Commit 60ddad8f by Alexander Zobnin Committed by GitHub

Batch disable users (#17254)

* batch disable users

* batch revoke users tokens

* split batch disable user and revoke token

* fix tests for batch disable users

* Chore: add BatchDisableUsers() to the bus
parent 1497f3d7
...@@ -94,6 +94,11 @@ type DisableUserCommand struct { ...@@ -94,6 +94,11 @@ type DisableUserCommand struct {
IsDisabled bool IsDisabled bool
} }
type BatchDisableUsersCommand struct {
UserIds []int64
IsDisabled bool
}
type DeleteUserCommand struct { type DeleteUserCommand struct {
UserId int64 UserId int64
} }
......
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"strings"
"time" "time"
"github.com/grafana/grafana/pkg/infra/serverlock" "github.com/grafana/grafana/pkg/infra/serverlock"
...@@ -305,6 +306,36 @@ func (s *UserAuthTokenService) RevokeAllUserTokens(ctx context.Context, userId i ...@@ -305,6 +306,36 @@ func (s *UserAuthTokenService) RevokeAllUserTokens(ctx context.Context, userId i
}) })
} }
func (s *UserAuthTokenService) BatchRevokeAllUserTokens(ctx context.Context, userIds []int64) error {
return s.SQLStore.WithTransactionalDbSession(ctx, func(dbSession *sqlstore.DBSession) error {
if len(userIds) == 0 {
return nil
}
user_id_params := strings.Repeat(",?", len(userIds)-1)
sql := "DELETE from user_auth_token WHERE user_id IN (?" + user_id_params + ")"
params := []interface{}{sql}
for _, v := range userIds {
params = append(params, v)
}
res, err := dbSession.Exec(params...)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
s.log.Debug("all user tokens for given users revoked", "usersCount", len(userIds), "count", affected)
return err
})
}
func (s *UserAuthTokenService) GetUserToken(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) { func (s *UserAuthTokenService) GetUserToken(ctx context.Context, userId, userTokenId int64) (*models.UserToken, error) {
var result models.UserToken var result models.UserToken
......
...@@ -117,6 +117,26 @@ func TestUserAuthToken(t *testing.T) { ...@@ -117,6 +117,26 @@ func TestUserAuthToken(t *testing.T) {
So(model2, ShouldBeNil) So(model2, ShouldBeNil)
}) })
}) })
Convey("When revoking users tokens in a batch", func() {
Convey("Can revoke all users tokens", func() {
userIds := []int64{}
for i := 0; i < 3; i++ {
userId := userID + int64(i+1)
userIds = append(userIds, userId)
userAuthTokenService.CreateToken(context.Background(), userId, "192.168.10.11:1234", "some user agent")
}
err := userAuthTokenService.BatchRevokeAllUserTokens(context.Background(), userIds)
So(err, ShouldBeNil)
for _, v := range userIds {
tokens, err := userAuthTokenService.GetUserTokens(context.Background(), v)
So(err, ShouldBeNil)
So(len(tokens), ShouldEqual, 0)
}
})
})
}) })
Convey("expires correctly", func() { Convey("expires correctly", func() {
......
...@@ -28,6 +28,7 @@ func (ss *SqlStore) addUserQueryAndCommandHandlers() { ...@@ -28,6 +28,7 @@ func (ss *SqlStore) addUserQueryAndCommandHandlers() {
bus.AddHandler("sql", SearchUsers) bus.AddHandler("sql", SearchUsers)
bus.AddHandler("sql", GetUserOrgList) bus.AddHandler("sql", GetUserOrgList)
bus.AddHandler("sql", DisableUser) bus.AddHandler("sql", DisableUser)
bus.AddHandler("sql", BatchDisableUsers)
bus.AddHandler("sql", DeleteUser) bus.AddHandler("sql", DeleteUser)
bus.AddHandler("sql", UpdateUserPermissions) bus.AddHandler("sql", UpdateUserPermissions)
bus.AddHandler("sql", SetUserHelpFlag) bus.AddHandler("sql", SetUserHelpFlag)
...@@ -487,6 +488,31 @@ func DisableUser(cmd *m.DisableUserCommand) error { ...@@ -487,6 +488,31 @@ func DisableUser(cmd *m.DisableUserCommand) error {
return err return err
} }
func BatchDisableUsers(cmd *m.BatchDisableUsersCommand) error {
return inTransaction(func(sess *DBSession) error {
userIds := cmd.UserIds
if len(userIds) == 0 {
return nil
}
user_id_params := strings.Repeat(",?", len(userIds)-1)
disableSQL := "UPDATE " + dialect.Quote("user") + " SET is_disabled=? WHERE Id IN (?" + user_id_params + ")"
disableParams := []interface{}{disableSQL, cmd.IsDisabled}
for _, v := range userIds {
disableParams = append(disableParams, v)
}
_, err := sess.Exec(disableParams...)
if err != nil {
return err
}
return nil
})
}
func DeleteUser(cmd *m.DeleteUserCommand) error { func DeleteUser(cmd *m.DeleteUserCommand) error {
return inTransaction(func(sess *DBSession) error { return inTransaction(func(sess *DBSession) error {
return deleteUserInTransaction(sess, cmd) return deleteUserInTransaction(sess, cmd)
......
...@@ -175,6 +175,40 @@ func TestUserDataAccess(t *testing.T) { ...@@ -175,6 +175,40 @@ func TestUserDataAccess(t *testing.T) {
So(found, ShouldBeTrue) So(found, ShouldBeTrue)
}) })
}) })
Convey("When batch disabling users", func() {
userIdsToDisable := []int64{}
for i := 0; i < 3; i++ {
userIdsToDisable = append(userIdsToDisable, users[i].Id)
}
disableCmd := m.BatchDisableUsersCommand{UserIds: userIdsToDisable, IsDisabled: true}
err = BatchDisableUsers(&disableCmd)
So(err, ShouldBeNil)
Convey("Should disable all provided users", func() {
query := m.SearchUsersQuery{}
err = SearchUsers(&query)
So(query.Result.TotalCount, ShouldEqual, 5)
for _, user := range query.Result.Users {
shouldBeDisabled := false
// Check if user id is in the userIdsToDisable list
for _, disabledUserId := range userIdsToDisable {
if user.Id == disabledUserId {
So(user.IsDisabled, ShouldBeTrue)
shouldBeDisabled = true
}
}
// Otherwise user shouldn't be disabled
if !shouldBeDisabled {
So(user.IsDisabled, ShouldBeFalse)
}
}
})
})
}) })
Convey("Given one grafana admin user", func() { Convey("Given one grafana admin user", func() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment