Commit 12661e8a by Arve Knudsen Committed by GitHub

Move middleware context handler logic to service (#29605)

* middleware: Move context handler to own service

Signed-off-by: Arve Knudsen <arve.knudsen@gmail.com>

Co-authored-by: Emil Tullsted <sakjur@users.noreply.github.com>
Co-authored-by: Will Browne <wbrowne@users.noreply.github.com>
parent d0f52d53
......@@ -22,7 +22,7 @@ const (
func TestAdminAPIEndpoint(t *testing.T) {
const role = models.ROLE_ADMIN
t.Run("Given a server admin attempts to remove themself as an admin", func(t *testing.T) {
t.Run("Given a server admin attempts to remove themselves as an admin", func(t *testing.T) {
updateCmd := dtos.AdminUpdateUserPermissionsForm{
IsGrafanaAdmin: false,
}
......
......@@ -18,7 +18,7 @@ func (hs *HTTPServer) registerRoutes() {
reqEditorRole := middleware.ReqEditorRole
reqOrgAdmin := middleware.ReqOrgAdmin
reqCanAccessTeams := middleware.AdminOrFeatureEnabled(hs.Cfg.EditorsCanAdmin)
reqSnapshotPublicModeOrSignedIn := middleware.SnapshotPublicModeOrSignedIn()
reqSnapshotPublicModeOrSignedIn := middleware.SnapshotPublicModeOrSignedIn(hs.Cfg)
redirectFromLegacyDashboardURL := middleware.RedirectFromLegacyDashboardURL()
redirectFromLegacyDashboardSoloURL := middleware.RedirectFromLegacyDashboardSoloURL()
redirectFromLegacyPanelEditURL := middleware.RedirectFromLegacyPanelEditURL()
......
......@@ -85,7 +85,7 @@ func Success(message string) *NormalResponse {
return JSON(200, resp)
}
// Error create a erroneous response
// Error creates an error response.
func Error(status int, message string, err error) *NormalResponse {
data := make(map[string]interface{})
......
......@@ -8,9 +8,14 @@ import (
"testing"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/middleware"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require"
"gopkg.in/macaron.v1"
)
......@@ -141,20 +146,68 @@ func (sc *scenarioContext) exec() {
type scenarioFunc func(c *scenarioContext)
type handlerFunc func(c *models.ReqContext) Response
func getContextHandler(t *testing.T) *contexthandler.ContextHandler {
t.Helper()
sqlStore := sqlstore.InitTestDB(t)
remoteCacheSvc := &remotecache.RemoteCache{}
cfg := setting.NewCfg()
cfg.RemoteCacheOptions = &setting.RemoteCacheOptions{
Name: "database",
}
userAuthTokenSvc := auth.NewFakeUserAuthTokenService()
renderSvc := &fakeRenderService{}
ctxHdlr := &contexthandler.ContextHandler{}
err := registry.BuildServiceGraph([]interface{}{cfg}, []*registry.Descriptor{
{
Name: sqlstore.ServiceName,
Instance: sqlStore,
},
{
Name: remotecache.ServiceName,
Instance: remoteCacheSvc,
},
{
Name: auth.ServiceName,
Instance: userAuthTokenSvc,
},
{
Name: rendering.ServiceName,
Instance: renderSvc,
},
{
Name: contexthandler.ServiceName,
Instance: ctxHdlr,
},
})
require.NoError(t, err)
return ctxHdlr
}
func setupScenarioContext(t *testing.T, url string) *scenarioContext {
sc := &scenarioContext{
url: url,
t: t,
}
viewsPath, _ := filepath.Abs("../../public/views")
viewsPath, err := filepath.Abs("../../public/views")
require.NoError(t, err)
sc.m = macaron.New()
sc.m.Use(macaron.Renderer(macaron.RenderOptions{
Directory: viewsPath,
Delims: macaron.Delims{Left: "[[", Right: "]]"},
}))
sc.m.Use(middleware.GetContextHandler(nil, nil, nil))
sc.m.Use(getContextHandler(t).Middleware)
return sc
}
type fakeRenderService struct {
rendering.Service
}
func (s *fakeRenderService) Init() error {
return nil
}
......@@ -193,11 +193,11 @@ func (hs *HTTPServer) getFrontendSettingsMap(c *models.ReqContext) (map[string]i
"datasources": dataSources,
"minRefreshInterval": setting.MinRefreshInterval,
"panels": panels,
"appUrl": setting.AppUrl,
"appSubUrl": setting.AppSubUrl,
"appUrl": hs.Cfg.AppURL,
"appSubUrl": hs.Cfg.AppSubURL,
"allowOrgCreate": (setting.AllowUserOrgCreate && c.IsSignedIn) || c.IsGrafanaAdmin,
"authProxyEnabled": setting.AuthProxyEnabled,
"ldapEnabled": setting.LDAPEnabled,
"ldapEnabled": hs.Cfg.LDAPEnabled,
"alertingEnabled": setting.AlertingEnabled,
"alertingErrorOrTimeout": setting.AlertingErrorOrTimeout,
"alertingNoDataOrNullValues": setting.AlertingNoDataOrNullValues,
......
......@@ -18,7 +18,6 @@ import (
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/middleware"
"gopkg.in/macaron.v1"
"github.com/grafana/grafana/pkg/setting"
......@@ -53,7 +52,7 @@ func setupTestEnvironment(t *testing.T, cfg *setting.Cfg) (*macaron.Macaron, *HT
}
m := macaron.New()
m.Use(middleware.GetContextHandler(nil, nil, nil))
m.Use(getContextHandler(t).Middleware)
m.Use(macaron.Renderer(macaron.RenderOptions{
Directory: filepath.Join(setting.StaticRootPath, "views"),
IndentJSON: true,
......@@ -84,10 +83,12 @@ func TestHTTPServer_GetFrontendSettings_hideVersionAnonyomus(t *testing.T) {
setting.Env = "testing"
tests := []struct {
desc string
hideVersion bool
expected settings
}{
{
desc: "Not hiding version",
hideVersion: false,
expected: settings{
BuildInfo: buildInfo{
......@@ -98,6 +99,7 @@ func TestHTTPServer_GetFrontendSettings_hideVersionAnonyomus(t *testing.T) {
},
},
{
desc: "Hiding version",
hideVersion: true,
expected: settings{
BuildInfo: buildInfo{
......@@ -110,16 +112,18 @@ func TestHTTPServer_GetFrontendSettings_hideVersionAnonyomus(t *testing.T) {
}
for _, test := range tests {
hs.Cfg.AnonymousHideVersion = test.hideVersion
expected := test.expected
recorder := httptest.NewRecorder()
m.ServeHTTP(recorder, req)
got := settings{}
err := json.Unmarshal(recorder.Body.Bytes(), &got)
require.NoError(t, err)
require.GreaterOrEqual(t, 400, recorder.Code, "status codes higher than 400 indicates a failure")
assert.EqualValues(t, expected, got)
t.Run(test.desc, func(t *testing.T) {
hs.Cfg.AnonymousHideVersion = test.hideVersion
expected := test.expected
recorder := httptest.NewRecorder()
m.ServeHTTP(recorder, req)
got := settings{}
err := json.Unmarshal(recorder.Body.Bytes(), &got)
require.NoError(t, err)
require.GreaterOrEqual(t, 400, recorder.Code, "status codes higher than 400 indicate a failure")
assert.EqualValues(t, expected, got)
})
}
}
......@@ -29,6 +29,7 @@ import (
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/hooks"
"github.com/grafana/grafana/pkg/services/login"
......@@ -75,6 +76,7 @@ type HTTPServer struct {
SearchService *search.SearchService `inject:""`
ShortURLService *shorturls.ShortURLService `inject:""`
Live *live.GrafanaLive `inject:""`
ContextHandler *contexthandler.ContextHandler `inject:""`
Listener net.Listener
}
......@@ -100,7 +102,7 @@ func (hs *HTTPServer) Run(ctx context.Context) error {
Addr: net.JoinHostPort(setting.HttpAddr, setting.HttpPort),
Handler: hs.macaron,
}
switch setting.Protocol {
switch hs.Cfg.Protocol {
case setting.HTTP2Scheme:
if err := hs.configureHttp2(); err != nil {
return err
......@@ -118,7 +120,7 @@ func (hs *HTTPServer) Run(ctx context.Context) error {
}
hs.log.Info("HTTP Server Listen", "address", listener.Addr().String(), "protocol",
setting.Protocol, "subUrl", setting.AppSubUrl, "socket", setting.SocketPath)
hs.Cfg.Protocol, "subUrl", hs.Cfg.AppSubURL, "socket", hs.Cfg.SocketPath)
var wg sync.WaitGroup
wg.Add(1)
......@@ -133,7 +135,7 @@ func (hs *HTTPServer) Run(ctx context.Context) error {
}
}()
switch setting.Protocol {
switch hs.Cfg.Protocol {
case setting.HTTPScheme, setting.SocketScheme:
if err := hs.httpSrv.Serve(listener); err != nil {
if errors.Is(err, http.ErrServerClosed) {
......@@ -151,7 +153,7 @@ func (hs *HTTPServer) Run(ctx context.Context) error {
return err
}
default:
panic(fmt.Sprintf("Unhandled protocol %q", setting.Protocol))
panic(fmt.Sprintf("Unhandled protocol %q", hs.Cfg.Protocol))
}
wg.Wait()
......@@ -164,7 +166,7 @@ func (hs *HTTPServer) getListener() (net.Listener, error) {
return hs.Listener, nil
}
switch setting.Protocol {
switch hs.Cfg.Protocol {
case setting.HTTPScheme, setting.HTTPSScheme, setting.HTTP2Scheme:
listener, err := net.Listen("tcp", hs.httpSrv.Addr)
if err != nil {
......@@ -172,21 +174,21 @@ func (hs *HTTPServer) getListener() (net.Listener, error) {
}
return listener, nil
case setting.SocketScheme:
listener, err := net.ListenUnix("unix", &net.UnixAddr{Name: setting.SocketPath, Net: "unix"})
listener, err := net.ListenUnix("unix", &net.UnixAddr{Name: hs.Cfg.SocketPath, Net: "unix"})
if err != nil {
return nil, errutil.Wrapf(err, "failed to open listener for socket %s", setting.SocketPath)
return nil, errutil.Wrapf(err, "failed to open listener for socket %s", hs.Cfg.SocketPath)
}
// Make socket writable by group
// nolint:gosec
if err := os.Chmod(setting.SocketPath, 0660); err != nil {
if err := os.Chmod(hs.Cfg.SocketPath, 0660); err != nil {
return nil, errutil.Wrapf(err, "failed to change socket permissions")
}
return listener, nil
default:
hs.log.Error("Invalid protocol", "protocol", setting.Protocol)
return nil, fmt.Errorf("invalid protocol %q", setting.Protocol)
hs.log.Error("Invalid protocol", "protocol", hs.Cfg.Protocol)
return nil, fmt.Errorf("invalid protocol %q", hs.Cfg.Protocol)
}
}
......@@ -271,7 +273,7 @@ func (hs *HTTPServer) configureHttp2() error {
}
func (hs *HTTPServer) newMacaron() *macaron.Macaron {
macaron.Env = setting.Env
macaron.Env = hs.Cfg.Env
m := macaron.New()
// automatically set HEAD for every GET
......@@ -294,13 +296,13 @@ func (hs *HTTPServer) applyRoutes() {
func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() {
m := hs.macaron
m.Use(middleware.Logger())
m.Use(middleware.Logger(hs.Cfg))
if setting.EnableGzip {
m.Use(middleware.Gziper())
}
m.Use(middleware.Recovery())
m.Use(middleware.Recovery(hs.Cfg))
for _, route := range plugins.StaticRoutes {
pluginRoute := path.Join("/public/plugins/", route.PluginId)
......@@ -316,7 +318,7 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() {
hs.mapStatic(m, hs.Cfg.ImagesDir, "", "/public/img/attachments")
}
m.Use(middleware.AddDefaultResponseHeaders())
m.Use(middleware.AddDefaultResponseHeaders(hs.Cfg))
if setting.ServeFromSubPath && setting.AppSubUrl != "" {
m.SetURLPrefix(setting.AppSubUrl)
......@@ -334,16 +336,12 @@ func (hs *HTTPServer) addMiddlewaresAndStaticRoutes() {
m.Use(hs.apiHealthHandler)
m.Use(hs.metricsEndpoint)
m.Use(middleware.GetContextHandler(
hs.AuthTokenService,
hs.RemoteCacheService,
hs.RenderService,
))
m.Use(hs.ContextHandler.Middleware)
m.Use(middleware.OrgRedirect())
// needs to be after context handler
if setting.EnforceDomain {
m.Use(middleware.ValidateHostHeader(setting.Domain))
m.Use(middleware.ValidateHostHeader(hs.Cfg.Domain))
}
m.Use(middleware.HandleNoCacheHeader())
......@@ -433,7 +431,7 @@ func (hs *HTTPServer) mapStatic(m *macaron.Macaron, rootDir string, dir string,
}
}
if setting.Env == setting.Dev {
if hs.Cfg.Env == setting.Dev {
headers = func(c *macaron.Context) {
c.Resp.Header().Set("Cache-Control", "max-age=0, must-revalidate, no-cache")
}
......
......@@ -300,7 +300,7 @@ func (hs *HTTPServer) getNavTree(c *models.ReqContext, hasEditPerm bool) ([]*dto
{Text: "Stats", Id: "server-stats", Url: setting.AppSubUrl + "/admin/stats", Icon: "graph-bar"},
}
if setting.LDAPEnabled {
if hs.Cfg.LDAPEnabled {
adminNavLinks = append(adminNavLinks, &dtos.NavLink{
Text: "LDAP", Id: "ldap", Url: setting.AppSubUrl + "/admin/ldap", Icon: "book",
})
......@@ -371,7 +371,7 @@ func (hs *HTTPServer) setIndexViewData(c *models.ReqContext) (*dtos.IndexViewDat
// special case when doing localhost call from image renderer
if c.IsRenderCall && !hs.Cfg.ServeFromSubPath {
appURL = fmt.Sprintf("%s://localhost:%s", setting.Protocol, setting.HttpPort)
appURL = fmt.Sprintf("%s://localhost:%s", hs.Cfg.Protocol, setting.HttpPort)
appSubURL = ""
settings["appSubUrl"] = ""
}
......
......@@ -116,7 +116,7 @@ func (hs *HTTPServer) GetLDAPStatus(c *models.ReqContext) Response {
return Error(http.StatusBadRequest, "LDAP is not enabled", nil)
}
ldapConfig, err := getLDAPConfig()
ldapConfig, err := getLDAPConfig(hs.Cfg)
if err != nil {
return Error(http.StatusBadRequest, "Failed to obtain the LDAP configuration. Please verify the configuration and try again", err)
......@@ -158,7 +158,7 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) Response {
return Error(http.StatusBadRequest, "LDAP is not enabled", nil)
}
ldapConfig, err := getLDAPConfig()
ldapConfig, err := getLDAPConfig(hs.Cfg)
if err != nil {
return Error(http.StatusBadRequest, "Failed to obtain the LDAP configuration. Please verify the configuration and try again", err)
}
......@@ -217,7 +217,7 @@ func (hs *HTTPServer) PostSyncUserWithLDAP(c *models.ReqContext) Response {
upsertCmd := &models.UpsertUserCommand{
ReqContext: c,
ExternalUser: user,
SignupAllowed: setting.LDAPAllowSignup,
SignupAllowed: hs.Cfg.LDAPAllowSignup,
}
err = bus.Dispatch(upsertCmd)
......@@ -235,7 +235,7 @@ func (hs *HTTPServer) GetUserFromLDAP(c *models.ReqContext) Response {
return Error(http.StatusBadRequest, "LDAP is not enabled", nil)
}
ldapConfig, err := getLDAPConfig()
ldapConfig, err := getLDAPConfig(hs.Cfg)
if err != nil {
return Error(http.StatusBadRequest, "Failed to obtain the LDAP configuration", err)
......
......@@ -74,7 +74,7 @@ func getUserFromLDAPContext(t *testing.T, requestURL string) *scenarioContext {
}
func TestGetUserFromLDAPAPIEndpoint_UserNotFound(t *testing.T) {
getLDAPConfig = func() (*ldap.Config, error) {
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil
}
......@@ -131,7 +131,7 @@ func TestGetUserFromLDAPAPIEndpoint_OrgNotfound(t *testing.T) {
return nil
})
getLDAPConfig = func() (*ldap.Config, error) {
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil
}
......@@ -193,7 +193,7 @@ func TestGetUserFromLDAPAPIEndpoint(t *testing.T) {
return nil
})
getLDAPConfig = func() (*ldap.Config, error) {
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil
}
......@@ -273,7 +273,7 @@ func TestGetUserFromLDAPAPIEndpoint_WithTeamHandler(t *testing.T) {
return nil
})
getLDAPConfig = func() (*ldap.Config, error) {
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil
}
......@@ -349,7 +349,7 @@ func TestGetLDAPStatusAPIEndpoint(t *testing.T) {
{Host: "10.0.0.5", Port: 361, Available: false, Error: errors.New("something is awfully wrong")},
}
getLDAPConfig = func() (*ldap.Config, error) {
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil
}
......@@ -412,7 +412,7 @@ func postSyncUserWithLDAPContext(t *testing.T, requestURL string, preHook func(t
func TestPostSyncUserWithLDAPAPIEndpoint_Success(t *testing.T) {
sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T) {
getLDAPConfig = func() (*ldap.Config, error) {
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil
}
......@@ -457,7 +457,7 @@ func TestPostSyncUserWithLDAPAPIEndpoint_Success(t *testing.T) {
func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotFound(t *testing.T) {
sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T) {
getLDAPConfig = func() (*ldap.Config, error) {
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil
}
......@@ -485,7 +485,7 @@ func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotFound(t *testing.T) {
func TestPostSyncUserWithLDAPAPIEndpoint_WhenGrafanaAdmin(t *testing.T) {
sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T) {
getLDAPConfig = func() (*ldap.Config, error) {
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil
}
......@@ -528,7 +528,7 @@ func TestPostSyncUserWithLDAPAPIEndpoint_WhenGrafanaAdmin(t *testing.T) {
func TestPostSyncUserWithLDAPAPIEndpoint_WhenUserNotInLDAP(t *testing.T) {
sc := postSyncUserWithLDAPContext(t, "/api/admin/ldap/sync/34", func(t *testing.T) {
getLDAPConfig = func() (*ldap.Config, error) {
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return &ldap.Config{}, nil
}
......
......@@ -13,7 +13,7 @@ import (
"github.com/grafana/grafana/pkg/infra/metrics"
"github.com/grafana/grafana/pkg/infra/network"
"github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/middleware"
"github.com/grafana/grafana/pkg/middleware/cookies"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
......@@ -61,12 +61,12 @@ func (hs *HTTPServer) ValidateRedirectTo(redirectTo string) error {
return nil
}
func (hs *HTTPServer) CookieOptionsFromCfg() middleware.CookieOptions {
func (hs *HTTPServer) CookieOptionsFromCfg() cookies.CookieOptions {
path := "/"
if len(hs.Cfg.AppSubURL) > 0 {
path = hs.Cfg.AppSubURL
}
return middleware.CookieOptions{
return cookies.CookieOptions{
Path: path,
Secure: hs.Cfg.CookieSecure,
SameSiteDisabled: hs.Cfg.CookieSameSiteDisabled,
......@@ -101,7 +101,7 @@ func (hs *HTTPServer) LoginView(c *models.ReqContext) {
// therefore the loginError should be passed to the view data
// and the view should return immediately before attempting
// to login again via OAuth and enter to a redirect loop
middleware.DeleteCookie(c.Resp, LoginErrorCookieName, hs.CookieOptionsFromCfg)
cookies.DeleteCookie(c.Resp, LoginErrorCookieName, hs.CookieOptionsFromCfg)
viewData.Settings["loginError"] = loginError
c.HTML(200, getViewIndex(), viewData)
return
......@@ -113,7 +113,7 @@ func (hs *HTTPServer) LoginView(c *models.ReqContext) {
if c.IsSignedIn {
// Assign login token to auth proxy users if enable_login_token = true
if setting.AuthProxyEnabled && setting.AuthProxyEnableLoginToken {
if hs.Cfg.AuthProxyEnabled && hs.Cfg.AuthProxyEnableLoginToken {
user := &models.User{Id: c.SignedInUser.UserId, Email: c.SignedInUser.Email, Login: c.SignedInUser.Login}
err := hs.loginUserWithUser(user, c)
if err != nil {
......@@ -129,7 +129,7 @@ func (hs *HTTPServer) LoginView(c *models.ReqContext) {
log.Debugf("Ignored invalid redirect_to cookie value: %v", redirectTo)
redirectTo = hs.Cfg.AppSubURL + "/"
}
middleware.DeleteCookie(c.Resp, "redirect_to", hs.CookieOptionsFromCfg)
cookies.DeleteCookie(c.Resp, "redirect_to", hs.CookieOptionsFromCfg)
c.Redirect(redirectTo)
return
}
......@@ -196,6 +196,7 @@ func (hs *HTTPServer) LoginPost(c *models.ReqContext, cmd dtos.LoginCommand) Res
Username: cmd.User,
Password: cmd.Password,
IpAddress: c.Req.RemoteAddr,
Cfg: hs.Cfg,
}
err := bus.Dispatch(authQuery)
......@@ -236,7 +237,7 @@ func (hs *HTTPServer) LoginPost(c *models.ReqContext, cmd dtos.LoginCommand) Res
} else {
log.Infof("Ignored invalid redirect_to cookie value: %v", redirectTo)
}
middleware.DeleteCookie(c.Resp, "redirect_to", hs.CookieOptionsFromCfg)
cookies.DeleteCookie(c.Resp, "redirect_to", hs.CookieOptionsFromCfg)
}
metrics.MApiLoginPost.Inc()
......@@ -263,7 +264,7 @@ func (hs *HTTPServer) loginUserWithUser(user *models.User, c *models.ReqContext)
}
hs.log.Info("Successful Login", "User", user.Email)
middleware.WriteSessionCookie(c, userToken.UnhashedToken, hs.Cfg.LoginMaxLifetime)
cookies.WriteSessionCookie(c, hs.Cfg, userToken.UnhashedToken, hs.Cfg.LoginMaxLifetime)
return nil
}
......@@ -278,7 +279,7 @@ func (hs *HTTPServer) Logout(c *models.ReqContext) {
hs.log.Error("failed to revoke auth token", "error", err)
}
middleware.WriteSessionCookie(c, "", -1)
cookies.WriteSessionCookie(c, hs.Cfg, "", -1)
if setting.SignoutRedirectUrl != "" {
c.Redirect(setting.SignoutRedirectUrl)
......@@ -309,7 +310,7 @@ func (hs *HTTPServer) trySetEncryptedCookie(ctx *models.ReqContext, cookieName s
return err
}
middleware.WriteCookie(ctx.Resp, cookieName, hex.EncodeToString(encryptedError), 60, hs.CookieOptionsFromCfg)
cookies.WriteCookie(ctx.Resp, cookieName, hex.EncodeToString(encryptedError), 60, hs.CookieOptionsFromCfg)
return nil
}
......
......@@ -18,7 +18,7 @@ import (
"github.com/grafana/grafana/pkg/infra/metrics"
"github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/login/social"
"github.com/grafana/grafana/pkg/middleware"
"github.com/grafana/grafana/pkg/middleware/cookies"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
)
......@@ -81,7 +81,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
}
hashedState := hashStatecode(state, setting.OAuthService.OAuthInfos[name].ClientSecret)
middleware.WriteCookie(ctx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
cookies.WriteCookie(ctx.Resp, OauthStateCookieName, hashedState, hs.Cfg.OAuthCookieMaxAge, hs.CookieOptionsFromCfg)
if setting.OAuthService.OAuthInfos[name].HostedDomain == "" {
ctx.Redirect(connect.AuthCodeURL(state, oauth2.AccessTypeOnline))
} else {
......@@ -93,7 +93,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
cookieState := ctx.GetCookie(OauthStateCookieName)
// delete cookie
middleware.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
cookies.DeleteCookie(ctx.Resp, OauthStateCookieName, hs.CookieOptionsFromCfg)
if cookieState == "" {
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
......@@ -192,7 +192,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
if redirectTo, err := url.QueryUnescape(ctx.GetCookie("redirect_to")); err == nil && len(redirectTo) > 0 {
if err := hs.ValidateRedirectTo(redirectTo); err == nil {
middleware.DeleteCookie(ctx.Resp, "redirect_to", hs.CookieOptionsFromCfg)
cookies.DeleteCookie(ctx.Resp, "redirect_to", hs.CookieOptionsFromCfg)
ctx.Redirect(redirectTo)
return
}
......
......@@ -592,8 +592,8 @@ func setupAuthProxyLoginTest(t *testing.T, enableLoginToken bool) *scenarioConte
setting.OAuthService = &setting.OAuther{}
setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
setting.AuthProxyEnabled = true
setting.AuthProxyEnableLoginToken = enableLoginToken
hs.Cfg.AuthProxyEnabled = true
hs.Cfg.AuthProxyEnableLoginToken = enableLoginToken
sc.m.Get(sc.url, sc.defaultHandler)
sc.fakeReqNoAssertions("GET", sc.url).exec()
......
......@@ -23,8 +23,17 @@ var (
defaultMaxCacheExpiration = time.Hour * 24
)
const (
ServiceName = "RemoteCache"
)
func init() {
registry.RegisterService(&RemoteCache{})
rc := &RemoteCache{}
registry.Register(&registry.Descriptor{
Name: ServiceName,
Instance: rc,
InitPriority: registry.Medium,
})
}
// CacheStorage allows the caller to set, get and delete items in the cache.
......
......@@ -25,12 +25,12 @@ var (
var loginLogger = log.New("login")
func Init() {
bus.AddHandler("auth", AuthenticateUser)
bus.AddHandler("auth", authenticateUser)
}
// AuthenticateUser authenticates the user via username & password
func AuthenticateUser(query *models.LoginUserQuery) error {
if err := validateLoginAttempts(query.Username); err != nil {
// authenticateUser authenticates the user via username & password
func authenticateUser(query *models.LoginUserQuery) error {
if err := validateLoginAttempts(query); err != nil {
return err
}
......
......@@ -21,7 +21,7 @@ func TestAuthenticateUser(t *testing.T) {
Username: "user",
Password: "",
}
err := AuthenticateUser(&loginQuery)
err := authenticateUser(&loginQuery)
Convey("login should fail", func() {
So(sc.grafanaLoginWasCalled, ShouldBeFalse)
......@@ -37,7 +37,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, nil, sc)
mockSaveInvalidLoginAttempt(sc)
err := AuthenticateUser(sc.loginUserQuery)
err := authenticateUser(sc.loginUserQuery)
Convey("it should result in", func() {
So(err, ShouldEqual, ErrTooManyLoginAttempts)
......@@ -55,7 +55,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc)
err := AuthenticateUser(sc.loginUserQuery)
err := authenticateUser(sc.loginUserQuery)
Convey("it should result in", func() {
So(err, ShouldEqual, nil)
......@@ -74,7 +74,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc)
err := AuthenticateUser(sc.loginUserQuery)
err := authenticateUser(sc.loginUserQuery)
Convey("it should result in", func() {
So(err, ShouldEqual, customErr)
......@@ -92,7 +92,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(false, nil, sc)
mockSaveInvalidLoginAttempt(sc)
err := AuthenticateUser(sc.loginUserQuery)
err := authenticateUser(sc.loginUserQuery)
Convey("it should result in", func() {
So(err, ShouldEqual, models.ErrUserNotFound)
......@@ -110,7 +110,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc)
err := AuthenticateUser(sc.loginUserQuery)
err := authenticateUser(sc.loginUserQuery)
Convey("it should result in", func() {
So(err, ShouldEqual, ErrInvalidCredentials)
......@@ -128,7 +128,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, nil, sc)
mockSaveInvalidLoginAttempt(sc)
err := AuthenticateUser(sc.loginUserQuery)
err := authenticateUser(sc.loginUserQuery)
Convey("it should result in", func() {
So(err, ShouldBeNil)
......@@ -147,7 +147,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, customErr, sc)
mockSaveInvalidLoginAttempt(sc)
err := AuthenticateUser(sc.loginUserQuery)
err := authenticateUser(sc.loginUserQuery)
Convey("it should result in", func() {
So(err, ShouldEqual, customErr)
......@@ -165,7 +165,7 @@ func TestAuthenticateUser(t *testing.T) {
mockLoginUsingLDAP(true, ldap.ErrInvalidCredentials, sc)
mockSaveInvalidLoginAttempt(sc)
err := AuthenticateUser(sc.loginUserQuery)
err := authenticateUser(sc.loginUserQuery)
Convey("it should result in", func() {
So(err, ShouldEqual, ErrInvalidCredentials)
......@@ -203,7 +203,7 @@ func mockLoginUsingLDAP(enabled bool, err error, sc *authScenarioContext) {
}
func mockLoginAttemptValidation(err error, sc *authScenarioContext) {
validateLoginAttempts = func(username string) error {
validateLoginAttempts = func(*models.LoginUserQuery) error {
sc.loginAttemptValidationWasCalled = true
return err
}
......
......@@ -5,7 +5,6 @@ import (
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
)
var (
......@@ -13,13 +12,13 @@ var (
loginAttemptsWindow = time.Minute * 5
)
var validateLoginAttempts = func(username string) error {
if setting.DisableBruteForceLoginProtection {
var validateLoginAttempts = func(query *models.LoginUserQuery) error {
if query.Cfg.DisableBruteForceLoginProtection {
return nil
}
loginAttemptCountQuery := models.GetUserLoginAttemptCountQuery{
Username: username,
Username: query.Username,
Since: time.Now().Add(-loginAttemptsWindow),
}
......@@ -35,7 +34,7 @@ var validateLoginAttempts = func(username string) error {
}
var saveInvalidLoginAttempt = func(query *models.LoginUserQuery) error {
if setting.DisableBruteForceLoginProtection {
if query.Cfg.DisableBruteForceLoginProtection {
return nil
}
......
......@@ -12,11 +12,16 @@ import (
func TestLoginAttemptsValidation(t *testing.T) {
Convey("Validate login attempts", t, func() {
Convey("Given brute force login protection enabled", func() {
setting.DisableBruteForceLoginProtection = false
cfg := setting.NewCfg()
cfg.DisableBruteForceLoginProtection = false
query := &models.LoginUserQuery{
Username: "user",
Cfg: cfg,
}
Convey("When user login attempt count equals max-1 ", func() {
withLoginAttempts(maxInvalidLoginAttempts - 1)
err := validateLoginAttempts("user")
err := validateLoginAttempts(query)
Convey("it should not result in error", func() {
So(err, ShouldBeNil)
......@@ -25,7 +30,7 @@ func TestLoginAttemptsValidation(t *testing.T) {
Convey("When user login attempt count equals max ", func() {
withLoginAttempts(maxInvalidLoginAttempts)
err := validateLoginAttempts("user")
err := validateLoginAttempts(query)
Convey("it should result in too many login attempts error", func() {
So(err, ShouldEqual, ErrTooManyLoginAttempts)
......@@ -34,7 +39,7 @@ func TestLoginAttemptsValidation(t *testing.T) {
Convey("When user login attempt count is greater than max ", func() {
withLoginAttempts(maxInvalidLoginAttempts + 5)
err := validateLoginAttempts("user")
err := validateLoginAttempts(query)
Convey("it should result in too many login attempts error", func() {
So(err, ShouldEqual, ErrTooManyLoginAttempts)
......@@ -54,6 +59,7 @@ func TestLoginAttemptsValidation(t *testing.T) {
Username: "user",
Password: "pwd",
IpAddress: "192.168.1.1:56433",
Cfg: setting.NewCfg(),
})
So(err, ShouldBeNil)
......@@ -66,11 +72,16 @@ func TestLoginAttemptsValidation(t *testing.T) {
})
Convey("Given brute force login protection disabled", func() {
setting.DisableBruteForceLoginProtection = true
cfg := setting.NewCfg()
cfg.DisableBruteForceLoginProtection = true
query := &models.LoginUserQuery{
Username: "user",
Cfg: cfg,
}
Convey("When user login attempt count equals max-1 ", func() {
withLoginAttempts(maxInvalidLoginAttempts - 1)
err := validateLoginAttempts("user")
err := validateLoginAttempts(query)
Convey("it should not result in error", func() {
So(err, ShouldBeNil)
......@@ -79,7 +90,7 @@ func TestLoginAttemptsValidation(t *testing.T) {
Convey("When user login attempt count equals max ", func() {
withLoginAttempts(maxInvalidLoginAttempts)
err := validateLoginAttempts("user")
err := validateLoginAttempts(query)
Convey("it should not result in error", func() {
So(err, ShouldBeNil)
......@@ -88,7 +99,7 @@ func TestLoginAttemptsValidation(t *testing.T) {
Convey("When user login attempt count is greater than max ", func() {
withLoginAttempts(maxInvalidLoginAttempts + 5)
err := validateLoginAttempts("user")
err := validateLoginAttempts(query)
Convey("it should not result in error", func() {
So(err, ShouldBeNil)
......@@ -97,7 +108,7 @@ func TestLoginAttemptsValidation(t *testing.T) {
Convey("When saving invalid login attempt", func() {
defer bus.ClearBusHandlers()
createLoginAttemptCmd := (*models.CreateLoginAttemptCommand)(nil)
var createLoginAttemptCmd *models.CreateLoginAttemptCommand
bus.AddHandler("test", func(cmd *models.CreateLoginAttemptCommand) error {
createLoginAttemptCmd = cmd
......@@ -108,6 +119,7 @@ func TestLoginAttemptsValidation(t *testing.T) {
Username: "user",
Password: "pwd",
IpAddress: "192.168.1.1:56433",
Cfg: cfg,
})
So(err, ShouldBeNil)
......
......@@ -33,7 +33,7 @@ var loginUsingLDAP = func(query *models.LoginUserQuery) (bool, error) {
return false, nil
}
config, err := getLDAPConfig()
config, err := getLDAPConfig(query.Cfg)
if err != nil {
return true, errutil.Wrap("Failed to get LDAP config", err)
}
......
......@@ -20,7 +20,7 @@ func TestLDAPLogin(t *testing.T) {
LDAPLoginScenario("When login", func(sc *LDAPLoginScenarioContext) {
sc.withLoginResult(false)
getLDAPConfig = func() (*ldap.Config, error) {
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
config := &ldap.Config{
Servers: []*ldap.ServerConfig{},
}
......@@ -150,7 +150,14 @@ func LDAPLoginScenario(desc string, fn LDAPLoginScenarioFunc) {
LDAPAuthenticatorMock: mock,
}
getLDAPConfig = func() (*ldap.Config, error) {
origNewLDAP := newLDAP
origGetLDAPConfig := getLDAPConfig
defer func() {
newLDAP = origNewLDAP
getLDAPConfig = origGetLDAPConfig
}()
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
config := &ldap.Config{
Servers: []*ldap.ServerConfig{
{
......@@ -166,11 +173,6 @@ func LDAPLoginScenario(desc string, fn LDAPLoginScenarioFunc) {
return mock
}
defer func() {
newLDAP = multildap.New
getLDAPConfig = multildap.GetConfig
}()
fn(sc)
})
}
......
......@@ -8,9 +8,9 @@ import (
macaron "gopkg.in/macaron.v1"
"github.com/grafana/grafana/pkg/middleware/cookies"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
)
type AuthOptions struct {
......@@ -18,22 +18,6 @@ type AuthOptions struct {
ReqSignedIn bool
}
func getApiKey(c *models.ReqContext) string {
header := c.Req.Header.Get("Authorization")
parts := strings.SplitN(header, " ", 2)
if len(parts) == 2 && parts[0] == "Bearer" {
key := parts[1]
return key
}
username, password, err := util.DecodeBasicAuthHeader(header)
if err == nil && username == "api_key" {
return password
}
return ""
}
func accessForbidden(c *models.ReqContext) {
if c.IsApiRequest() {
c.JsonApiErr(403, "Permission denied", nil)
......@@ -57,7 +41,7 @@ func notAuthorized(c *models.ReqContext) {
// remove any forceLogin=true params
redirectTo = removeForceLoginParams(redirectTo)
WriteCookie(c.Resp, "redirect_to", url.QueryEscape(redirectTo), 0, nil)
cookies.WriteCookie(c.Resp, "redirect_to", url.QueryEscape(redirectTo), 0, nil)
c.Redirect(setting.AppSubUrl + "/login")
}
......@@ -135,9 +119,9 @@ func AdminOrFeatureEnabled(enabled bool) macaron.Handler {
}
}
func SnapshotPublicModeOrSignedIn() macaron.Handler {
func SnapshotPublicModeOrSignedIn(cfg *setting.Cfg) macaron.Handler {
return func(c *models.ReqContext) {
if setting.SnapshotPublicMode {
if cfg.SnapshotPublicMode {
return
}
......
package middleware
import (
"errors"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/middleware/authproxy"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
)
var header = setting.AuthProxyHeaderName
func logUserIn(auth *authproxy.AuthProxy, username string, logger log.Logger, ignoreCache bool) (int64, error) {
logger.Debug("Trying to log user in", "username", username, "ignoreCache", ignoreCache)
// Try to log in user via various providers
id, err := auth.Login(logger, ignoreCache)
if err != nil {
details := err
var e authproxy.Error
if errors.As(err, &e) {
details = e.DetailsError
}
logger.Error("Failed to login", "username", username, "message", err.Error(), "error", details,
"ignoreCache", ignoreCache)
return 0, err
}
return id, nil
}
// handleError calls ctx.Handle with the error message and the underlying error.
// If the error is of type authproxy.Error, its DetailsError is unwrapped and passed to ctx.Handle.
// If a callback is provided, it's called with either err.DetailsError, if err is of type
// authproxy.Error, otherwise err itself.
func handleError(ctx *models.ReqContext, err error, statusCode int, cb func(err error)) {
details := err
var e authproxy.Error
if errors.As(err, &e) {
details = e.DetailsError
}
ctx.Handle(statusCode, err.Error(), details)
if cb != nil {
cb(details)
}
}
func initContextWithAuthProxy(store *remotecache.RemoteCache, ctx *models.ReqContext, orgID int64) bool {
username := ctx.Req.Header.Get(header)
auth := authproxy.New(&authproxy.Options{
Store: store,
Ctx: ctx,
OrgID: orgID,
})
logger := log.New("auth.proxy")
// Bail if auth proxy is not enabled
if !auth.IsEnabled() {
return false
}
// If there is no header - we can't move forward
if !auth.HasHeader() {
return false
}
// Check if allowed to continue with this IP
if err := auth.IsAllowedIP(); err != nil {
handleError(ctx, err, 407, func(details error) {
logger.Error("Failed to check whitelisted IP addresses", "message", err.Error(), "error", details)
})
return true
}
id, err := logUserIn(auth, username, logger, false)
if err != nil {
handleError(ctx, err, 407, nil)
return true
}
logger.Debug("Got user ID, getting full user info", "userID", id)
user, e := auth.GetSignedUser(id)
if e != nil {
// The reason we couldn't find the user corresponding to the ID might be that the ID was found from a stale
// cache entry. For example, if a user is deleted via the API, corresponding cache entries aren't invalidated
// because cache keys are computed from request header values and not just the user ID. Meaning that
// we can't easily derive cache keys to invalidate when deleting a user. To work around this, we try to
// log the user in again without the cache.
logger.Debug("Failed to get user info given ID, retrying without cache", "userID", id)
if err := auth.RemoveUserFromCache(logger); err != nil {
if !errors.Is(err, remotecache.ErrCacheItemNotFound) {
logger.Error("Got unexpected error when removing user from auth cache", "error", err)
}
}
id, err = logUserIn(auth, username, logger, true)
if err != nil {
handleError(ctx, err, 407, nil)
return true
}
user, err = auth.GetSignedUser(id)
if err != nil {
handleError(ctx, err, 407, nil)
return true
}
}
logger.Debug("Successfully got user info", "userID", user.UserId, "username", user.Login)
// Add user info to context
ctx.SignedInUser = user
ctx.IsSignedIn = true
// Remember user data in cache
if err := auth.Remember(id); err != nil {
handleError(ctx, err, 500, func(details error) {
logger.Error(
"Failed to store user in cache",
"username", username,
"message", e.Error(),
"error", details,
)
})
return true
}
return true
}
......@@ -33,16 +33,10 @@ func TestMiddlewareAuth(t *testing.T) {
t.Run("Anonymous auth enabled", func(t *testing.T) {
const orgID int64 = 1
origEnabled := setting.AnonymousEnabled
t.Cleanup(func() {
setting.AnonymousEnabled = origEnabled
})
origName := setting.AnonymousOrgName
t.Cleanup(func() {
setting.AnonymousOrgName = origName
})
setting.AnonymousEnabled = true
setting.AnonymousOrgName = "test"
configure := func(cfg *setting.Cfg) {
cfg.AnonymousEnabled = true
cfg.AnonymousOrgName = "test"
}
middlewareScenario(t, "ReqSignIn true and request with forceLogin in query string", func(
t *testing.T, sc *scenarioContext) {
......@@ -59,7 +53,7 @@ func TestMiddlewareAuth(t *testing.T) {
location, ok := sc.resp.Header()["Location"]
assert.True(t, ok)
assert.Equal(t, "/login", location[0])
})
}, configure)
middlewareScenario(t, "ReqSignIn true and request with same org provided in query string", func(
t *testing.T, sc *scenarioContext) {
......@@ -73,7 +67,7 @@ func TestMiddlewareAuth(t *testing.T) {
sc.fakeReq("GET", fmt.Sprintf("/secure?orgId=%d", orgID)).exec()
assert.Equal(t, 200, sc.resp.Code)
})
}, configure)
middlewareScenario(t, "ReqSignIn true and request with different org provided in query string", func(
t *testing.T, sc *scenarioContext) {
......@@ -90,20 +84,20 @@ func TestMiddlewareAuth(t *testing.T) {
location, ok := sc.resp.Header()["Location"]
assert.True(t, ok)
assert.Equal(t, "/login", location[0])
})
}, configure)
})
middlewareScenario(t, "Snapshot public mode disabled and unauthenticated request should return 401", func(
t *testing.T, sc *scenarioContext) {
sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(), sc.defaultHandler)
sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(sc.cfg), sc.defaultHandler)
sc.fakeReq("GET", "/api/snapshot").exec()
assert.Equal(t, 401, sc.resp.Code)
})
middlewareScenario(t, "Snapshot public mode enabled and unauthenticated request should return 200", func(
t *testing.T, sc *scenarioContext) {
setting.SnapshotPublicMode = true
sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(), sc.defaultHandler)
sc.cfg.SnapshotPublicMode = true
sc.m.Get("/api/snapshot", SnapshotPublicModeOrSignedIn(sc.cfg), sc.defaultHandler)
sc.fakeReq("GET", "/api/snapshot").exec()
assert.Equal(t, 200, sc.resp.Code)
})
......
package middleware
package cookies
import (
"net/http"
......@@ -55,8 +55,8 @@ func WriteCookie(w http.ResponseWriter, name string, value string, maxAge int, g
http.SetCookie(w, &cookie)
}
func WriteSessionCookie(ctx *models.ReqContext, value string, maxLifetime time.Duration) {
if setting.Env == setting.Dev {
func WriteSessionCookie(ctx *models.ReqContext, cfg *setting.Cfg, value string, maxLifetime time.Duration) {
if cfg.Env == setting.Dev {
ctx.Logger.Info("New token", "unhashed token", value)
}
......
......@@ -26,7 +26,7 @@ import (
"gopkg.in/macaron.v1"
)
func Logger() macaron.Handler {
func Logger(cfg *setting.Cfg) macaron.Handler {
return func(res http.ResponseWriter, req *http.Request, c *macaron.Context) {
start := time.Now()
c.Data["perfmon.start"] = start
......@@ -43,7 +43,7 @@ func Logger() macaron.Handler {
status := rw.Status()
if status == 200 || status == 304 {
if !setting.RouterLogging {
if !cfg.RouterLogging {
return
}
}
......
......@@ -7,6 +7,7 @@ import (
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/login"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
"github.com/stretchr/testify/assert"
......@@ -14,19 +15,13 @@ import (
)
func TestMiddlewareBasicAuth(t *testing.T) {
var origBasicAuthEnabled = setting.BasicAuthEnabled
var origDisableBruteForceLoginProtection = setting.DisableBruteForceLoginProtection
t.Cleanup(func() {
setting.BasicAuthEnabled = origBasicAuthEnabled
setting.DisableBruteForceLoginProtection = origDisableBruteForceLoginProtection
})
setting.BasicAuthEnabled = true
setting.DisableBruteForceLoginProtection = true
bus.ClearBusHandlers()
const id int64 = 12
configure := func(cfg *setting.Cfg) {
cfg.BasicAuthEnabled = true
cfg.DisableBruteForceLoginProtection = true
}
middlewareScenario(t, "Valid API key", func(t *testing.T, sc *scenarioContext) {
const orgID int64 = 2
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
......@@ -44,16 +39,15 @@ func TestMiddlewareBasicAuth(t *testing.T) {
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgId)
assert.Equal(t, models.ROLE_EDITOR, sc.context.OrgRole)
})
}, configure)
middlewareScenario(t, "Handle auth", func(t *testing.T, sc *scenarioContext) {
const password = "MyPass"
const salt = "Salt"
const orgID int64 = 2
t.Cleanup(bus.ClearBusHandlers)
bus.AddHandler("grafana-auth", func(query *models.LoginUserQuery) error {
t.Log("Handling LoginUserQuery")
encoded, err := util.EncodePassword(password, salt)
if err != nil {
return err
......@@ -66,6 +60,7 @@ func TestMiddlewareBasicAuth(t *testing.T) {
})
bus.AddHandler("get-sign-user", func(query *models.GetSignedInUserQuery) error {
t.Log("Handling GetSignedInUserQuery")
query.Result = &models.SignedInUser{OrgId: orgID, UserId: id}
return nil
})
......@@ -76,7 +71,7 @@ func TestMiddlewareBasicAuth(t *testing.T) {
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, orgID, sc.context.OrgId)
assert.Equal(t, id, sc.context.UserId)
})
}, configure)
middlewareScenario(t, "Auth sequence", func(t *testing.T, sc *scenarioContext) {
const password = "MyPass"
......@@ -104,10 +99,11 @@ func TestMiddlewareBasicAuth(t *testing.T) {
authHeader := util.GetBasicAuthHeader("myUser", password)
sc.fakeReq("GET", "/").withAuthorizationHeader(authHeader).exec()
require.NotNil(t, sc.context)
assert.True(t, sc.context.IsSignedIn)
assert.Equal(t, id, sc.context.UserId)
})
}, configure)
middlewareScenario(t, "Should return error if user is not found", func(t *testing.T, sc *scenarioContext) {
sc.fakeReq("GET", "/")
......@@ -118,8 +114,8 @@ func TestMiddlewareBasicAuth(t *testing.T) {
require.Error(t, err)
assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, errStringInvalidUsernamePassword, sc.respJson["message"])
})
assert.Equal(t, contexthandler.InvalidUsernamePassword, sc.respJson["message"])
}, configure)
middlewareScenario(t, "Should return error if user & password do not match", func(t *testing.T, sc *scenarioContext) {
bus.AddHandler("user-query", func(loginUserQuery *models.GetUserByLoginQuery) error {
......@@ -134,6 +130,6 @@ func TestMiddlewareBasicAuth(t *testing.T) {
require.Error(t, err)
assert.Equal(t, 401, sc.resp.Code)
assert.Equal(t, errStringInvalidUsernamePassword, sc.respJson["message"])
})
assert.Equal(t, contexthandler.InvalidUsernamePassword, sc.respJson["message"])
}, configure)
}
......@@ -7,6 +7,7 @@ import (
"time"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
......@@ -18,6 +19,8 @@ type advanceTimeFunc func(deltaTime time.Duration)
type rateLimiterScenarioFunc func(c execFunc, t advanceTimeFunc)
func rateLimiterScenario(t *testing.T, desc string, rps int, burst int, fn rateLimiterScenarioFunc) {
t.Helper()
t.Run(desc, func(t *testing.T) {
defaultHandler := func(c *models.ReqContext) {
resp := make(map[string]interface{})
......@@ -26,12 +29,14 @@ func rateLimiterScenario(t *testing.T, desc string, rps int, burst int, fn rateL
}
currentTime := time.Now()
cfg := setting.NewCfg()
m := macaron.New()
m.Use(macaron.Renderer(macaron.RenderOptions{
Directory: "",
Delims: macaron.Delims{Left: "[[", Right: "]]"},
}))
m.Use(GetContextHandler(nil, nil, nil))
m.Use(getContextHandler(t, cfg).Middleware)
m.Get("/foo", RateLimit(rps, burst, func() time.Time { return currentTime }), defaultHandler)
fn(func() *httptest.ResponseRecorder {
......
......@@ -103,7 +103,7 @@ func function(pc uintptr) []byte {
// Recovery returns a middleware that recovers from any panics and writes a 500 if there was one.
// While Martini is in development mode, Recovery will also output the panic as HTML.
func Recovery() macaron.Handler {
func Recovery(cfg *setting.Cfg) macaron.Handler {
return func(c *macaron.Context) {
defer func() {
if r := recover(); r != nil {
......@@ -134,7 +134,7 @@ func Recovery() macaron.Handler {
c.Data["Title"] = "Server Error"
c.Data["AppSubUrl"] = setting.AppSubUrl
c.Data["Theme"] = setting.DefaultTheme
c.Data["Theme"] = cfg.DefaultTheme
if setting.Env == setting.Dev {
if err, ok := r.(error); ok {
......@@ -158,7 +158,7 @@ func Recovery() macaron.Handler {
c.JSON(500, resp)
} else {
c.HTML(500, setting.ErrTemplateName)
c.HTML(500, cfg.ErrTemplateName)
}
}
}()
......
......@@ -16,8 +16,6 @@ import (
)
func TestRecoveryMiddleware(t *testing.T) {
setting.ErrTemplateName = "error-template"
t.Run("Given an API route that panics", func(t *testing.T) {
apiURL := "/api/whatever"
recoveryScenario(t, "recovery middleware should return json", apiURL, func(t *testing.T, sc *scenarioContext) {
......@@ -52,18 +50,21 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
t.Run(desc, func(t *testing.T) {
defer bus.ClearBusHandlers()
cfg := setting.NewCfg()
cfg.ErrTemplateName = "error-template"
sc := &scenarioContext{
t: t,
url: url,
cfg: cfg,
}
viewsPath, err := filepath.Abs("../../public/views")
require.NoError(t, err)
sc.m = macaron.New()
sc.m.Use(Recovery())
sc.m.Use(Recovery(cfg))
sc.m.Use(AddDefaultResponseHeaders())
sc.m.Use(AddDefaultResponseHeaders(cfg))
sc.m.Use(macaron.Renderer(macaron.RenderOptions{
Directory: viewsPath,
Delims: macaron.Delims{Left: "[[", Right: "]]"},
......@@ -72,7 +73,8 @@ func recoveryScenario(t *testing.T, desc string, url string, fn scenarioFunc) {
sc.userAuthTokenService = auth.NewFakeUserAuthTokenService()
sc.remoteCacheService = remotecache.NewFakeStore(t)
sc.m.Use(GetContextHandler(sc.userAuthTokenService, sc.remoteCacheService, nil))
contextHandler := getContextHandler(t, nil)
sc.m.Use(contextHandler.Middleware)
// mock out gc goroutine
sc.m.Use(OrgRedirect())
......
package middleware
import (
"time"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/rendering"
)
func initContextWithRenderAuth(ctx *models.ReqContext, renderService rendering.Service) bool {
key := ctx.GetCookie("renderKey")
if key == "" {
return false
}
renderUser, exists := renderService.GetRenderUser(key)
if !exists {
ctx.JsonApiErr(401, "Invalid Render Key", nil)
return true
}
ctx.IsSignedIn = true
ctx.SignedInUser = &models.SignedInUser{
OrgId: renderUser.OrgID,
UserId: renderUser.UserID,
OrgRole: models.RoleType(renderUser.OrgRole),
}
ctx.IsRenderCall = true
ctx.LastSeenAt = time.Now()
return true
}
......@@ -11,6 +11,7 @@ import (
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/contexthandler"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require"
)
......@@ -29,6 +30,8 @@ type scenarioContext struct {
url string
userAuthTokenService *auth.FakeUserAuthTokenService
remoteCacheService *remotecache.RemoteCache
cfg *setting.Cfg
contextHandler *contexthandler.ContextHandler
req *http.Request
}
......@@ -94,9 +97,9 @@ func (sc *scenarioContext) exec() {
}
if sc.tokenSessionCookie != "" {
sc.t.Log(`Adding cookie`, "name", setting.LoginCookieName, "value", sc.tokenSessionCookie)
sc.t.Log(`Adding cookie`, "name", sc.cfg.LoginCookieName, "value", sc.tokenSessionCookie)
sc.req.AddCookie(&http.Cookie{
Name: setting.LoginCookieName,
Name: sc.cfg.LoginCookieName,
Value: sc.tokenSessionCookie,
})
}
......
......@@ -3,6 +3,7 @@ package models
import (
"time"
"github.com/grafana/grafana/pkg/setting"
"golang.org/x/oauth2"
)
......@@ -84,6 +85,7 @@ type LoginUserQuery struct {
User *User
IpAddress string
AuthModule string
Cfg *setting.Cfg
}
type GetUserByAuthInfoQuery struct {
......
package registry
import (
"fmt"
"github.com/facebookgo/inject"
)
// BuildServiceGraph builds a graph of services and their dependencies.
// The services are initialized after the graph is built.
func BuildServiceGraph(objs []interface{}, services []*Descriptor) error {
if services == nil {
services = GetServices()
}
for _, service := range services {
objs = append(objs, service.Instance)
}
serviceGraph := inject.Graph{}
// Provide services and their dependencies to the graph.
for _, obj := range objs {
if err := serviceGraph.Provide(&inject.Object{Value: obj}); err != nil {
return fmt.Errorf("failed to provide object to the graph: %w", err)
}
}
// Resolve services and their dependencies.
if err := serviceGraph.Populate(); err != nil {
return fmt.Errorf("failed to populate service dependencies: %w", err)
}
// Initialize services.
for _, service := range services {
if IsDisabled(service.Instance) {
continue
}
if err := service.Instance.Init(); err != nil {
return fmt.Errorf("service init failed: %w", err)
}
}
return nil
}
......@@ -18,8 +18,14 @@ import (
"github.com/grafana/grafana/pkg/util"
)
const ServiceName = "UserAuthTokenService"
func init() {
registry.RegisterService(&UserAuthTokenService{})
registry.Register(&registry.Descriptor{
Name: ServiceName,
Instance: &UserAuthTokenService{},
InitPriority: registry.Medium,
})
}
var getTime = time.Now
......
......@@ -57,8 +57,13 @@ func NewFakeUserAuthTokenService() *FakeUserAuthTokenService {
}
}
func (s *FakeUserAuthTokenService) CreateToken(ctx context.Context, userId int64, clientIP net.IP,
userAgent string) (*models.UserToken, error) {
// Init initializes the service.
// Required for dependency injection.
func (s *FakeUserAuthTokenService) Init() error {
return nil
}
func (s *FakeUserAuthTokenService) CreateToken(ctx context.Context, userId int64, clientIP net.IP, userAgent string) (*models.UserToken, error) {
return s.CreateTokenProvider(context.Background(), userId, clientIP, userAgent)
}
......
package middleware
package contexthandler
import (
"fmt"
......@@ -8,8 +8,12 @@ import (
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/infra/remotecache"
"github.com/grafana/grafana/pkg/middleware/authproxy"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/contexthandler/authproxy"
"github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require"
macaron "gopkg.in/macaron.v1"
......@@ -41,25 +45,16 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) {
}
return nil
}
origHeaderName := setting.AuthProxyHeaderName
origEnabled := setting.AuthProxyEnabled
origHeaderProperty := setting.AuthProxyHeaderProperty
bus.AddHandler("", upsertHandler)
bus.AddHandler("", getUserHandler)
t.Cleanup(func() {
setting.AuthProxyHeaderName = origHeaderName
setting.AuthProxyEnabled = origEnabled
setting.AuthProxyHeaderProperty = origHeaderProperty
bus.ClearBusHandlers()
})
setting.AuthProxyHeaderName = "X-Killa"
setting.AuthProxyEnabled = true
setting.AuthProxyHeaderProperty = "username"
svc := getContextHandler(t)
req, err := http.NewRequest("POST", "http://example.com", nil)
require.NoError(t, err)
store := remotecache.NewFakeStore(t)
ctx := &models.ReqContext{
Context: &macaron.Context{
Req: macaron.Request{
......@@ -69,20 +64,72 @@ func TestInitContextWithAuthProxy_CachedInvalidUserID(t *testing.T) {
},
Logger: log.New("Test"),
}
req.Header.Add(setting.AuthProxyHeaderName, name)
req.Header.Set(svc.Cfg.AuthProxyHeaderName, name)
key := fmt.Sprintf(authproxy.CachePrefix, authproxy.HashCacheKey(name))
t.Logf("Injecting stale user ID in cache with key %q", key)
err = store.Set(key, int64(33), 0)
err = svc.RemoteCache.Set(key, int64(33), 0)
require.NoError(t, err)
authEnabled := initContextWithAuthProxy(store, ctx, orgID)
authEnabled := svc.initContextWithAuthProxy(ctx, orgID)
require.True(t, authEnabled)
require.Equal(t, userID, ctx.SignedInUser.UserId)
require.True(t, ctx.IsSignedIn)
i, err := store.Get(key)
i, err := svc.RemoteCache.Get(key)
require.NoError(t, err)
require.Equal(t, userID, i.(int64))
}
type fakeRenderService struct {
rendering.Service
}
func (s *fakeRenderService) Init() error {
return nil
}
func getContextHandler(t *testing.T) *ContextHandler {
t.Helper()
sqlStore := sqlstore.InitTestDB(t)
remoteCacheSvc := &remotecache.RemoteCache{}
cfg := setting.NewCfg()
cfg.RemoteCacheOptions = &setting.RemoteCacheOptions{
Name: "database",
}
cfg.AuthProxyHeaderName = "X-Killa"
cfg.AuthProxyEnabled = true
cfg.AuthProxyHeaderProperty = "username"
userAuthTokenSvc := auth.NewFakeUserAuthTokenService()
renderSvc := &fakeRenderService{}
svc := &ContextHandler{}
err := registry.BuildServiceGraph([]interface{}{cfg}, []*registry.Descriptor{
{
Name: sqlstore.ServiceName,
Instance: sqlStore,
},
{
Name: remotecache.ServiceName,
Instance: remoteCacheSvc,
},
{
Name: auth.ServiceName,
Instance: userAuthTokenSvc,
},
{
Name: rendering.ServiceName,
Instance: renderSvc,
},
{
Name: ServiceName,
Instance: svc,
},
})
require.NoError(t, err)
return svc
}
......@@ -32,7 +32,13 @@ const (
var getLDAPConfig = ldap.GetConfig
// isLDAPEnabled checks if LDAP is enabled
var isLDAPEnabled = ldap.IsEnabled
var isLDAPEnabled = func(cfg *setting.Cfg) bool {
if cfg != nil {
return cfg.LDAPEnabled
}
return setting.LDAPEnabled
}
// newLDAP creates multiple LDAP instance
var newLDAP = multildap.New
......@@ -42,18 +48,11 @@ var supportedHeaderFields = []string{"Name", "Email", "Login", "Groups"}
// AuthProxy struct
type AuthProxy struct {
store *remotecache.RemoteCache
ctx *models.ReqContext
orgID int64
header string
enabled bool
LDAPAllowSignup bool
AuthProxyAutoSignUp bool
whitelistIP string
headerType string
headers map[string]string
cacheTTL int
cfg *setting.Cfg
remoteCache *remotecache.RemoteCache
ctx *models.ReqContext
orgID int64
header string
}
// Error auth proxy specific error
......@@ -77,35 +76,27 @@ func (err Error) Error() string {
// Options for the AuthProxy
type Options struct {
Store *remotecache.RemoteCache
Ctx *models.ReqContext
OrgID int64
RemoteCache *remotecache.RemoteCache
Ctx *models.ReqContext
OrgID int64
}
// New instance of the AuthProxy
func New(options *Options) *AuthProxy {
header := options.Ctx.Req.Header.Get(setting.AuthProxyHeaderName)
func New(cfg *setting.Cfg, options *Options) *AuthProxy {
header := options.Ctx.Req.Header.Get(cfg.AuthProxyHeaderName)
return &AuthProxy{
store: options.Store,
ctx: options.Ctx,
orgID: options.OrgID,
header: header,
enabled: setting.AuthProxyEnabled,
headerType: setting.AuthProxyHeaderProperty,
headers: setting.AuthProxyHeaders,
whitelistIP: setting.AuthProxyWhitelist,
cacheTTL: setting.AuthProxySyncTtl,
LDAPAllowSignup: setting.LDAPAllowSignup,
AuthProxyAutoSignUp: setting.AuthProxyAutoSignUp,
remoteCache: options.RemoteCache,
cfg: cfg,
ctx: options.Ctx,
orgID: options.OrgID,
header: header,
}
}
// IsEnabled checks if the proxy auth is enabled
func (auth *AuthProxy) IsEnabled() bool {
// Bail if the setting is not enabled
return auth.enabled
return auth.cfg.AuthProxyEnabled
}
// HasHeader checks if the we have specified header
......@@ -113,15 +104,15 @@ func (auth *AuthProxy) HasHeader() bool {
return len(auth.header) != 0
}
// IsAllowedIP compares presented IP with the whitelist one
// IsAllowedIP returns whether provided IP is allowed.
func (auth *AuthProxy) IsAllowedIP() error {
ip := auth.ctx.Req.RemoteAddr
if len(strings.TrimSpace(auth.whitelistIP)) == 0 {
if len(strings.TrimSpace(auth.cfg.AuthProxyWhitelist)) == 0 {
return nil
}
proxies := strings.Split(auth.whitelistIP, ",")
proxies := strings.Split(auth.cfg.AuthProxyWhitelist, ",")
var proxyObjs []*net.IPNet
for _, proxy := range proxies {
result, err := coerceProxyAddress(proxy)
......@@ -181,7 +172,7 @@ func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error)
}
}
if isLDAPEnabled() {
if isLDAPEnabled(auth.cfg) {
id, err := auth.LoginViaLDAP()
if err != nil {
if errors.Is(err, ldap.ErrInvalidCredentials) {
......@@ -205,7 +196,7 @@ func (auth *AuthProxy) Login(logger log.Logger, ignoreCache bool) (int64, error)
func (auth *AuthProxy) GetUserViaCache(logger log.Logger) (int64, error) {
cacheKey := auth.getKey()
logger.Debug("Getting user ID via auth cache", "cacheKey", cacheKey)
userID, err := auth.store.Get(cacheKey)
userID, err := auth.remoteCache.Get(cacheKey)
if err != nil {
logger.Debug("Failed getting user ID via auth cache", "error", err)
return 0, err
......@@ -219,7 +210,7 @@ func (auth *AuthProxy) GetUserViaCache(logger log.Logger) (int64, error) {
func (auth *AuthProxy) RemoveUserFromCache(logger log.Logger) error {
cacheKey := auth.getKey()
logger.Debug("Removing user from auth cache", "cacheKey", cacheKey)
if err := auth.store.Delete(cacheKey); err != nil {
if err := auth.remoteCache.Delete(cacheKey); err != nil {
return err
}
......@@ -229,12 +220,13 @@ func (auth *AuthProxy) RemoveUserFromCache(logger log.Logger) error {
// LoginViaLDAP logs in user via LDAP request
func (auth *AuthProxy) LoginViaLDAP() (int64, error) {
config, err := getLDAPConfig()
config, err := getLDAPConfig(auth.cfg)
if err != nil {
return 0, newError("failed to get LDAP config", err)
}
extUser, _, err := newLDAP(config.Servers).User(auth.header)
mldap := newLDAP(config.Servers)
extUser, _, err := mldap.User(auth.header)
if err != nil {
return 0, err
}
......@@ -242,7 +234,7 @@ func (auth *AuthProxy) LoginViaLDAP() (int64, error) {
// Have to sync grafana and LDAP user during log in
upsert := &models.UpsertUserCommand{
ReqContext: auth.ctx,
SignupAllowed: auth.LDAPAllowSignup,
SignupAllowed: auth.cfg.LDAPAllowSignup,
ExternalUser: extUser,
}
if err := bus.Dispatch(upsert); err != nil {
......@@ -259,7 +251,7 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
AuthId: auth.header,
}
switch auth.headerType {
switch auth.cfg.AuthProxyHeaderProperty {
case "username":
extUser.Login = auth.header
......@@ -284,7 +276,7 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
upsert := &models.UpsertUserCommand{
ReqContext: auth.ctx,
SignupAllowed: setting.AuthProxyAutoSignUp,
SignupAllowed: auth.cfg.AuthProxyAutoSignUp,
ExternalUser: extUser,
}
......@@ -299,8 +291,7 @@ func (auth *AuthProxy) LoginViaHeader() (int64, error) {
// headersIterator iterates over all non-empty supported additional headers
func (auth *AuthProxy) headersIterator(fn func(field string, header string)) {
for _, field := range supportedHeaderFields {
h := auth.headers[field]
h := auth.cfg.AuthProxyHeaders[field]
if h == "" {
continue
}
......@@ -311,8 +302,8 @@ func (auth *AuthProxy) headersIterator(fn func(field string, header string)) {
}
}
// GetSignedUser gets full signed user info.
func (auth *AuthProxy) GetSignedUser(userID int64) (*models.SignedInUser, error) {
// GetSignedUser gets full signed in user info.
func (auth *AuthProxy) GetSignedInUser(userID int64) (*models.SignedInUser, error) {
query := &models.GetSignedInUserQuery{
OrgId: auth.orgID,
UserId: userID,
......@@ -330,14 +321,14 @@ func (auth *AuthProxy) Remember(id int64) error {
key := auth.getKey()
// Check if user already in cache
userID, _ := auth.store.Get(key)
userID, _ := auth.remoteCache.Get(key)
if userID != nil {
return nil
}
expiration := time.Duration(auth.cacheTTL) * time.Minute
expiration := time.Duration(auth.cfg.AuthProxySyncTTL) * time.Minute
err := auth.store.Set(key, id, expiration)
err := auth.remoteCache.Set(key, id, expiration)
if err != nil {
return err
}
......@@ -353,5 +344,8 @@ func coerceProxyAddress(proxyAddr string) (*net.IPNet, error) {
}
_, network, err := net.ParseCIDR(proxyAddr)
return network, err
if err != nil {
return nil, fmt.Errorf("could not parse the network: %w", err)
}
return network, nil
}
......@@ -47,9 +47,22 @@ func (m *fakeMultiLDAP) User(login string) (
return result, ldap.ServerConfig{}, nil
}
func prepareMiddleware(t *testing.T, req *http.Request, store *remotecache.RemoteCache) *AuthProxy {
const hdrName = "markelog"
func prepareMiddleware(t *testing.T, remoteCache *remotecache.RemoteCache, cb func(*http.Request, *setting.Cfg)) *AuthProxy {
t.Helper()
cfg := setting.NewCfg()
cfg.AuthProxyHeaderName = "X-Killa"
req, err := http.NewRequest("POST", "http://example.com", nil)
require.NoError(t, err)
req.Header.Set(cfg.AuthProxyHeaderName, hdrName)
if cb != nil {
cb(req, cfg)
}
ctx := &models.ReqContext{
Context: &macaron.Context{
Req: macaron.Request{
......@@ -58,10 +71,10 @@ func prepareMiddleware(t *testing.T, req *http.Request, store *remotecache.Remot
},
}
auth := New(&Options{
Store: store,
Ctx: ctx,
OrgID: 4,
auth := New(cfg, &Options{
RemoteCache: remoteCache,
Ctx: ctx,
OrgID: 4,
})
return auth
......@@ -69,24 +82,17 @@ func prepareMiddleware(t *testing.T, req *http.Request, store *remotecache.Remot
func TestMiddlewareContext(t *testing.T) {
logger := log.New("test")
req, err := http.NewRequest("POST", "http://example.com", nil)
require.NoError(t, err)
setting.AuthProxyHeaderName = "X-Killa"
store := remotecache.NewFakeStore(t)
name := "markelog"
req.Header.Add(setting.AuthProxyHeaderName, name)
cache := remotecache.NewFakeStore(t)
t.Run("When the cache only contains the main header with a simple cache key", func(t *testing.T) {
const id int64 = 33
// Set cache key
key := fmt.Sprintf(CachePrefix, HashCacheKey(name))
err := store.Set(key, id, 0)
key := fmt.Sprintf(CachePrefix, HashCacheKey(hdrName))
err := cache.Set(key, id, 0)
require.NoError(t, err)
// Set up the middleware
auth := prepareMiddleware(t, req, store)
assert.Equal(t, "auth-proxy-sync-ttl:0a7f3374e9659b10980fd66247b0cf2f", auth.getKey())
auth := prepareMiddleware(t, cache, nil)
assert.Equal(t, key, auth.getKey())
gotID, err := auth.Login(logger, false)
require.NoError(t, err)
......@@ -96,15 +102,16 @@ func TestMiddlewareContext(t *testing.T) {
t.Run("When the cache key contains additional headers", func(t *testing.T) {
const id int64 = 33
setting.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS"}
group := "grafana-core-team"
req.Header.Add("X-WEBAUTH-GROUPS", group)
const group = "grafana-core-team"
key := fmt.Sprintf(CachePrefix, HashCacheKey(name+"-"+group))
err := store.Set(key, id, 0)
key := fmt.Sprintf(CachePrefix, HashCacheKey(hdrName+"-"+group))
err := cache.Set(key, id, 0)
require.NoError(t, err)
auth := prepareMiddleware(t, req, store)
auth := prepareMiddleware(t, cache, func(req *http.Request, cfg *setting.Cfg) {
req.Header.Set("X-WEBAUTH-GROUPS", group)
cfg.AuthProxyHeaders = map[string]string{"Groups": "X-WEBAUTH-GROUPS"}
})
assert.Equal(t, "auth-proxy-sync-ttl:14f69b7023baa0ac98c96b31cec07bc0", auth.getKey())
gotID, err := auth.Login(logger, false)
......@@ -115,12 +122,6 @@ func TestMiddlewareContext(t *testing.T) {
func TestMiddlewareContext_ldap(t *testing.T) {
logger := log.New("test")
req, err := http.NewRequest("POST", "http://example.com", nil)
require.NoError(t, err)
setting.AuthProxyHeaderName = "X-Killa"
const headerName = "markelog"
req.Header.Add(setting.AuthProxyHeaderName, headerName)
t.Run("Logs in via LDAP", func(t *testing.T) {
const id int64 = 42
......@@ -133,7 +134,16 @@ func TestMiddlewareContext_ldap(t *testing.T) {
return nil
})
isLDAPEnabled = func() bool {
origIsLDAPEnabled := isLDAPEnabled
origGetLDAPConfig := getLDAPConfig
origNewLDAP := newLDAP
t.Cleanup(func() {
newLDAP = origNewLDAP
isLDAPEnabled = origIsLDAPEnabled
getLDAPConfig = origGetLDAPConfig
})
isLDAPEnabled = func(*setting.Cfg) bool {
return true
}
......@@ -141,7 +151,7 @@ func TestMiddlewareContext_ldap(t *testing.T) {
ID: id,
}
getLDAPConfig = func() (*ldap.Config, error) {
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
config := &ldap.Config{
Servers: []*ldap.ServerConfig{
{
......@@ -156,15 +166,9 @@ func TestMiddlewareContext_ldap(t *testing.T) {
return stub
}
defer func() {
newLDAP = multildap.New
isLDAPEnabled = ldap.IsEnabled
getLDAPConfig = ldap.GetConfig
}()
store := remotecache.NewFakeStore(t)
cache := remotecache.NewFakeStore(t)
auth := prepareMiddleware(t, req, store)
auth := prepareMiddleware(t, cache, nil)
gotID, err := auth.Login(logger, false)
require.NoError(t, err)
......@@ -173,25 +177,28 @@ func TestMiddlewareContext_ldap(t *testing.T) {
assert.True(t, stub.userCalled)
})
t.Run("Gets nice error if ldap is enabled but not configured", func(t *testing.T) {
t.Run("Gets nice error if LDAP is enabled, but not configured", func(t *testing.T) {
const id int64 = 42
isLDAPEnabled = func() bool {
origIsLDAPEnabled := isLDAPEnabled
origNewLDAP := newLDAP
origGetLDAPConfig := getLDAPConfig
t.Cleanup(func() {
isLDAPEnabled = origIsLDAPEnabled
newLDAP = origNewLDAP
getLDAPConfig = origGetLDAPConfig
})
isLDAPEnabled = func(*setting.Cfg) bool {
return true
}
getLDAPConfig = func() (*ldap.Config, error) {
getLDAPConfig = func(*setting.Cfg) (*ldap.Config, error) {
return nil, errors.New("something went wrong")
}
defer func() {
newLDAP = multildap.New
isLDAPEnabled = ldap.IsEnabled
getLDAPConfig = ldap.GetConfig
}()
store := remotecache.NewFakeStore(t)
cache := remotecache.NewFakeStore(t)
auth := prepareMiddleware(t, req, store)
auth := prepareMiddleware(t, cache, nil)
stub := &fakeMultiLDAP{
ID: id,
......
package contexthandler
import (
"context"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/grafana/grafana/pkg/components/gtime"
"github.com/grafana/grafana/pkg/infra/log"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
macaron "gopkg.in/macaron.v1"
)
func TestDontRotateTokensOnCancelledRequests(t *testing.T) {
ctxHdlr := getContextHandler(t)
ctx, cancel := context.WithCancel(context.Background())
reqContext, _, err := initTokenRotationScenario(ctx, t)
require.NoError(t, err)
tryRotateCallCount := 0
uts := &auth.FakeUserAuthTokenService{
TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP net.IP,
userAgent string) (bool, error) {
tryRotateCallCount++
return false, nil
},
}
token := &models.UserToken{AuthToken: "oldtoken"}
fn := ctxHdlr.rotateEndOfRequestFunc(reqContext, uts, token)
cancel()
fn(reqContext.Resp)
assert.Equal(t, 0, tryRotateCallCount, "Token rotation was attempted")
}
func TestTokenRotationAtEndOfRequest(t *testing.T) {
ctxHdlr := getContextHandler(t)
reqContext, rr, err := initTokenRotationScenario(context.Background(), t)
require.NoError(t, err)
uts := &auth.FakeUserAuthTokenService{
TryRotateTokenProvider: func(ctx context.Context, token *models.UserToken, clientIP net.IP,
userAgent string) (bool, error) {
newToken, err := util.RandomHex(16)
require.NoError(t, err)
token.AuthToken = newToken
return true, nil
},
}
token := &models.UserToken{AuthToken: "oldtoken"}
ctxHdlr.rotateEndOfRequestFunc(reqContext, uts, token)(reqContext.Resp)
foundLoginCookie := false
resp := rr.Result()
defer resp.Body.Close()
for _, c := range resp.Cookies() {
if c.Name == "login_token" {
foundLoginCookie = true
require.NotEqual(t, token.AuthToken, c.Value, "Auth token is still the same")
}
}
assert.True(t, foundLoginCookie, "Could not find cookie")
}
func initTokenRotationScenario(ctx context.Context, t *testing.T) (*models.ReqContext, *httptest.ResponseRecorder, error) {
t.Helper()
origLoginCookieName := setting.LoginCookieName
origLoginMaxLifetime := setting.LoginMaxLifetime
t.Cleanup(func() {
setting.LoginCookieName = origLoginCookieName
setting.LoginMaxLifetime = origLoginMaxLifetime
})
setting.LoginCookieName = "login_token"
var err error
setting.LoginMaxLifetime, err = gtime.ParseDuration("7d")
if err != nil {
return nil, nil, err
}
rr := httptest.NewRecorder()
req, err := http.NewRequestWithContext(ctx, "", "", nil)
if err != nil {
return nil, nil, err
}
reqContext := &models.ReqContext{
Context: &macaron.Context{
Req: macaron.Request{
Request: req,
},
},
Logger: log.New("testlogger"),
}
mw := mockWriter{rr}
reqContext.Resp = mw
return reqContext, rr, nil
}
type mockWriter struct {
*httptest.ResponseRecorder
}
func (mw mockWriter) Flush() {}
func (mw mockWriter) Status() int { return 0 }
func (mw mockWriter) Size() int { return 0 }
func (mw mockWriter) Written() bool { return false }
func (mw mockWriter) Before(macaron.BeforeFunc) {}
func (mw mockWriter) Push(target string, opts *http.PushOptions) error {
return nil
}
......@@ -94,8 +94,12 @@ var config *Config
// GetConfig returns the LDAP config if LDAP is enabled otherwise it returns nil. It returns either cached value of
// the config or it reads it and caches it first.
func GetConfig() (*Config, error) {
if !IsEnabled() {
func GetConfig(cfg *setting.Cfg) (*Config, error) {
if cfg != nil {
if !cfg.LDAPEnabled {
return nil, nil
}
} else if !IsEnabled() {
return nil, nil
}
......
......@@ -25,12 +25,13 @@ import (
func init() {
remotecache.Register(&RenderUser{})
registry.Register(&registry.Descriptor{
Name: "RenderingService",
Name: ServiceName,
Instance: &RenderingService{},
InitPriority: registry.High,
})
}
const ServiceName = "RenderingService"
const renderKeyPrefix = "render-%s"
type RenderUser struct {
......@@ -226,8 +227,8 @@ func (rs *RenderingService) getURL(path string) string {
return fmt.Sprintf("%s%s&render=1", rs.Cfg.RendererCallbackUrl, path)
}
protocol := setting.Protocol
switch setting.Protocol {
protocol := rs.Cfg.Protocol
switch protocol {
case setting.HTTPScheme:
protocol = "http"
case setting.HTTP2Scheme, setting.HTTPSScheme:
......
......@@ -28,7 +28,7 @@ func TestGetUrl(t *testing.T) {
t.Run("And protocol HTTP configured should return expected path", func(t *testing.T) {
rs.Cfg.ServeFromSubPath = false
rs.Cfg.AppSubURL = ""
setting.Protocol = setting.HTTPScheme
rs.Cfg.Protocol = setting.HTTPScheme
url := rs.getURL(path)
require.Equal(t, "http://localhost:3000/"+path+"&render=1", url)
......@@ -43,7 +43,7 @@ func TestGetUrl(t *testing.T) {
t.Run("And protocol HTTPS configured should return expected path", func(t *testing.T) {
rs.Cfg.ServeFromSubPath = false
rs.Cfg.AppSubURL = ""
setting.Protocol = setting.HTTPSScheme
rs.Cfg.Protocol = setting.HTTPSScheme
url := rs.getURL(path)
require.Equal(t, "https://localhost:3000/"+path+"&render=1", url)
})
......@@ -51,7 +51,7 @@ func TestGetUrl(t *testing.T) {
t.Run("And protocol HTTP2 configured should return expected path", func(t *testing.T) {
rs.Cfg.ServeFromSubPath = false
rs.Cfg.AppSubURL = ""
setting.Protocol = setting.HTTP2Scheme
rs.Cfg.Protocol = setting.HTTP2Scheme
url := rs.getURL(path)
require.Equal(t, "https://localhost:3000/"+path+"&render=1", url)
})
......
......@@ -6,8 +6,6 @@ import (
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
)
func (ss *SQLStore) addPreferencesQueryAndCommandHandlers() {
......@@ -42,7 +40,7 @@ func (ss *SQLStore) GetPreferencesWithDefaults(query *models.GetPreferencesWithD
}
res := &models.Preferences{
Theme: setting.DefaultTheme,
Theme: ss.Cfg.DefaultTheme,
Timezone: ss.Cfg.DateFormats.DefaultTimezone,
HomeDashboardId: 0,
}
......
......@@ -6,7 +6,6 @@ import (
"testing"
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
"github.com/stretchr/testify/require"
)
......@@ -14,7 +13,7 @@ func TestPreferencesDataAccess(t *testing.T) {
ss := InitTestDB(t)
t.Run("GetPreferencesWithDefaults with no saved preferences should return defaults", func(t *testing.T) {
setting.DefaultTheme = "light"
ss.Cfg.DefaultTheme = "light"
ss.Cfg.DateFormats.DefaultTimezone = "UTC"
query := &models.GetPreferencesWithDefaultsQuery{User: &models.SignedInUser{}}
......
......@@ -39,16 +39,21 @@ var (
// ContextSessionKey is used as key to save values in `context.Context`
type ContextSessionKey struct{}
const ServiceName = "SqlStore"
const InitPriority = registry.High
func init() {
ss := &SQLStore{}
// This change will make xorm use an empty default schema for postgres and
// by that mimic the functionality of how it was functioning before
// xorm's changes above.
xorm.DefaultPostgresSchema = ""
registry.Register(&registry.Descriptor{
Name: "SQLStore",
Instance: &SQLStore{},
InitPriority: registry.High,
Name: ServiceName,
Instance: ss,
InitPriority: InitPriority,
})
}
......@@ -113,13 +118,20 @@ func (ss *SQLStore) Init() error {
func (ss *SQLStore) ensureMainOrgAndAdminUser() error {
err := ss.InTransaction(context.Background(), func(ctx context.Context) error {
systemUserCountQuery := models.GetSystemUserCountStatsQuery{}
err := bus.DispatchCtx(ctx, &systemUserCountQuery)
var stats models.SystemUserCountStats
err := ss.WithDbSession(ctx, func(sess *DBSession) error {
var rawSql = `SELECT COUNT(id) AS Count FROM ` + dialect.Quote("user")
if _, err := sess.SQL(rawSql).Get(&stats); err != nil {
return fmt.Errorf("could not determine if admin user exists: %w", err)
}
return nil
})
if err != nil {
return fmt.Errorf("could not determine if admin user exists: %w", err)
return err
}
if systemUserCountQuery.Result.Count > 0 {
if stats.Count > 0 {
return nil
}
......@@ -351,7 +363,7 @@ func InitTestDB(t ITestDB) *SQLStore {
testSQLStore = &SQLStore{}
testSQLStore.Bus = bus.New()
testSQLStore.CacheService = localcache.New(5*time.Minute, 10*time.Minute)
testSQLStore.skipEnsureDefaultOrgAndUser = true
testSQLStore.skipEnsureDefaultOrgAndUser = false
dbType := migrator.SQLite
......
......@@ -86,4 +86,6 @@ func (cfg *Cfg) readQuotaSettings() {
ApiKey: quota.Key("global_api_key").MustInt64(-1),
Session: quota.Key("global_session").MustInt64(-1),
}
cfg.Quota = Quota
}
......@@ -133,7 +133,7 @@ func TestLoadingSettings(t *testing.T) {
})
So(err, ShouldBeNil)
So(Domain, ShouldEqual, "test2")
So(cfg.Domain, ShouldEqual, "test2")
})
Convey("Defaults can be overridden in specified config file", func() {
......@@ -239,7 +239,7 @@ func TestLoadingSettings(t *testing.T) {
})
So(err, ShouldBeNil)
So(AuthProxySyncTtl, ShouldEqual, 2)
So(cfg.AuthProxySyncTTL, ShouldEqual, 2)
})
Convey("Only ldap_sync_ttl should return the value ldap_sync_ttl", func() {
......@@ -250,7 +250,7 @@ func TestLoadingSettings(t *testing.T) {
})
So(err, ShouldBeNil)
So(AuthProxySyncTtl, ShouldEqual, 5)
So(cfg.AuthProxySyncTTL, ShouldEqual, 5)
})
Convey("ldap_sync should override ldap_sync_ttl that is default value", func() {
......@@ -261,7 +261,7 @@ func TestLoadingSettings(t *testing.T) {
})
So(err, ShouldBeNil)
So(AuthProxySyncTtl, ShouldEqual, 5)
So(cfg.AuthProxySyncTTL, ShouldEqual, 5)
})
Convey("ldap_sync should not override ldap_sync_ttl that is different from default value", func() {
......@@ -272,7 +272,7 @@ func TestLoadingSettings(t *testing.T) {
})
So(err, ShouldBeNil)
So(AuthProxySyncTtl, ShouldEqual, 12)
So(cfg.AuthProxySyncTTL, ShouldEqual, 12)
})
})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment