Commit dcec61e1 by Carl Bergquist Committed by GitHub

Merge pull request #15378 from grafana/auth_token_quotas

use authTokenService for session quotas restrictions
parents 8bd7e5a5 e163aadf
......@@ -16,7 +16,7 @@ func (hs *HTTPServer) registerRoutes() {
reqOrgAdmin := middleware.ReqOrgAdmin
redirectFromLegacyDashboardURL := middleware.RedirectFromLegacyDashboardURL()
redirectFromLegacyDashboardSoloURL := middleware.RedirectFromLegacyDashboardSoloURL()
quota := middleware.Quota
quota := middleware.Quota(hs.QuotaService)
bind := binding.Bind
r := hs.RouteRegister
......@@ -286,7 +286,7 @@ func (hs *HTTPServer) registerRoutes() {
dashboardRoute.Post("/calculate-diff", bind(dtos.CalculateDiffOptions{}), Wrap(CalculateDashboardDiff))
dashboardRoute.Post("/db", bind(m.SaveDashboardCommand{}), Wrap(PostDashboard))
dashboardRoute.Post("/db", bind(m.SaveDashboardCommand{}), Wrap(hs.PostDashboard))
dashboardRoute.Get("/home", Wrap(GetHomeDashboard))
dashboardRoute.Get("/tags", GetDashboardTags)
dashboardRoute.Post("/import", bind(dtos.ImportDashboardCommand{}), Wrap(ImportDashboard))
......@@ -294,7 +294,7 @@ func (hs *HTTPServer) registerRoutes() {
dashboardRoute.Group("/id/:dashboardId", func(dashIdRoute routing.RouteRegister) {
dashIdRoute.Get("/versions", Wrap(GetDashboardVersions))
dashIdRoute.Get("/versions/:id", Wrap(GetDashboardVersion))
dashIdRoute.Post("/restore", bind(dtos.RestoreDashboardVersionCommand{}), Wrap(RestoreDashboardVersion))
dashIdRoute.Post("/restore", bind(dtos.RestoreDashboardVersionCommand{}), Wrap(hs.RestoreDashboardVersion))
dashIdRoute.Group("/permissions", func(dashboardPermissionRoute routing.RouteRegister) {
dashboardPermissionRoute.Get("/", Wrap(GetDashboardPermissionList))
......
......@@ -18,7 +18,6 @@ import (
m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/services/guardian"
"github.com/grafana/grafana/pkg/services/quota"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
)
......@@ -208,14 +207,14 @@ func DeleteDashboardByUID(c *m.ReqContext) Response {
})
}
func PostDashboard(c *m.ReqContext, cmd m.SaveDashboardCommand) Response {
func (hs *HTTPServer) PostDashboard(c *m.ReqContext, cmd m.SaveDashboardCommand) Response {
cmd.OrgId = c.OrgId
cmd.UserId = c.UserId
dash := cmd.GetDashboardModel()
if dash.Id == 0 && dash.Uid == "" {
limitReached, err := quota.QuotaReached(c, "dashboard")
limitReached, err := hs.QuotaService.QuotaReached(c, "dashboard")
if err != nil {
return Error(500, "failed to get quota", err)
}
......@@ -463,7 +462,7 @@ func CalculateDashboardDiff(c *m.ReqContext, apiOptions dtos.CalculateDiffOption
}
// RestoreDashboardVersion restores a dashboard to the given version.
func RestoreDashboardVersion(c *m.ReqContext, apiCmd dtos.RestoreDashboardVersionCommand) Response {
func (hs *HTTPServer) RestoreDashboardVersion(c *m.ReqContext, apiCmd dtos.RestoreDashboardVersionCommand) Response {
dash, rsp := getDashboardHelper(c.OrgId, "", c.ParamsInt64(":dashboardId"), "")
if rsp != nil {
return rsp
......@@ -490,7 +489,7 @@ func RestoreDashboardVersion(c *m.ReqContext, apiCmd dtos.RestoreDashboardVersio
saveCmd.Dashboard.Set("uid", dash.Uid)
saveCmd.Message = fmt.Sprintf("Restored from version %d", version.Version)
return PostDashboard(c, saveCmd)
return hs.PostDashboard(c, saveCmd)
}
func GetDashboardTags(c *m.ReqContext) {
......
......@@ -881,12 +881,16 @@ func postDashboardScenario(desc string, url string, routePattern string, mock *d
Convey(desc+" "+url, func() {
defer bus.ClearBusHandlers()
hs := HTTPServer{
Bus: bus.GetBus(),
}
sc := setupScenarioContext(url)
sc.defaultHandler = Wrap(func(c *m.ReqContext) Response {
sc.context = c
sc.context.SignedInUser = &m.SignedInUser{OrgId: cmd.OrgId, UserId: cmd.UserId}
return PostDashboard(c, cmd)
return hs.PostDashboard(c, cmd)
})
origNewDashboardService := dashboards.NewService
......
......@@ -24,6 +24,7 @@ import (
"github.com/grafana/grafana/pkg/services/cache"
"github.com/grafana/grafana/pkg/services/datasources"
"github.com/grafana/grafana/pkg/services/hooks"
"github.com/grafana/grafana/pkg/services/quota"
"github.com/grafana/grafana/pkg/services/rendering"
"github.com/grafana/grafana/pkg/services/session"
"github.com/grafana/grafana/pkg/setting"
......@@ -55,6 +56,7 @@ type HTTPServer struct {
CacheService *cache.CacheService `inject:""`
DatasourceCache datasources.CacheService `inject:""`
AuthTokenService models.UserTokenService `inject:""`
QuotaService *quota.QuotaService `inject:""`
}
func (hs *HTTPServer) Init() error {
......
......@@ -4,18 +4,30 @@ import (
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/log"
m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/services/quota"
)
func init() {
bus.AddHandler("auth", UpsertUser)
registry.RegisterService(&LoginService{})
}
var (
logger = log.New("login.ext_user")
)
func UpsertUser(cmd *m.UpsertUserCommand) error {
type LoginService struct {
Bus bus.Bus `inject:""`
QuotaService *quota.QuotaService `inject:""`
}
func (ls *LoginService) Init() error {
ls.Bus.AddHandler(ls.UpsertUser)
return nil
}
func (ls *LoginService) UpsertUser(cmd *m.UpsertUserCommand) error {
extUser := cmd.ExternalUser
userQuery := &m.GetUserByAuthInfoQuery{
......@@ -37,7 +49,7 @@ func UpsertUser(cmd *m.UpsertUserCommand) error {
return ErrInvalidCredentials
}
limitReached, err := quota.QuotaReached(cmd.ReqContext, "user")
limitReached, err := ls.QuotaService.QuotaReached(cmd.ReqContext, "user")
if err != nil {
log.Warn("Error getting user quota. error: %v", err)
return ErrGettingUserQuota
......@@ -57,7 +69,7 @@ func UpsertUser(cmd *m.UpsertUserCommand) error {
AuthModule: extUser.AuthModule,
AuthId: extUser.AuthId,
}
if err := bus.Dispatch(cmd2); err != nil {
if err := ls.Bus.Dispatch(cmd2); err != nil {
return err
}
}
......@@ -78,12 +90,12 @@ func UpsertUser(cmd *m.UpsertUserCommand) error {
// Sync isGrafanaAdmin permission
if extUser.IsGrafanaAdmin != nil && *extUser.IsGrafanaAdmin != cmd.Result.IsAdmin {
if err := bus.Dispatch(&m.UpdateUserPermissionsCommand{UserId: cmd.Result.Id, IsGrafanaAdmin: *extUser.IsGrafanaAdmin}); err != nil {
if err := ls.Bus.Dispatch(&m.UpdateUserPermissionsCommand{UserId: cmd.Result.Id, IsGrafanaAdmin: *extUser.IsGrafanaAdmin}); err != nil {
return err
}
}
err = bus.Dispatch(&m.SyncTeamsCommand{
err = ls.Bus.Dispatch(&m.SyncTeamsCommand{
User: cmd.Result,
ExternalUser: extUser,
})
......
......@@ -395,8 +395,11 @@ func ldapAutherScenario(desc string, fn scenarioFunc) {
defer bus.ClearBusHandlers()
sc := &scenarioContext{}
loginService := &LoginService{
Bus: bus.GetBus(),
}
bus.AddHandler("test", UpsertUser)
bus.AddHandler("test", loginService.UpsertUser)
bus.AddHandlerCtx("test", func(ctx context.Context, cmd *m.SyncTeamsCommand) error {
return nil
......
......@@ -682,6 +682,7 @@ type fakeUserAuthTokenService struct {
tryRotateTokenProvider func(token *m.UserToken, clientIP, userAgent string) (bool, error)
lookupTokenProvider func(unhashedToken string) (*m.UserToken, error)
revokeTokenProvider func(token *m.UserToken) error
activeAuthTokenCount func() (int64, error)
}
func newFakeUserAuthTokenService() *fakeUserAuthTokenService {
......@@ -704,6 +705,9 @@ func newFakeUserAuthTokenService() *fakeUserAuthTokenService {
revokeTokenProvider: func(token *m.UserToken) error {
return nil
},
activeAuthTokenCount: func() (int64, error) {
return 10, nil
},
}
}
......@@ -722,3 +726,7 @@ func (s *fakeUserAuthTokenService) TryRotateToken(token *m.UserToken, clientIP,
func (s *fakeUserAuthTokenService) RevokeToken(token *m.UserToken) error {
return s.revokeTokenProvider(token)
}
func (s *fakeUserAuthTokenService) ActiveTokenCount() (int64, error) {
return s.activeAuthTokenCount()
}
......@@ -9,16 +9,20 @@ import (
"github.com/grafana/grafana/pkg/services/quota"
)
func Quota(target string) macaron.Handler {
return func(c *m.ReqContext) {
limitReached, err := quota.QuotaReached(c, target)
if err != nil {
c.JsonApiErr(500, "failed to get quota", err)
return
}
if limitReached {
c.JsonApiErr(403, fmt.Sprintf("%s Quota reached", target), nil)
return
// Quota returns a function that returns a function used to call quotaservice based on target name
func Quota(quotaService *quota.QuotaService) func(target string) macaron.Handler {
//https://open.spotify.com/track/7bZSoBEAEEUsGEuLOf94Jm?si=T1Tdju5qRSmmR0zph_6RBw fuuuuunky
return func(target string) macaron.Handler {
return func(c *m.ReqContext) {
limitReached, err := quotaService.QuotaReached(c, target)
if err != nil {
c.JsonApiErr(500, "failed to get quota", err)
return
}
if limitReached {
c.JsonApiErr(403, fmt.Sprintf("%s Quota reached", target), nil)
return
}
}
}
}
......@@ -3,9 +3,10 @@ package middleware
import (
"testing"
"github.com/grafana/grafana/pkg/services/quota"
"github.com/grafana/grafana/pkg/bus"
m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/session"
"github.com/grafana/grafana/pkg/setting"
. "github.com/smartystreets/goconvey/convey"
)
......@@ -13,10 +14,6 @@ import (
func TestMiddlewareQuota(t *testing.T) {
Convey("Given the grafana quota middleware", t, func() {
session.GetSessionCount = func() int {
return 4
}
setting.AnonymousEnabled = false
setting.Quota = setting.QuotaSettings{
Enabled: true,
......@@ -39,6 +36,12 @@ func TestMiddlewareQuota(t *testing.T) {
},
}
fakeAuthTokenService := newFakeUserAuthTokenService()
qs := &quota.QuotaService{
AuthTokenService: fakeAuthTokenService,
}
QuotaFn := Quota(qs)
middlewareScenario("with user not logged in", func(sc *scenarioContext) {
bus.AddHandler("globalQuota", func(query *m.GetGlobalQuotaByTargetQuery) error {
query.Result = &m.GlobalQuotaDTO{
......@@ -48,26 +51,30 @@ func TestMiddlewareQuota(t *testing.T) {
}
return nil
})
Convey("global quota not reached", func() {
sc.m.Get("/user", Quota("user"), sc.defaultHandler)
sc.m.Get("/user", QuotaFn("user"), sc.defaultHandler)
sc.fakeReq("GET", "/user").exec()
So(sc.resp.Code, ShouldEqual, 200)
})
Convey("global quota reached", func() {
setting.Quota.Global.User = 4
sc.m.Get("/user", Quota("user"), sc.defaultHandler)
sc.m.Get("/user", QuotaFn("user"), sc.defaultHandler)
sc.fakeReq("GET", "/user").exec()
So(sc.resp.Code, ShouldEqual, 403)
})
Convey("global session quota not reached", func() {
setting.Quota.Global.Session = 10
sc.m.Get("/user", Quota("session"), sc.defaultHandler)
sc.m.Get("/user", QuotaFn("session"), sc.defaultHandler)
sc.fakeReq("GET", "/user").exec()
So(sc.resp.Code, ShouldEqual, 200)
})
Convey("global session quota reached", func() {
setting.Quota.Global.Session = 1
sc.m.Get("/user", Quota("session"), sc.defaultHandler)
sc.m.Get("/user", QuotaFn("session"), sc.defaultHandler)
sc.fakeReq("GET", "/user").exec()
So(sc.resp.Code, ShouldEqual, 403)
})
......@@ -95,6 +102,7 @@ func TestMiddlewareQuota(t *testing.T) {
}
return nil
})
bus.AddHandler("userQuota", func(query *m.GetUserQuotaByTargetQuery) error {
query.Result = &m.UserQuotaDTO{
Target: query.Target,
......@@ -103,6 +111,7 @@ func TestMiddlewareQuota(t *testing.T) {
}
return nil
})
bus.AddHandler("orgQuota", func(query *m.GetOrgQuotaByTargetQuery) error {
query.Result = &m.OrgQuotaDTO{
Target: query.Target,
......@@ -111,45 +120,49 @@ func TestMiddlewareQuota(t *testing.T) {
}
return nil
})
Convey("global datasource quota reached", func() {
setting.Quota.Global.DataSource = 4
sc.m.Get("/ds", Quota("data_source"), sc.defaultHandler)
sc.m.Get("/ds", QuotaFn("data_source"), sc.defaultHandler)
sc.fakeReq("GET", "/ds").exec()
So(sc.resp.Code, ShouldEqual, 403)
})
Convey("user Org quota not reached", func() {
setting.Quota.User.Org = 5
sc.m.Get("/org", Quota("org"), sc.defaultHandler)
sc.m.Get("/org", QuotaFn("org"), sc.defaultHandler)
sc.fakeReq("GET", "/org").exec()
So(sc.resp.Code, ShouldEqual, 200)
})
Convey("user Org quota reached", func() {
setting.Quota.User.Org = 4
sc.m.Get("/org", Quota("org"), sc.defaultHandler)
sc.m.Get("/org", QuotaFn("org"), sc.defaultHandler)
sc.fakeReq("GET", "/org").exec()
So(sc.resp.Code, ShouldEqual, 403)
})
Convey("org dashboard quota not reached", func() {
setting.Quota.Org.Dashboard = 10
sc.m.Get("/dashboard", Quota("dashboard"), sc.defaultHandler)
sc.m.Get("/dashboard", QuotaFn("dashboard"), sc.defaultHandler)
sc.fakeReq("GET", "/dashboard").exec()
So(sc.resp.Code, ShouldEqual, 200)
})
Convey("org dashboard quota reached", func() {
setting.Quota.Org.Dashboard = 4
sc.m.Get("/dashboard", Quota("dashboard"), sc.defaultHandler)
sc.m.Get("/dashboard", QuotaFn("dashboard"), sc.defaultHandler)
sc.fakeReq("GET", "/dashboard").exec()
So(sc.resp.Code, ShouldEqual, 403)
})
Convey("org dashboard quota reached but quotas disabled", func() {
setting.Quota.Org.Dashboard = 4
setting.Quota.Enabled = false
sc.m.Get("/dashboard", Quota("dashboard"), sc.defaultHandler)
sc.m.Get("/dashboard", QuotaFn("dashboard"), sc.defaultHandler)
sc.fakeReq("GET", "/dashboard").exec()
So(sc.resp.Code, ShouldEqual, 200)
})
})
})
}
......@@ -29,4 +29,5 @@ type UserTokenService interface {
LookupToken(unhashedToken string) (*UserToken, error)
TryRotateToken(token *UserToken, clientIP, userAgent string) (bool, error)
RevokeToken(token *UserToken) error
ActiveTokenCount() (int64, error)
}
......@@ -35,6 +35,13 @@ func (s *UserAuthTokenService) Init() error {
return nil
}
func (s *UserAuthTokenService) ActiveTokenCount() (int64, error) {
var model userAuthToken
count, err := s.SQLStore.NewSession().Where(`created_at > ? AND rotated_at > ?`, s.createdAfterParam(), s.rotatedAfterParam()).Count(&model)
return count, err
}
func (s *UserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*models.UserToken, error) {
clientIP = util.ParseIPAddress(clientIP)
token, err := util.RandomHex(16)
......@@ -79,13 +86,8 @@ func (s *UserAuthTokenService) LookupToken(unhashedToken string) (*models.UserTo
s.log.Debug("looking up token", "unhashed", unhashedToken, "hashed", hashedToken)
}
tokenMaxLifetime := time.Duration(s.Cfg.LoginMaxLifetimeDays) * 24 * time.Hour
tokenMaxInactiveLifetime := time.Duration(s.Cfg.LoginMaxInactiveLifetimeDays) * 24 * time.Hour
createdAfter := getTime().Add(-tokenMaxLifetime).Unix()
rotatedAfter := getTime().Add(-tokenMaxInactiveLifetime).Unix()
var model userAuthToken
exists, err := s.SQLStore.NewSession().Where("(auth_token = ? OR prev_auth_token = ?) AND created_at > ? AND rotated_at > ?", hashedToken, hashedToken, createdAfter, rotatedAfter).Get(&model)
exists, err := s.SQLStore.NewSession().Where("(auth_token = ? OR prev_auth_token = ?) AND created_at > ? AND rotated_at > ?", hashedToken, hashedToken, s.createdAfterParam(), s.rotatedAfterParam()).Get(&model)
if err != nil {
return nil, err
}
......@@ -219,6 +221,16 @@ func (s *UserAuthTokenService) RevokeToken(token *models.UserToken) error {
return nil
}
func (s *UserAuthTokenService) createdAfterParam() int64 {
tokenMaxLifetime := time.Duration(s.Cfg.LoginMaxLifetimeDays) * 24 * time.Hour
return getTime().Add(-tokenMaxLifetime).Unix()
}
func (s *UserAuthTokenService) rotatedAfterParam() int64 {
tokenMaxInactiveLifetime := time.Duration(s.Cfg.LoginMaxInactiveLifetimeDays) * 24 * time.Hour
return getTime().Add(-tokenMaxInactiveLifetime).Unix()
}
func hashToken(token string) string {
hashBytes := sha256.Sum256([]byte(token + setting.SecretKey))
return hex.EncodeToString(hashBytes[:])
......
......@@ -31,6 +31,12 @@ func TestUserAuthToken(t *testing.T) {
So(userToken, ShouldNotBeNil)
So(userToken.AuthTokenSeen, ShouldBeFalse)
Convey("Can count active tokens", func() {
count, err := userAuthTokenService.ActiveTokenCount()
So(err, ShouldBeNil)
So(count, ShouldEqual, 1)
})
Convey("When lookup unhashed token should return user auth token", func() {
userToken, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
So(err, ShouldBeNil)
......@@ -114,6 +120,12 @@ func TestUserAuthToken(t *testing.T) {
notGood, err := userAuthTokenService.LookupToken(userToken.UnhashedToken)
So(err, ShouldEqual, models.ErrUserTokenNotFound)
So(notGood, ShouldBeNil)
Convey("should not find active token when expired", func() {
count, err := userAuthTokenService.ActiveTokenCount()
So(err, ShouldBeNil)
So(count, ShouldEqual, 0)
})
})
Convey("when rotated_at is 5 days ago and created_at is 29 days and 23:59:59 ago should not find token", func() {
......
......@@ -3,11 +3,23 @@ package quota
import (
"github.com/grafana/grafana/pkg/bus"
m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/session"
"github.com/grafana/grafana/pkg/registry"
"github.com/grafana/grafana/pkg/setting"
)
func QuotaReached(c *m.ReqContext, target string) (bool, error) {
func init() {
registry.RegisterService(&QuotaService{})
}
type QuotaService struct {
AuthTokenService m.UserTokenService `inject:""`
}
func (qs *QuotaService) Init() error {
return nil
}
func (qs *QuotaService) QuotaReached(c *m.ReqContext, target string) (bool, error) {
if !setting.Quota.Enabled {
return false, nil
}
......@@ -30,7 +42,12 @@ func QuotaReached(c *m.ReqContext, target string) (bool, error) {
return true, nil
}
if target == "session" {
usedSessions := session.GetSessionCount()
usedSessions, err := qs.AuthTokenService.ActiveTokenCount()
if err != nil {
return false, err
}
if int64(usedSessions) > scope.DefaultLimit {
c.Logger.Debug("Sessions limit reached", "active", usedSessions, "limit", scope.DefaultLimit)
return true, nil
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment