Commit dcec61e1 by Carl Bergquist Committed by GitHub

Merge pull request #15378 from grafana/auth_token_quotas

use authTokenService for session quotas restrictions
parents 8bd7e5a5 e163aadf
...@@ -16,7 +16,7 @@ func (hs *HTTPServer) registerRoutes() { ...@@ -16,7 +16,7 @@ func (hs *HTTPServer) registerRoutes() {
reqOrgAdmin := middleware.ReqOrgAdmin reqOrgAdmin := middleware.ReqOrgAdmin
redirectFromLegacyDashboardURL := middleware.RedirectFromLegacyDashboardURL() redirectFromLegacyDashboardURL := middleware.RedirectFromLegacyDashboardURL()
redirectFromLegacyDashboardSoloURL := middleware.RedirectFromLegacyDashboardSoloURL() redirectFromLegacyDashboardSoloURL := middleware.RedirectFromLegacyDashboardSoloURL()
quota := middleware.Quota quota := middleware.Quota(hs.QuotaService)
bind := binding.Bind bind := binding.Bind
r := hs.RouteRegister r := hs.RouteRegister
...@@ -286,7 +286,7 @@ func (hs *HTTPServer) registerRoutes() { ...@@ -286,7 +286,7 @@ func (hs *HTTPServer) registerRoutes() {
dashboardRoute.Post("/calculate-diff", bind(dtos.CalculateDiffOptions{}), Wrap(CalculateDashboardDiff)) dashboardRoute.Post("/calculate-diff", bind(dtos.CalculateDiffOptions{}), Wrap(CalculateDashboardDiff))
dashboardRoute.Post("/db", bind(m.SaveDashboardCommand{}), Wrap(PostDashboard)) dashboardRoute.Post("/db", bind(m.SaveDashboardCommand{}), Wrap(hs.PostDashboard))
dashboardRoute.Get("/home", Wrap(GetHomeDashboard)) dashboardRoute.Get("/home", Wrap(GetHomeDashboard))
dashboardRoute.Get("/tags", GetDashboardTags) dashboardRoute.Get("/tags", GetDashboardTags)
dashboardRoute.Post("/import", bind(dtos.ImportDashboardCommand{}), Wrap(ImportDashboard)) dashboardRoute.Post("/import", bind(dtos.ImportDashboardCommand{}), Wrap(ImportDashboard))
...@@ -294,7 +294,7 @@ func (hs *HTTPServer) registerRoutes() { ...@@ -294,7 +294,7 @@ func (hs *HTTPServer) registerRoutes() {
dashboardRoute.Group("/id/:dashboardId", func(dashIdRoute routing.RouteRegister) { dashboardRoute.Group("/id/:dashboardId", func(dashIdRoute routing.RouteRegister) {
dashIdRoute.Get("/versions", Wrap(GetDashboardVersions)) dashIdRoute.Get("/versions", Wrap(GetDashboardVersions))
dashIdRoute.Get("/versions/:id", Wrap(GetDashboardVersion)) dashIdRoute.Get("/versions/:id", Wrap(GetDashboardVersion))
dashIdRoute.Post("/restore", bind(dtos.RestoreDashboardVersionCommand{}), Wrap(RestoreDashboardVersion)) dashIdRoute.Post("/restore", bind(dtos.RestoreDashboardVersionCommand{}), Wrap(hs.RestoreDashboardVersion))
dashIdRoute.Group("/permissions", func(dashboardPermissionRoute routing.RouteRegister) { dashIdRoute.Group("/permissions", func(dashboardPermissionRoute routing.RouteRegister) {
dashboardPermissionRoute.Get("/", Wrap(GetDashboardPermissionList)) dashboardPermissionRoute.Get("/", Wrap(GetDashboardPermissionList))
......
...@@ -18,7 +18,6 @@ import ( ...@@ -18,7 +18,6 @@ import (
m "github.com/grafana/grafana/pkg/models" m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins" "github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/services/guardian" "github.com/grafana/grafana/pkg/services/guardian"
"github.com/grafana/grafana/pkg/services/quota"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/util"
) )
...@@ -208,14 +207,14 @@ func DeleteDashboardByUID(c *m.ReqContext) Response { ...@@ -208,14 +207,14 @@ func DeleteDashboardByUID(c *m.ReqContext) Response {
}) })
} }
func PostDashboard(c *m.ReqContext, cmd m.SaveDashboardCommand) Response { func (hs *HTTPServer) PostDashboard(c *m.ReqContext, cmd m.SaveDashboardCommand) Response {
cmd.OrgId = c.OrgId cmd.OrgId = c.OrgId
cmd.UserId = c.UserId cmd.UserId = c.UserId
dash := cmd.GetDashboardModel() dash := cmd.GetDashboardModel()
if dash.Id == 0 && dash.Uid == "" { if dash.Id == 0 && dash.Uid == "" {
limitReached, err := quota.QuotaReached(c, "dashboard") limitReached, err := hs.QuotaService.QuotaReached(c, "dashboard")
if err != nil { if err != nil {
return Error(500, "failed to get quota", err) return Error(500, "failed to get quota", err)
} }
...@@ -463,7 +462,7 @@ func CalculateDashboardDiff(c *m.ReqContext, apiOptions dtos.CalculateDiffOption ...@@ -463,7 +462,7 @@ func CalculateDashboardDiff(c *m.ReqContext, apiOptions dtos.CalculateDiffOption
} }
// RestoreDashboardVersion restores a dashboard to the given version. // RestoreDashboardVersion restores a dashboard to the given version.
func RestoreDashboardVersion(c *m.ReqContext, apiCmd dtos.RestoreDashboardVersionCommand) Response { func (hs *HTTPServer) RestoreDashboardVersion(c *m.ReqContext, apiCmd dtos.RestoreDashboardVersionCommand) Response {
dash, rsp := getDashboardHelper(c.OrgId, "", c.ParamsInt64(":dashboardId"), "") dash, rsp := getDashboardHelper(c.OrgId, "", c.ParamsInt64(":dashboardId"), "")
if rsp != nil { if rsp != nil {
return rsp return rsp
...@@ -490,7 +489,7 @@ func RestoreDashboardVersion(c *m.ReqContext, apiCmd dtos.RestoreDashboardVersio ...@@ -490,7 +489,7 @@ func RestoreDashboardVersion(c *m.ReqContext, apiCmd dtos.RestoreDashboardVersio
saveCmd.Dashboard.Set("uid", dash.Uid) saveCmd.Dashboard.Set("uid", dash.Uid)
saveCmd.Message = fmt.Sprintf("Restored from version %d", version.Version) saveCmd.Message = fmt.Sprintf("Restored from version %d", version.Version)
return PostDashboard(c, saveCmd) return hs.PostDashboard(c, saveCmd)
} }
func GetDashboardTags(c *m.ReqContext) { func GetDashboardTags(c *m.ReqContext) {
......
...@@ -881,12 +881,16 @@ func postDashboardScenario(desc string, url string, routePattern string, mock *d ...@@ -881,12 +881,16 @@ func postDashboardScenario(desc string, url string, routePattern string, mock *d
Convey(desc+" "+url, func() { Convey(desc+" "+url, func() {
defer bus.ClearBusHandlers() defer bus.ClearBusHandlers()
hs := HTTPServer{
Bus: bus.GetBus(),
}
sc := setupScenarioContext(url) sc := setupScenarioContext(url)
sc.defaultHandler = Wrap(func(c *m.ReqContext) Response { sc.defaultHandler = Wrap(func(c *m.ReqContext) Response {
sc.context = c sc.context = c
sc.context.SignedInUser = &m.SignedInUser{OrgId: cmd.OrgId, UserId: cmd.UserId} sc.context.SignedInUser = &m.SignedInUser{OrgId: cmd.OrgId, UserId: cmd.UserId}
return PostDashboard(c, cmd) return hs.PostDashboard(c, cmd)
}) })
origNewDashboardService := dashboards.NewService origNewDashboardService := dashboards.NewService
......
...@@ -24,6 +24,7 @@ import ( ...@@ -24,6 +24,7 @@ import (
"github.com/grafana/grafana/pkg/services/cache" "github.com/grafana/grafana/pkg/services/cache"
"github.com/grafana/grafana/pkg/services/datasources" "github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/hooks" "github.com/grafana/grafana/pkg/services/hooks"
"github.com/grafana/grafana/pkg/services/quota"
"github.com/grafana/grafana/pkg/services/rendering" "github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/session" "github.com/grafana/grafana/pkg/services/session"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
...@@ -55,6 +56,7 @@ type HTTPServer struct { ...@@ -55,6 +56,7 @@ type HTTPServer struct {
CacheService *cache.CacheService `inject:""` CacheService *cache.CacheService `inject:""`
DatasourceCache datasources.CacheService `inject:""` DatasourceCache datasources.CacheService `inject:""`
AuthTokenService models.UserTokenService `inject:""` AuthTokenService models.UserTokenService `inject:""`
QuotaService *quota.QuotaService `inject:""`
} }
func (hs *HTTPServer) Init() error { func (hs *HTTPServer) Init() error {
......
...@@ -4,18 +4,30 @@ import ( ...@@ -4,18 +4,30 @@ import (
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/log" "github.com/grafana/grafana/pkg/log"
m "github.com/grafana/grafana/pkg/models" m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/services/quota" "github.com/grafana/grafana/pkg/services/quota"
) )
func init() { func init() {
bus.AddHandler("auth", UpsertUser) registry.RegisterService(&LoginService{})
} }
var ( var (
logger = log.New("login.ext_user") logger = log.New("login.ext_user")
) )
func UpsertUser(cmd *m.UpsertUserCommand) error { type LoginService struct {
Bus bus.Bus `inject:""`
QuotaService *quota.QuotaService `inject:""`
}
func (ls *LoginService) Init() error {
ls.Bus.AddHandler(ls.UpsertUser)
return nil
}
func (ls *LoginService) UpsertUser(cmd *m.UpsertUserCommand) error {
extUser := cmd.ExternalUser extUser := cmd.ExternalUser
userQuery := &m.GetUserByAuthInfoQuery{ userQuery := &m.GetUserByAuthInfoQuery{
...@@ -37,7 +49,7 @@ func UpsertUser(cmd *m.UpsertUserCommand) error { ...@@ -37,7 +49,7 @@ func UpsertUser(cmd *m.UpsertUserCommand) error {
return ErrInvalidCredentials return ErrInvalidCredentials
} }
limitReached, err := quota.QuotaReached(cmd.ReqContext, "user") limitReached, err := ls.QuotaService.QuotaReached(cmd.ReqContext, "user")
if err != nil { if err != nil {
log.Warn("Error getting user quota. error: %v", err) log.Warn("Error getting user quota. error: %v", err)
return ErrGettingUserQuota return ErrGettingUserQuota
...@@ -57,7 +69,7 @@ func UpsertUser(cmd *m.UpsertUserCommand) error { ...@@ -57,7 +69,7 @@ func UpsertUser(cmd *m.UpsertUserCommand) error {
AuthModule: extUser.AuthModule, AuthModule: extUser.AuthModule,
AuthId: extUser.AuthId, AuthId: extUser.AuthId,
} }
if err := bus.Dispatch(cmd2); err != nil { if err := ls.Bus.Dispatch(cmd2); err != nil {
return err return err
} }
} }
...@@ -78,12 +90,12 @@ func UpsertUser(cmd *m.UpsertUserCommand) error { ...@@ -78,12 +90,12 @@ func UpsertUser(cmd *m.UpsertUserCommand) error {
// Sync isGrafanaAdmin permission // Sync isGrafanaAdmin permission
if extUser.IsGrafanaAdmin != nil && *extUser.IsGrafanaAdmin != cmd.Result.IsAdmin { if extUser.IsGrafanaAdmin != nil && *extUser.IsGrafanaAdmin != cmd.Result.IsAdmin {
if err := bus.Dispatch(&m.UpdateUserPermissionsCommand{UserId: cmd.Result.Id, IsGrafanaAdmin: *extUser.IsGrafanaAdmin}); err != nil { if err := ls.Bus.Dispatch(&m.UpdateUserPermissionsCommand{UserId: cmd.Result.Id, IsGrafanaAdmin: *extUser.IsGrafanaAdmin}); err != nil {
return err return err
} }
} }
err = bus.Dispatch(&m.SyncTeamsCommand{ err = ls.Bus.Dispatch(&m.SyncTeamsCommand{
User: cmd.Result, User: cmd.Result,
ExternalUser: extUser, ExternalUser: extUser,
}) })
......
...@@ -395,8 +395,11 @@ func ldapAutherScenario(desc string, fn scenarioFunc) { ...@@ -395,8 +395,11 @@ func ldapAutherScenario(desc string, fn scenarioFunc) {
defer bus.ClearBusHandlers() defer bus.ClearBusHandlers()
sc := &scenarioContext{} sc := &scenarioContext{}
loginService := &LoginService{
Bus: bus.GetBus(),
}
bus.AddHandler("test", UpsertUser) bus.AddHandler("test", loginService.UpsertUser)
bus.AddHandlerCtx("test", func(ctx context.Context, cmd *m.SyncTeamsCommand) error { bus.AddHandlerCtx("test", func(ctx context.Context, cmd *m.SyncTeamsCommand) error {
return nil return nil
......
...@@ -682,6 +682,7 @@ type fakeUserAuthTokenService struct { ...@@ -682,6 +682,7 @@ type fakeUserAuthTokenService struct {
tryRotateTokenProvider func(token *m.UserToken, clientIP, userAgent string) (bool, error) tryRotateTokenProvider func(token *m.UserToken, clientIP, userAgent string) (bool, error)
lookupTokenProvider func(unhashedToken string) (*m.UserToken, error) lookupTokenProvider func(unhashedToken string) (*m.UserToken, error)
revokeTokenProvider func(token *m.UserToken) error revokeTokenProvider func(token *m.UserToken) error
activeAuthTokenCount func() (int64, error)
} }
func newFakeUserAuthTokenService() *fakeUserAuthTokenService { func newFakeUserAuthTokenService() *fakeUserAuthTokenService {
...@@ -704,6 +705,9 @@ func newFakeUserAuthTokenService() *fakeUserAuthTokenService { ...@@ -704,6 +705,9 @@ func newFakeUserAuthTokenService() *fakeUserAuthTokenService {
revokeTokenProvider: func(token *m.UserToken) error { revokeTokenProvider: func(token *m.UserToken) error {
return nil return nil
}, },
activeAuthTokenCount: func() (int64, error) {
return 10, nil
},
} }
} }
...@@ -722,3 +726,7 @@ func (s *fakeUserAuthTokenService) TryRotateToken(token *m.UserToken, clientIP, ...@@ -722,3 +726,7 @@ func (s *fakeUserAuthTokenService) TryRotateToken(token *m.UserToken, clientIP,
func (s *fakeUserAuthTokenService) RevokeToken(token *m.UserToken) error { func (s *fakeUserAuthTokenService) RevokeToken(token *m.UserToken) error {
return s.revokeTokenProvider(token) return s.revokeTokenProvider(token)
} }
func (s *fakeUserAuthTokenService) ActiveTokenCount() (int64, error) {
return s.activeAuthTokenCount()
}
...@@ -9,9 +9,12 @@ import ( ...@@ -9,9 +9,12 @@ import (
"github.com/grafana/grafana/pkg/services/quota" "github.com/grafana/grafana/pkg/services/quota"
) )
func Quota(target string) macaron.Handler { // Quota returns a function that returns a function used to call quotaservice based on target name
func Quota(quotaService *quota.QuotaService) func(target string) macaron.Handler {
//https://open.spotify.com/track/7bZSoBEAEEUsGEuLOf94Jm?si=T1Tdju5qRSmmR0zph_6RBw fuuuuunky
return func(target string) macaron.Handler {
return func(c *m.ReqContext) { return func(c *m.ReqContext) {
limitReached, err := quota.QuotaReached(c, target) limitReached, err := quotaService.QuotaReached(c, target)
if err != nil { if err != nil {
c.JsonApiErr(500, "failed to get quota", err) c.JsonApiErr(500, "failed to get quota", err)
return return
...@@ -21,4 +24,5 @@ func Quota(target string) macaron.Handler { ...@@ -21,4 +24,5 @@ func Quota(target string) macaron.Handler {
return return
} }
} }
}
} }
...@@ -3,9 +3,10 @@ package middleware ...@@ -3,9 +3,10 @@ package middleware
import ( import (
"testing" "testing"
"github.com/grafana/grafana/pkg/services/quota"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
m "github.com/grafana/grafana/pkg/models" m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/session"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
. "github.com/smartystreets/goconvey/convey" . "github.com/smartystreets/goconvey/convey"
) )
...@@ -13,10 +14,6 @@ import ( ...@@ -13,10 +14,6 @@ import (
func TestMiddlewareQuota(t *testing.T) { func TestMiddlewareQuota(t *testing.T) {
Convey("Given the grafana quota middleware", t, func() { Convey("Given the grafana quota middleware", t, func() {
session.GetSessionCount = func() int {
return 4
}
setting.AnonymousEnabled = false setting.AnonymousEnabled = false
setting.Quota = setting.QuotaSettings{ setting.Quota = setting.QuotaSettings{
Enabled: true, Enabled: true,
...@@ -39,6 +36,12 @@ func TestMiddlewareQuota(t *testing.T) { ...@@ -39,6 +36,12 @@ func TestMiddlewareQuota(t *testing.T) {
}, },
} }
fakeAuthTokenService := newFakeUserAuthTokenService()
qs := &quota.QuotaService{
AuthTokenService: fakeAuthTokenService,
}
QuotaFn := Quota(qs)
middlewareScenario("with user not logged in", func(sc *scenarioContext) { middlewareScenario("with user not logged in", func(sc *scenarioContext) {
bus.AddHandler("globalQuota", func(query *m.GetGlobalQuotaByTargetQuery) error { bus.AddHandler("globalQuota", func(query *m.GetGlobalQuotaByTargetQuery) error {
query.Result = &m.GlobalQuotaDTO{ query.Result = &m.GlobalQuotaDTO{
...@@ -48,26 +51,30 @@ func TestMiddlewareQuota(t *testing.T) { ...@@ -48,26 +51,30 @@ func TestMiddlewareQuota(t *testing.T) {
} }
return nil return nil
}) })
Convey("global quota not reached", func() { Convey("global quota not reached", func() {
sc.m.Get("/user", Quota("user"), sc.defaultHandler) sc.m.Get("/user", QuotaFn("user"), sc.defaultHandler)
sc.fakeReq("GET", "/user").exec() sc.fakeReq("GET", "/user").exec()
So(sc.resp.Code, ShouldEqual, 200) So(sc.resp.Code, ShouldEqual, 200)
}) })
Convey("global quota reached", func() { Convey("global quota reached", func() {
setting.Quota.Global.User = 4 setting.Quota.Global.User = 4
sc.m.Get("/user", Quota("user"), sc.defaultHandler) sc.m.Get("/user", QuotaFn("user"), sc.defaultHandler)
sc.fakeReq("GET", "/user").exec() sc.fakeReq("GET", "/user").exec()
So(sc.resp.Code, ShouldEqual, 403) So(sc.resp.Code, ShouldEqual, 403)
}) })
Convey("global session quota not reached", func() { Convey("global session quota not reached", func() {
setting.Quota.Global.Session = 10 setting.Quota.Global.Session = 10
sc.m.Get("/user", Quota("session"), sc.defaultHandler) sc.m.Get("/user", QuotaFn("session"), sc.defaultHandler)
sc.fakeReq("GET", "/user").exec() sc.fakeReq("GET", "/user").exec()
So(sc.resp.Code, ShouldEqual, 200) So(sc.resp.Code, ShouldEqual, 200)
}) })
Convey("global session quota reached", func() { Convey("global session quota reached", func() {
setting.Quota.Global.Session = 1 setting.Quota.Global.Session = 1
sc.m.Get("/user", Quota("session"), sc.defaultHandler) sc.m.Get("/user", QuotaFn("session"), sc.defaultHandler)
sc.fakeReq("GET", "/user").exec() sc.fakeReq("GET", "/user").exec()
So(sc.resp.Code, ShouldEqual, 403) So(sc.resp.Code, ShouldEqual, 403)
}) })
...@@ -95,6 +102,7 @@ func TestMiddlewareQuota(t *testing.T) { ...@@ -95,6 +102,7 @@ func TestMiddlewareQuota(t *testing.T) {
} }
return nil return nil
}) })
bus.AddHandler("userQuota", func(query *m.GetUserQuotaByTargetQuery) error { bus.AddHandler("userQuota", func(query *m.GetUserQuotaByTargetQuery) error {
query.Result = &m.UserQuotaDTO{ query.Result = &m.UserQuotaDTO{
Target: query.Target, Target: query.Target,
...@@ -103,6 +111,7 @@ func TestMiddlewareQuota(t *testing.T) { ...@@ -103,6 +111,7 @@ func TestMiddlewareQuota(t *testing.T) {
} }
return nil return nil
}) })
bus.AddHandler("orgQuota", func(query *m.GetOrgQuotaByTargetQuery) error { bus.AddHandler("orgQuota", func(query *m.GetOrgQuotaByTargetQuery) error {
query.Result = &m.OrgQuotaDTO{ query.Result = &m.OrgQuotaDTO{
Target: query.Target, Target: query.Target,
...@@ -111,45 +120,49 @@ func TestMiddlewareQuota(t *testing.T) { ...@@ -111,45 +120,49 @@ func TestMiddlewareQuota(t *testing.T) {
} }
return nil return nil
}) })
Convey("global datasource quota reached", func() { Convey("global datasource quota reached", func() {
setting.Quota.Global.DataSource = 4 setting.Quota.Global.DataSource = 4
sc.m.Get("/ds", Quota("data_source"), sc.defaultHandler) sc.m.Get("/ds", QuotaFn("data_source"), sc.defaultHandler)
sc.fakeReq("GET", "/ds").exec() sc.fakeReq("GET", "/ds").exec()
So(sc.resp.Code, ShouldEqual, 403) So(sc.resp.Code, ShouldEqual, 403)
}) })
Convey("user Org quota not reached", func() { Convey("user Org quota not reached", func() {
setting.Quota.User.Org = 5 setting.Quota.User.Org = 5
sc.m.Get("/org", Quota("org"), sc.defaultHandler) sc.m.Get("/org", QuotaFn("org"), sc.defaultHandler)
sc.fakeReq("GET", "/org").exec() sc.fakeReq("GET", "/org").exec()
So(sc.resp.Code, ShouldEqual, 200) So(sc.resp.Code, ShouldEqual, 200)
}) })
Convey("user Org quota reached", func() { Convey("user Org quota reached", func() {
setting.Quota.User.Org = 4 setting.Quota.User.Org = 4
sc.m.Get("/org", Quota("org"), sc.defaultHandler) sc.m.Get("/org", QuotaFn("org"), sc.defaultHandler)
sc.fakeReq("GET", "/org").exec() sc.fakeReq("GET", "/org").exec()
So(sc.resp.Code, ShouldEqual, 403) So(sc.resp.Code, ShouldEqual, 403)
}) })
Convey("org dashboard quota not reached", func() { Convey("org dashboard quota not reached", func() {
setting.Quota.Org.Dashboard = 10 setting.Quota.Org.Dashboard = 10
sc.m.Get("/dashboard", Quota("dashboard"), sc.defaultHandler) sc.m.Get("/dashboard", QuotaFn("dashboard"), sc.defaultHandler)
sc.fakeReq("GET", "/dashboard").exec() sc.fakeReq("GET", "/dashboard").exec()
So(sc.resp.Code, ShouldEqual, 200) So(sc.resp.Code, ShouldEqual, 200)
}) })
Convey("org dashboard quota reached", func() { Convey("org dashboard quota reached", func() {
setting.Quota.Org.Dashboard = 4 setting.Quota.Org.Dashboard = 4
sc.m.Get("/dashboard", Quota("dashboard"), sc.defaultHandler) sc.m.Get("/dashboard", QuotaFn("dashboard"), sc.defaultHandler)
sc.fakeReq("GET", "/dashboard").exec() sc.fakeReq("GET", "/dashboard").exec()
So(sc.resp.Code, ShouldEqual, 403) So(sc.resp.Code, ShouldEqual, 403)
}) })
Convey("org dashboard quota reached but quotas disabled", func() { Convey("org dashboard quota reached but quotas disabled", func() {
setting.Quota.Org.Dashboard = 4 setting.Quota.Org.Dashboard = 4
setting.Quota.Enabled = false setting.Quota.Enabled = false
sc.m.Get("/dashboard", Quota("dashboard"), sc.defaultHandler) sc.m.Get("/dashboard", QuotaFn("dashboard"), sc.defaultHandler)
sc.fakeReq("GET", "/dashboard").exec() sc.fakeReq("GET", "/dashboard").exec()
So(sc.resp.Code, ShouldEqual, 200) So(sc.resp.Code, ShouldEqual, 200)
}) })
}) })
}) })
} }
...@@ -29,4 +29,5 @@ type UserTokenService interface { ...@@ -29,4 +29,5 @@ type UserTokenService interface {
LookupToken(unhashedToken string) (*UserToken, error) LookupToken(unhashedToken string) (*UserToken, error)
TryRotateToken(token *UserToken, clientIP, userAgent string) (bool, error) TryRotateToken(token *UserToken, clientIP, userAgent string) (bool, error)
RevokeToken(token *UserToken) error RevokeToken(token *UserToken) error
ActiveTokenCount() (int64, error)
} }
...@@ -35,6 +35,13 @@ func (s *UserAuthTokenService) Init() error { ...@@ -35,6 +35,13 @@ func (s *UserAuthTokenService) Init() error {
return nil return nil
} }
func (s *UserAuthTokenService) ActiveTokenCount() (int64, error) {
var model userAuthToken
count, err := s.SQLStore.NewSession().Where(`created_at > ? AND rotated_at > ?`, s.createdAfterParam(), s.rotatedAfterParam()).Count(&model)
return count, err
}
func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*models.UserToken, error) { func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*models.UserToken, error) {
clientIP = util.ParseIPAddress(clientIP) clientIP = util.ParseIPAddress(clientIP)
token, err := util.RandomHex(16) token, err := util.RandomHex(16)
...@@ -79,13 +86,8 @@ func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserTo ...@@ -79,13 +86,8 @@ func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserTo
s.log.Debug("looking up token", "unhashed", unhashedToken, "hashed", hashedToken) s.log.Debug("looking up token", "unhashed", unhashedToken, "hashed", hashedToken)
} }
tokenMaxLifetime := time.Duration(s.Cfg.LoginMaxLifetimeDays) * 24 * time.Hour
tokenMaxInactiveLifetime := time.Duration(s.Cfg.LoginMaxInactiveLifetimeDays) * 24 * time.Hour
createdAfter := getTime().Add(-tokenMaxLifetime).Unix()
rotatedAfter := getTime().Add(-tokenMaxInactiveLifetime).Unix()
var model userAuthToken var model userAuthToken
exists, err := s.SQLStore.NewSession().Where("(auth_token = ? OR prev_auth_token = ?) AND created_at > ? AND rotated_at > ?", hashedToken, hashedToken, createdAfter, rotatedAfter).Get(&model) exists, err := s.SQLStore.NewSession().Where("(auth_token = ? OR prev_auth_token = ?) AND created_at > ? AND rotated_at > ?", hashedToken, hashedToken, s.createdAfterParam(), s.rotatedAfterParam()).Get(&model)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -219,6 +221,16 @@ func (s *UserAuthTokenService) RevokeToken(token *models.UserToken) error { ...@@ -219,6 +221,16 @@ func (s *UserAuthTokenService) RevokeToken(token *models.UserToken) error {
return nil return nil
} }
func (s *UserAuthTokenService) createdAfterParam() int64 {
tokenMaxLifetime := time.Duration(s.Cfg.LoginMaxLifetimeDays) * 24 * time.Hour
return getTime().Add(-tokenMaxLifetime).Unix()
}
func (s *UserAuthTokenService) rotatedAfterParam() int64 {
tokenMaxInactiveLifetime := time.Duration(s.Cfg.LoginMaxInactiveLifetimeDays) * 24 * time.Hour
return getTime().Add(-tokenMaxInactiveLifetime).Unix()
}
func hashToken(token string) string { func hashToken(token string) string {
hashBytes := sha256.Sum256([]byte(token + setting.SecretKey)) hashBytes := sha256.Sum256([]byte(token + setting.SecretKey))
return hex.EncodeToString(hashBytes[:]) return hex.EncodeToString(hashBytes[:])
......
...@@ -31,6 +31,12 @@ func TestUserAuthToken(t *testing.T) { ...@@ -31,6 +31,12 @@ func TestUserAuthToken(t *testing.T) {
So(userToken, ShouldNotBeNil) So(userToken, ShouldNotBeNil)
So(userToken.AuthTokenSeen, ShouldBeFalse) So(userToken.AuthTokenSeen, ShouldBeFalse)
Convey("Can count active tokens", func() {
count, err := userAuthTokenService.ActiveTokenCount()
So(err, ShouldBeNil)
So(count, ShouldEqual, 1)
})
Convey("When lookup unhashed token should return user auth token", func() { Convey("When lookup unhashed token should return user auth token", func() {
userToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken) userToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
So(err, ShouldBeNil) So(err, ShouldBeNil)
...@@ -114,6 +120,12 @@ func TestUserAuthToken(t *testing.T) { ...@@ -114,6 +120,12 @@ func TestUserAuthToken(t *testing.T) {
notGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken) notGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
So(err, ShouldEqual, models.ErrUserTokenNotFound) So(err, ShouldEqual, models.ErrUserTokenNotFound)
So(notGood, ShouldBeNil) So(notGood, ShouldBeNil)
Convey("should not find active token when expired", func() {
count, err := userAuthTokenService.ActiveTokenCount()
So(err, ShouldBeNil)
So(count, ShouldEqual, 0)
})
}) })
Convey("when rotated_at is 5 days ago and created_at is 29 days and 23:59:59 ago should not find token", func() { Convey("when rotated_at is 5 days ago and created_at is 29 days and 23:59:59 ago should not find token", func() {
......
...@@ -3,11 +3,23 @@ package quota ...@@ -3,11 +3,23 @@ package quota
import ( import (
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
m "github.com/grafana/grafana/pkg/models" m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/session" "github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
) )
func QuotaReached(c *m.ReqContext, target string) (bool, error) { func init() {
registry.RegisterService(&QuotaService{})
}
type QuotaService struct {
AuthTokenService m.UserTokenService `inject:""`
}
func (qs *QuotaService) Init() error {
return nil
}
func (qs *QuotaService) QuotaReached(c *m.ReqContext, target string) (bool, error) {
if !setting.Quota.Enabled { if !setting.Quota.Enabled {
return false, nil return false, nil
} }
...@@ -30,7 +42,12 @@ func QuotaReached(c *m.ReqContext, target string) (bool, error) { ...@@ -30,7 +42,12 @@ func QuotaReached(c *m.ReqContext, target string) (bool, error) {
return true, nil return true, nil
} }
if target == "session" { if target == "session" {
usedSessions := session.GetSessionCount()
usedSessions, err := qs.AuthTokenService.ActiveTokenCount()
if err != nil {
return false, err
}
if int64(usedSessions) > scope.DefaultLimit { if int64(usedSessions) > scope.DefaultLimit {
c.Logger.Debug("Sessions limit reached", "active", usedSessions, "limit", scope.DefaultLimit) c.Logger.Debug("Sessions limit reached", "active", usedSessions, "limit", scope.DefaultLimit)
return true, nil return true, nil
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment