Commit 8029e485 by Marcus Efraimsson

support get user tokens/revoke all user tokens in UserTokenService

parent 878b41d6
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
msession "github.com/go-macaron/session" msession "github.com/go-macaron/session"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
m "github.com/grafana/grafana/pkg/models" m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/session" "github.com/grafana/grafana/pkg/services/session"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/util"
...@@ -155,7 +156,7 @@ func TestMiddlewareContext(t *testing.T) { ...@@ -155,7 +156,7 @@ func TestMiddlewareContext(t *testing.T) {
return nil return nil
}) })
sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
return &m.UserToken{ return &m.UserToken{
UserId: 12, UserId: 12,
UnhashedToken: unhashedToken, UnhashedToken: unhashedToken,
...@@ -184,14 +185,14 @@ func TestMiddlewareContext(t *testing.T) { ...@@ -184,14 +185,14 @@ func TestMiddlewareContext(t *testing.T) {
return nil return nil
}) })
sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
return &m.UserToken{ return &m.UserToken{
UserId: 12, UserId: 12,
UnhashedToken: "", UnhashedToken: "",
}, nil }, nil
} }
sc.userAuthTokenService.tryRotateTokenProvider = func(userToken *m.UserToken, clientIP, userAgent string) (bool, error) { sc.userAuthTokenService.TryRotateTokenProvider = func(userToken *m.UserToken, clientIP, userAgent string) (bool, error) {
userToken.UnhashedToken = "rotated" userToken.UnhashedToken = "rotated"
return true, nil return true, nil
} }
...@@ -226,7 +227,7 @@ func TestMiddlewareContext(t *testing.T) { ...@@ -226,7 +227,7 @@ func TestMiddlewareContext(t *testing.T) {
middlewareScenario("Invalid/expired auth token in cookie", func(sc *scenarioContext) { middlewareScenario("Invalid/expired auth token in cookie", func(sc *scenarioContext) {
sc.withTokenSessionCookie("token") sc.withTokenSessionCookie("token")
sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
return nil, m.ErrUserTokenNotFound return nil, m.ErrUserTokenNotFound
} }
...@@ -562,7 +563,7 @@ func middlewareScenario(desc string, fn scenarioFunc) { ...@@ -562,7 +563,7 @@ func middlewareScenario(desc string, fn scenarioFunc) {
})) }))
session.Init(&msession.Options{}, 0) session.Init(&msession.Options{}, 0)
sc.userAuthTokenService = newFakeUserAuthTokenService() sc.userAuthTokenService = auth.NewFakeUserAuthTokenService()
sc.m.Use(GetContextHandler(sc.userAuthTokenService)) sc.m.Use(GetContextHandler(sc.userAuthTokenService))
// mock out gc goroutine // mock out gc goroutine
session.StartSessionGC = func() {} session.StartSessionGC = func() {}
...@@ -595,7 +596,7 @@ type scenarioContext struct { ...@@ -595,7 +596,7 @@ type scenarioContext struct {
handlerFunc handlerFunc handlerFunc handlerFunc
defaultHandler macaron.Handler defaultHandler macaron.Handler
url string url string
userAuthTokenService *fakeUserAuthTokenService userAuthTokenService *auth.FakeUserAuthTokenService
req *http.Request req *http.Request
} }
...@@ -676,57 +677,3 @@ func (sc *scenarioContext) exec() { ...@@ -676,57 +677,3 @@ func (sc *scenarioContext) exec() {
type scenarioFunc func(c *scenarioContext) type scenarioFunc func(c *scenarioContext)
type handlerFunc func(c *m.ReqContext) type handlerFunc func(c *m.ReqContext)
type fakeUserAuthTokenService struct {
createTokenProvider func(userId int64, clientIP, userAgent string) (*m.UserToken, error)
tryRotateTokenProvider func(token *m.UserToken, clientIP, userAgent string) (bool, error)
lookupTokenProvider func(unhashedToken string) (*m.UserToken, error)
revokeTokenProvider func(token *m.UserToken) error
activeAuthTokenCount func() (int64, error)
}
func newFakeUserAuthTokenService() *fakeUserAuthTokenService {
return &fakeUserAuthTokenService{
createTokenProvider: func(userId int64, clientIP, userAgent string) (*m.UserToken, error) {
return &m.UserToken{
UserId: 0,
UnhashedToken: "",
}, nil
},
tryRotateTokenProvider: func(token *m.UserToken, clientIP, userAgent string) (bool, error) {
return false, nil
},
lookupTokenProvider: func(unhashedToken string) (*m.UserToken, error) {
return &m.UserToken{
UserId: 0,
UnhashedToken: "",
}, nil
},
revokeTokenProvider: func(token *m.UserToken) error {
return nil
},
activeAuthTokenCount: func() (int64, error) {
return 10, nil
},
}
}
func (s *fakeUserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*m.UserToken, error) {
return s.createTokenProvider(userId, clientIP, userAgent)
}
func (s *fakeUserAuthTokenService) LookupToken(unhashedToken string) (*m.UserToken, error) {
return s.lookupTokenProvider(unhashedToken)
}
func (s *fakeUserAuthTokenService) TryRotateToken(token *m.UserToken, clientIP, userAgent string) (bool, error) {
return s.tryRotateTokenProvider(token, clientIP, userAgent)
}
func (s *fakeUserAuthTokenService) RevokeToken(token *m.UserToken) error {
return s.revokeTokenProvider(token)
}
func (s *fakeUserAuthTokenService) ActiveTokenCount() (int64, error) {
return s.activeAuthTokenCount()
}
...@@ -24,7 +24,7 @@ func TestOrgRedirectMiddleware(t *testing.T) { ...@@ -24,7 +24,7 @@ func TestOrgRedirectMiddleware(t *testing.T) {
return nil return nil
}) })
sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
return &m.UserToken{ return &m.UserToken{
UserId: 0, UserId: 0,
UnhashedToken: "", UnhashedToken: "",
...@@ -50,7 +50,7 @@ func TestOrgRedirectMiddleware(t *testing.T) { ...@@ -50,7 +50,7 @@ func TestOrgRedirectMiddleware(t *testing.T) {
return nil return nil
}) })
sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
return &m.UserToken{ return &m.UserToken{
UserId: 12, UserId: 12,
UnhashedToken: "", UnhashedToken: "",
......
...@@ -3,6 +3,7 @@ package middleware ...@@ -3,6 +3,7 @@ package middleware
import ( import (
"testing" "testing"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/services/quota" "github.com/grafana/grafana/pkg/services/quota"
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
...@@ -36,7 +37,7 @@ func TestMiddlewareQuota(t *testing.T) { ...@@ -36,7 +37,7 @@ func TestMiddlewareQuota(t *testing.T) {
}, },
} }
fakeAuthTokenService := newFakeUserAuthTokenService() fakeAuthTokenService := auth.NewFakeUserAuthTokenService()
qs := &quota.QuotaService{ qs := &quota.QuotaService{
AuthTokenService: fakeAuthTokenService, AuthTokenService: fakeAuthTokenService,
} }
...@@ -87,7 +88,7 @@ func TestMiddlewareQuota(t *testing.T) { ...@@ -87,7 +88,7 @@ func TestMiddlewareQuota(t *testing.T) {
return nil return nil
}) })
sc.userAuthTokenService.lookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) { sc.userAuthTokenService.LookupTokenProvider = func(unhashedToken string) (*m.UserToken, error) {
return &m.UserToken{ return &m.UserToken{
UserId: 12, UserId: 12,
UnhashedToken: "", UnhashedToken: "",
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"github.com/grafana/grafana/pkg/bus" "github.com/grafana/grafana/pkg/bus"
m "github.com/grafana/grafana/pkg/models" m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/auth"
"github.com/grafana/grafana/pkg/setting" "github.com/grafana/grafana/pkg/setting"
. "github.com/smartystreets/goconvey/convey" . "github.com/smartystreets/goconvey/convey"
macaron "gopkg.in/macaron.v1" macaron "gopkg.in/macaron.v1"
...@@ -62,7 +63,7 @@ func recoveryScenario(desc string, url string, fn scenarioFunc) { ...@@ -62,7 +63,7 @@ func recoveryScenario(desc string, url string, fn scenarioFunc) {
Delims: macaron.Delims{Left: "[[", Right: "]]"}, Delims: macaron.Delims{Left: "[[", Right: "]]"},
})) }))
sc.userAuthTokenService = newFakeUserAuthTokenService() sc.userAuthTokenService = auth.NewFakeUserAuthTokenService()
sc.m.Use(GetContextHandler(sc.userAuthTokenService)) sc.m.Use(GetContextHandler(sc.userAuthTokenService))
// mock out gc goroutine // mock out gc goroutine
sc.m.Use(OrgRedirect()) sc.m.Use(OrgRedirect())
......
...@@ -29,5 +29,8 @@ type UserTokenService interface { ...@@ -29,5 +29,8 @@ type UserTokenService interface {
LookupToken(unhashedToken string) (*UserToken, error) LookupToken(unhashedToken string) (*UserToken, error)
TryRotateToken(token *UserToken, clientIP, userAgent string) (bool, error) TryRotateToken(token *UserToken, clientIP, userAgent string) (bool, error)
RevokeToken(token *UserToken) error RevokeToken(token *UserToken) error
RevokeAllUserTokens(userId int64) error
ActiveTokenCount() (int64, error) ActiveTokenCount() (int64, error)
GetUserToken(userId, userTokenId int64) (*UserToken, error)
GetUserTokens(userId int64) ([]*UserToken, error)
} }
...@@ -221,6 +221,57 @@ func (s *UserAuthTokenService) RevokeToken(token *models.UserToken) error { ...@@ -221,6 +221,57 @@ func (s *UserAuthTokenService) RevokeToken(token *models.UserToken) error {
return nil return nil
} }
func (s *UserAuthTokenService) RevokeAllUserTokens(userId int64) error {
sql := `DELETE from user_auth_token WHERE user_id = ?`
res, err := s.SQLStore.NewSession().Exec(sql, userId)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
s.log.Debug("all user tokens for user revoked", "userId", userId, "count", affected)
return nil
}
func (s *UserAuthTokenService) GetUserToken(userId, userTokenId int64) (*models.UserToken, error) {
var token userAuthToken
exists, err := s.SQLStore.NewSession().Where("id = ? AND user_id = ?", userTokenId, userId).Get(&token)
if err != nil {
return nil, err
}
if !exists {
return nil, models.ErrUserTokenNotFound
}
var result models.UserToken
token.toUserToken(&result)
return &result, nil
}
func (s *UserAuthTokenService) GetUserTokens(userId int64) ([]*models.UserToken, error) {
var tokens []*userAuthToken
err := s.SQLStore.NewSession().Where("user_id = ? AND created_at > ? AND rotated_at > ?", userId, s.createdAfterParam(), s.rotatedAfterParam()).Find(&tokens)
if err != nil {
return nil, err
}
result := []*models.UserToken{}
for _, token := range tokens {
var userToken models.UserToken
token.toUserToken(&userToken)
result = append(result, &userToken)
}
return result, nil
}
func (s *UserAuthTokenService) createdAfterParam() int64 { func (s *UserAuthTokenService) createdAfterParam() int64 {
tokenMaxLifetime := time.Duration(s.Cfg.LoginMaxLifetimeDays) * 24 * time.Hour tokenMaxLifetime := time.Duration(s.Cfg.LoginMaxLifetimeDays) * 24 * time.Hour
return getTime().Add(-tokenMaxLifetime).Unix() return getTime().Add(-tokenMaxLifetime).Unix()
......
...@@ -75,6 +75,47 @@ func TestUserAuthToken(t *testing.T) { ...@@ -75,6 +75,47 @@ func TestUserAuthToken(t *testing.T) {
err = userAuthTokenService.RevokeToken(userToken) err = userAuthTokenService.RevokeToken(userToken)
So(err, ShouldEqual, models.ErrUserTokenNotFound) So(err, ShouldEqual, models.ErrUserTokenNotFound)
}) })
Convey("When creating an additional token", func() {
userToken2, err := userAuthTokenService.CreateToken(userID, "192.168.10.11:1234", "some user agent")
So(err, ShouldBeNil)
So(userToken2, ShouldNotBeNil)
Convey("Can get first user token", func() {
token, err := userAuthTokenService.GetUserToken(userID, userToken.Id)
So(err, ShouldBeNil)
So(token, ShouldNotBeNil)
So(token.Id, ShouldEqual, userToken.Id)
})
Convey("Can get second user token", func() {
token, err := userAuthTokenService.GetUserToken(userID, userToken2.Id)
So(err, ShouldBeNil)
So(token, ShouldNotBeNil)
So(token.Id, ShouldEqual, userToken2.Id)
})
Convey("Can get user tokens", func() {
tokens, err := userAuthTokenService.GetUserTokens(userID)
So(err, ShouldBeNil)
So(tokens, ShouldHaveLength, 2)
So(tokens[0].Id, ShouldEqual, userToken.Id)
So(tokens[1].Id, ShouldEqual, userToken2.Id)
})
Convey("Can revoke all user tokens", func() {
err := userAuthTokenService.RevokeAllUserTokens(userID)
So(err, ShouldBeNil)
model, err := ctx.getAuthTokenByID(userToken.Id)
So(err, ShouldBeNil)
So(model, ShouldBeNil)
model2, err := ctx.getAuthTokenByID(userToken2.Id)
So(err, ShouldBeNil)
So(model2, ShouldBeNil)
})
})
}) })
Convey("expires correctly", func() { Convey("expires correctly", func() {
......
package auth
import "github.com/grafana/grafana/pkg/models"
type FakeUserAuthTokenService struct {
CreateTokenProvider func(userId int64, clientIP, userAgent string) (*models.UserToken, error)
TryRotateTokenProvider func(token *models.UserToken, clientIP, userAgent string) (bool, error)
LookupTokenProvider func(unhashedToken string) (*models.UserToken, error)
RevokeTokenProvider func(token *models.UserToken) error
RevokeAllUserTokensProvider func(userId int64) error
ActiveAuthTokenCount func() (int64, error)
GetUserTokenProvider func(userId, userTokenId int64) (*models.UserToken, error)
GetUserTokensProvider func(userId int64) ([]*models.UserToken, error)
}
func NewFakeUserAuthTokenService() *FakeUserAuthTokenService {
return &FakeUserAuthTokenService{
CreateTokenProvider: func(userId int64, clientIP, userAgent string) (*models.UserToken, error) {
return &models.UserToken{
UserId: 0,
UnhashedToken: "",
}, nil
},
TryRotateTokenProvider: func(token *models.UserToken, clientIP, userAgent string) (bool, error) {
return false, nil
},
LookupTokenProvider: func(unhashedToken string) (*models.UserToken, error) {
return &models.UserToken{
UserId: 0,
UnhashedToken: "",
}, nil
},
RevokeTokenProvider: func(token *models.UserToken) error {
return nil
},
RevokeAllUserTokensProvider: func(userId int64) error {
return nil
},
ActiveAuthTokenCount: func() (int64, error) {
return 10, nil
},
GetUserTokenProvider: func(userId, userTokenId int64) (*models.UserToken, error) {
return nil, nil
},
GetUserTokensProvider: func(userId int64) ([]*models.UserToken, error) {
return nil, nil
},
}
}
func (s *FakeUserAuthTokenService) CreateToken(userId int64, clientIP, userAgent string) (*models.UserToken, error) {
return s.CreateTokenProvider(userId, clientIP, userAgent)
}
func (s *FakeUserAuthTokenService) LookupToken(unhashedToken string) (*models.UserToken, error) {
return s.LookupTokenProvider(unhashedToken)
}
func (s *FakeUserAuthTokenService) TryRotateToken(token *models.UserToken, clientIP, userAgent string) (bool, error) {
return s.TryRotateTokenProvider(token, clientIP, userAgent)
}
func (s *FakeUserAuthTokenService) RevokeToken(token *models.UserToken) error {
return s.RevokeTokenProvider(token)
}
func (s *FakeUserAuthTokenService) RevokeAllUserTokens(userId int64) error {
return s.RevokeAllUserTokensProvider(userId)
}
func (s *FakeUserAuthTokenService) ActiveTokenCount() (int64, error) {
return s.ActiveAuthTokenCount()
}
func (s *FakeUserAuthTokenService) GetUserToken(userId, userTokenId int64) (*models.UserToken, error) {
return s.GetUserTokenProvider(userId, userTokenId)
}
func (s *FakeUserAuthTokenService) GetUserTokens(userId int64) ([]*models.UserToken, error) {
return s.GetUserTokensProvider(userId)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment