Commit 2a70c730 by Agnès Toulet Committed by GitHub

Auth: add expired token error and update CreateToken function (#30203)

* Auth: add error for expired token

* Auth: save token error into context data

* Auth: send full user and req context to CreateToken

* Auth: add token ID in context

* add TokenExpiredError struct

* update auth tests

* remove most of the changes to CreateToken func

* clean up

* Login: add requestURI in CreateToken ctx

* update RequestURIKey comment
parent 218a8de2
package api package api
import ( import (
"context"
"encoding/hex" "encoding/hex"
"errors" "errors"
"net/http" "net/http"
...@@ -259,10 +260,12 @@ func (hs *HTTPServer) loginUserWithUser(user *models.User, c *models.ReqContext) ...@@ -259,10 +260,12 @@ func (hs *HTTPServer) loginUserWithUser(user *models.User, c *models.ReqContext)
} }
hs.log.Debug("Got IP address from client address", "addr", addr, "ip", ip) hs.log.Debug("Got IP address from client address", "addr", addr, "ip", ip)
userToken, err := hs.AuthTokenService.CreateToken(c.Req.Context(), user.Id, ip, c.Req.UserAgent()) ctx := context.WithValue(c.Req.Context(), models.RequestURIKey{}, c.Req.RequestURI)
userToken, err := hs.AuthTokenService.CreateToken(ctx, user, ip, c.Req.UserAgent())
if err != nil { if err != nil {
return errutil.Wrap("failed to create auth token", err) return errutil.Wrap("failed to create auth token", err)
} }
c.UserToken = userToken
hs.log.Info("Successful Login", "User", user.Email) hs.log.Info("Successful Login", "User", user.Email)
cookies.WriteSessionCookie(c, hs.Cfg, userToken.UnhashedToken, hs.Cfg.LoginMaxLifetime) cookies.WriteSessionCookie(c, hs.Cfg, userToken.UnhashedToken, hs.Cfg.LoginMaxLifetime)
......
...@@ -46,6 +46,10 @@ type LoginInfo struct { ...@@ -46,6 +46,10 @@ type LoginInfo struct {
Error error Error error
} }
// RequestURIKey is used as key to save request URI in contexts
// (used for the Enterprise auditing feature)
type RequestURIKey struct{}
// --------------------- // ---------------------
// COMMANDS // COMMANDS
......
...@@ -11,6 +11,13 @@ var ( ...@@ -11,6 +11,13 @@ var (
ErrUserTokenNotFound = errors.New("user token not found") ErrUserTokenNotFound = errors.New("user token not found")
) )
type TokenExpiredError struct {
UserID int64
TokenID int64
}
func (e *TokenExpiredError) Error() string { return "user token expired" }
// UserToken represents a user token // UserToken represents a user token
type UserToken struct { type UserToken struct {
Id int64 Id int64
...@@ -33,7 +40,7 @@ type RevokeAuthTokenCmd struct { ...@@ -33,7 +40,7 @@ type RevokeAuthTokenCmd struct {
// UserTokenService are used for generating and validating user tokens // UserTokenService are used for generating and validating user tokens
type UserTokenService interface { type UserTokenService interface {
CreateToken(ctx context.Context, userId int64, clientIP net.IP, userAgent string) (*UserToken, error) CreateToken(ctx context.Context, user *User, clientIP net.IP, userAgent string) (*UserToken, error)
LookupToken(ctx context.Context, unhashedToken string) (*UserToken, error) LookupToken(ctx context.Context, unhashedToken string) (*UserToken, error)
TryRotateToken(ctx context.Context, token *UserToken, clientIP net.IP, userAgent string) (bool, error) TryRotateToken(ctx context.Context, token *UserToken, clientIP net.IP, userAgent string) (bool, error)
RevokeToken(ctx context.Context, token *UserToken) error RevokeToken(ctx context.Context, token *UserToken) error
......
...@@ -60,7 +60,7 @@ func (s *UserAuthTokenService) ActiveTokenCount(ctx context.Context) (int64, err ...@@ -60,7 +60,7 @@ func (s *UserAuthTokenService) ActiveTokenCount(ctx context.Context) (int64, err
return count, err return count, err
} }
func (s *UserAuthTokenService) CreateToken(ctx context.Context, userId int64, clientIP net.IP, userAgent string) (*models.UserToken, error) { func (s *UserAuthTokenService) CreateToken(ctx context.Context, user *models.User, clientIP net.IP, userAgent string) (*models.UserToken, error) {
token, err := util.RandomHex(16) token, err := util.RandomHex(16)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -75,7 +75,7 @@ func (s *UserAuthTokenService) CreateToken(ctx context.Context, userId int64, cl ...@@ -75,7 +75,7 @@ func (s *UserAuthTokenService) CreateToken(ctx context.Context, userId int64, cl
} }
userAuthToken := userAuthToken{ userAuthToken := userAuthToken{
UserId: userId, UserId: user.Id,
AuthToken: hashedToken, AuthToken: hashedToken,
PrevAuthToken: hashedToken, PrevAuthToken: hashedToken,
ClientIp: clientIPStr, ClientIp: clientIPStr,
...@@ -116,11 +116,9 @@ func (s *UserAuthTokenService) LookupToken(ctx context.Context, unhashedToken st ...@@ -116,11 +116,9 @@ func (s *UserAuthTokenService) LookupToken(ctx context.Context, unhashedToken st
var exists bool var exists bool
var err error var err error
err = s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error { err = s.SQLStore.WithDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
exists, err = dbSession.Where("(auth_token = ? OR prev_auth_token = ?) AND created_at > ? AND rotated_at > ?", exists, err = dbSession.Where("(auth_token = ? OR prev_auth_token = ?)",
hashedToken, hashedToken,
hashedToken, hashedToken).
s.createdAfterParam(),
s.rotatedAfterParam()).
Get(&model) Get(&model)
return err return err
...@@ -133,6 +131,13 @@ func (s *UserAuthTokenService) LookupToken(ctx context.Context, unhashedToken st ...@@ -133,6 +131,13 @@ func (s *UserAuthTokenService) LookupToken(ctx context.Context, unhashedToken st
return nil, models.ErrUserTokenNotFound return nil, models.ErrUserTokenNotFound
} }
if model.CreatedAt <= s.createdAfterParam() || model.RotatedAt <= s.rotatedAfterParam() {
return nil, &models.TokenExpiredError{
UserID: model.UserId,
TokenID: model.Id,
}
}
if model.AuthToken != hashedToken && model.PrevAuthToken == hashedToken && model.AuthTokenSeen { if model.AuthToken != hashedToken && model.PrevAuthToken == hashedToken && model.AuthTokenSeen {
modelCopy := model modelCopy := model
modelCopy.AuthTokenSeen = false modelCopy.AuthTokenSeen = false
......
...@@ -20,7 +20,8 @@ func TestUserAuthToken(t *testing.T) { ...@@ -20,7 +20,8 @@ func TestUserAuthToken(t *testing.T) {
Convey("Test user auth token", t, func() { Convey("Test user auth token", t, func() {
ctx := createTestContext(t) ctx := createTestContext(t)
userAuthTokenService := ctx.tokenService userAuthTokenService := ctx.tokenService
userID := int64(10) user := &models.User{Id: int64(10)}
userID := user.Id
t := time.Date(2018, 12, 13, 13, 45, 0, 0, time.UTC) t := time.Date(2018, 12, 13, 13, 45, 0, 0, time.UTC)
getTime = func() time.Time { getTime = func() time.Time {
...@@ -28,7 +29,7 @@ func TestUserAuthToken(t *testing.T) { ...@@ -28,7 +29,7 @@ func TestUserAuthToken(t *testing.T) {
} }
Convey("When creating token", func() { Convey("When creating token", func() {
userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, userToken, err := userAuthTokenService.CreateToken(context.Background(), user,
net.ParseIP("192.168.10.11"), "some user agent") net.ParseIP("192.168.10.11"), "some user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(userToken, ShouldNotBeNil) So(userToken, ShouldNotBeNil)
...@@ -80,7 +81,7 @@ func TestUserAuthToken(t *testing.T) { ...@@ -80,7 +81,7 @@ func TestUserAuthToken(t *testing.T) {
}) })
Convey("When creating an additional token", func() { Convey("When creating an additional token", func() {
userToken2, err := userAuthTokenService.CreateToken(context.Background(), userID, userToken2, err := userAuthTokenService.CreateToken(context.Background(), user,
net.ParseIP("192.168.10.11"), "some user agent") net.ParseIP("192.168.10.11"), "some user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(userToken2, ShouldNotBeNil) So(userToken2, ShouldNotBeNil)
...@@ -127,7 +128,7 @@ func TestUserAuthToken(t *testing.T) { ...@@ -127,7 +128,7 @@ func TestUserAuthToken(t *testing.T) {
for i := 0; i < 3; i++ { for i := 0; i < 3; i++ {
userId := userID + int64(i+1) userId := userID + int64(i+1)
userIds = append(userIds, userId) userIds = append(userIds, userId)
_, err := userAuthTokenService.CreateToken(context.Background(), userId, _, err := userAuthTokenService.CreateToken(context.Background(), user,
net.ParseIP("192.168.10.11"), "some user agent") net.ParseIP("192.168.10.11"), "some user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
} }
...@@ -145,7 +146,7 @@ func TestUserAuthToken(t *testing.T) { ...@@ -145,7 +146,7 @@ func TestUserAuthToken(t *testing.T) {
}) })
Convey("expires correctly", func() { Convey("expires correctly", func() {
userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, userToken, err := userAuthTokenService.CreateToken(context.Background(), user,
net.ParseIP("192.168.10.11"), "some user agent") net.ParseIP("192.168.10.11"), "some user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
...@@ -181,13 +182,13 @@ func TestUserAuthToken(t *testing.T) { ...@@ -181,13 +182,13 @@ func TestUserAuthToken(t *testing.T) {
So(stillGood, ShouldNotBeNil) So(stillGood, ShouldNotBeNil)
}) })
Convey("when rotated_at is 7:00:00 ago should not find token", func() { Convey("when rotated_at is 7:00:00 ago should return token expired error", func() {
getTime = func() time.Time { getTime = func() time.Time {
return time.Unix(model.RotatedAt, 0).Add(24 * 7 * time.Hour) return time.Unix(model.RotatedAt, 0).Add(24 * 7 * time.Hour)
} }
notGood, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken) notGood, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
So(err, ShouldEqual, models.ErrUserTokenNotFound) So(err, ShouldHaveSameTypeAs, &models.TokenExpiredError{})
So(notGood, ShouldBeNil) So(notGood, ShouldBeNil)
Convey("should not find active token when expired", func() { Convey("should not find active token when expired", func() {
...@@ -211,7 +212,7 @@ func TestUserAuthToken(t *testing.T) { ...@@ -211,7 +212,7 @@ func TestUserAuthToken(t *testing.T) {
So(stillGood, ShouldNotBeNil) So(stillGood, ShouldNotBeNil)
}) })
Convey("when rotated_at is 5 days ago and created_at is 30 days ago should not find token", func() { Convey("when rotated_at is 5 days ago and created_at is 30 days ago should return token expired error", func() {
updated, err := ctx.updateRotatedAt(model.Id, time.Unix(model.CreatedAt, 0).Add(24*25*time.Hour).Unix()) updated, err := ctx.updateRotatedAt(model.Id, time.Unix(model.CreatedAt, 0).Add(24*25*time.Hour).Unix())
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(updated, ShouldBeTrue) So(updated, ShouldBeTrue)
...@@ -221,13 +222,13 @@ func TestUserAuthToken(t *testing.T) { ...@@ -221,13 +222,13 @@ func TestUserAuthToken(t *testing.T) {
} }
notGood, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken) notGood, err := userAuthTokenService.LookupToken(context.Background(), userToken.UnhashedToken)
So(err, ShouldEqual, models.ErrUserTokenNotFound) So(err, ShouldHaveSameTypeAs, &models.TokenExpiredError{})
So(notGood, ShouldBeNil) So(notGood, ShouldBeNil)
}) })
}) })
Convey("can properly rotate tokens", func() { Convey("can properly rotate tokens", func() {
userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, userToken, err := userAuthTokenService.CreateToken(context.Background(), user,
net.ParseIP("192.168.10.11"), "some user agent") net.ParseIP("192.168.10.11"), "some user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
...@@ -312,7 +313,7 @@ func TestUserAuthToken(t *testing.T) { ...@@ -312,7 +313,7 @@ func TestUserAuthToken(t *testing.T) {
}) })
Convey("keeps prev token valid for 1 minute after it is confirmed", func() { Convey("keeps prev token valid for 1 minute after it is confirmed", func() {
userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, userToken, err := userAuthTokenService.CreateToken(context.Background(), user,
net.ParseIP("192.168.10.11"), "some user agent") net.ParseIP("192.168.10.11"), "some user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(userToken, ShouldNotBeNil) So(userToken, ShouldNotBeNil)
...@@ -345,7 +346,7 @@ func TestUserAuthToken(t *testing.T) { ...@@ -345,7 +346,7 @@ func TestUserAuthToken(t *testing.T) {
}) })
Convey("will not mark token unseen when prev and current are the same", func() { Convey("will not mark token unseen when prev and current are the same", func() {
userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, userToken, err := userAuthTokenService.CreateToken(context.Background(), user,
net.ParseIP("192.168.10.11"), "some user agent") net.ParseIP("192.168.10.11"), "some user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(userToken, ShouldNotBeNil) So(userToken, ShouldNotBeNil)
...@@ -365,7 +366,7 @@ func TestUserAuthToken(t *testing.T) { ...@@ -365,7 +366,7 @@ func TestUserAuthToken(t *testing.T) {
}) })
Convey("Rotate token", func() { Convey("Rotate token", func() {
userToken, err := userAuthTokenService.CreateToken(context.Background(), userID, userToken, err := userAuthTokenService.CreateToken(context.Background(), user,
net.ParseIP("192.168.10.11"), "some user agent") net.ParseIP("192.168.10.11"), "some user agent")
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(userToken, ShouldNotBeNil) So(userToken, ShouldNotBeNil)
......
...@@ -8,7 +8,7 @@ import ( ...@@ -8,7 +8,7 @@ import (
) )
type FakeUserAuthTokenService struct { type FakeUserAuthTokenService struct {
CreateTokenProvider func(ctx context.Context, userId int64, clientIP net.IP, userAgent string) (*models.UserToken, error) CreateTokenProvider func(ctx context.Context, user *models.User, clientIP net.IP, userAgent string) (*models.UserToken, error)
TryRotateTokenProvider func(ctx context.Context, token *models.UserToken, clientIP net.IP, userAgent string) (bool, error) TryRotateTokenProvider func(ctx context.Context, token *models.UserToken, clientIP net.IP, userAgent string) (bool, error)
LookupTokenProvider func(ctx context.Context, unhashedToken string) (*models.UserToken, error) LookupTokenProvider func(ctx context.Context, unhashedToken string) (*models.UserToken, error)
RevokeTokenProvider func(ctx context.Context, token *models.UserToken) error RevokeTokenProvider func(ctx context.Context, token *models.UserToken) error
...@@ -21,7 +21,7 @@ type FakeUserAuthTokenService struct { ...@@ -21,7 +21,7 @@ type FakeUserAuthTokenService struct {
func NewFakeUserAuthTokenService() *FakeUserAuthTokenService { func NewFakeUserAuthTokenService() *FakeUserAuthTokenService {
return &FakeUserAuthTokenService{ return &FakeUserAuthTokenService{
CreateTokenProvider: func(ctx context.Context, userId int64, clientIP net.IP, userAgent string) (*models.UserToken, error) { CreateTokenProvider: func(ctx context.Context, user *models.User, clientIP net.IP, userAgent string) (*models.UserToken, error) {
return &models.UserToken{ return &models.UserToken{
UserId: 0, UserId: 0,
UnhashedToken: "", UnhashedToken: "",
...@@ -63,8 +63,8 @@ func (s *FakeUserAuthTokenService) Init() error { ...@@ -63,8 +63,8 @@ func (s *FakeUserAuthTokenService) Init() error {
return nil return nil
} }
func (s *FakeUserAuthTokenService) CreateToken(ctx context.Context, userId int64, clientIP net.IP, userAgent string) (*models.UserToken, error) { func (s *FakeUserAuthTokenService) CreateToken(ctx context.Context, user *models.User, clientIP net.IP, userAgent string) (*models.UserToken, error) {
return s.CreateTokenProvider(context.Background(), userId, clientIP, userAgent) return s.CreateTokenProvider(context.Background(), user, clientIP, userAgent)
} }
func (s *FakeUserAuthTokenService) LookupToken(ctx context.Context, unhashedToken string) (*models.UserToken, error) { func (s *FakeUserAuthTokenService) LookupToken(ctx context.Context, unhashedToken string) (*models.UserToken, error) {
......
...@@ -258,6 +258,8 @@ func (h *ContextHandler) initContextWithToken(ctx *models.ReqContext, orgID int6 ...@@ -258,6 +258,8 @@ func (h *ContextHandler) initContextWithToken(ctx *models.ReqContext, orgID int6
if err != nil { if err != nil {
ctx.Logger.Error("Failed to look up user based on cookie", "error", err) ctx.Logger.Error("Failed to look up user based on cookie", "error", err)
cookies.WriteSessionCookie(ctx, h.Cfg, "", -1) cookies.WriteSessionCookie(ctx, h.Cfg, "", -1)
ctx.Data["lookupTokenErr"] = err
return false return false
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment