Commit 35e0e078 by Arve Knudsen Committed by GitHub

pkg/util: Check errors (#19832)

* pkg/util: Check errors
* pkg/services: DRY up code
parent 31a346fc
......@@ -61,7 +61,11 @@ func AdminUpdateUserPassword(c *models.ReqContext, form dtos.AdminUpdateUserPass
return
}
passwordHashed := util.EncodePassword(form.Password, userQuery.Result.Salt)
passwordHashed, err := util.EncodePassword(form.Password, userQuery.Result.Salt)
if err != nil {
c.JsonApiErr(500, "Could not encode password", err)
return
}
cmd := models.ChangeUserPasswordCommand{
UserId: userID,
......
package api
import (
"time"
"github.com/grafana/grafana/pkg/api/dtos"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/components/apikeygen"
"github.com/grafana/grafana/pkg/models"
"time"
)
func GetAPIKeys(c *models.ReqContext) Response {
......@@ -61,7 +62,11 @@ func (hs *HTTPServer) AddAPIKey(c *models.ReqContext, cmd models.AddApiKeyComman
}
cmd.OrgId = c.OrgId
newKeyInfo := apikeygen.New(cmd.OrgId, cmd.Name)
newKeyInfo, err := apikeygen.New(cmd.OrgId, cmd.Name)
if err != nil {
return Error(500, "Generating API key failed", err)
}
cmd.Key = newKeyInfo.HashedKey
if err := bus.Dispatch(&cmd); err != nil {
......
......@@ -100,11 +100,21 @@ func CreateDashboardSnapshot(c *m.ReqContext, cmd m.CreateDashboardSnapshotComma
metrics.MApiDashboardSnapshotExternal.Inc()
} else {
if cmd.Key == "" {
cmd.Key = util.GetRandomString(32)
var err error
cmd.Key, err = util.GetRandomString(32)
if err != nil {
c.JsonApiErr(500, "Could not generate random string", err)
return
}
}
if cmd.DeleteKey == "" {
cmd.DeleteKey = util.GetRandomString(32)
var err error
cmd.DeleteKey, err = util.GetRandomString(32)
if err != nil {
c.JsonApiErr(500, "Could not generate random string", err)
return
}
}
url = setting.ToAbsUrl("dashboard/snapshot/" + cmd.Key)
......
......@@ -51,7 +51,11 @@ func AddOrgInvite(c *m.ReqContext, inviteDto dtos.AddInviteForm) Response {
cmd.Name = inviteDto.Name
cmd.Status = m.TmpUserInvitePending
cmd.InvitedByUserId = c.UserId
cmd.Code = util.GetRandomString(30)
var err error
cmd.Code, err = util.GetRandomString(30)
if err != nil {
return Error(500, "Could not generate random string", err)
}
cmd.Role = inviteDto.Role
cmd.RemoteAddr = c.Req.RemoteAddr
......
......@@ -47,7 +47,11 @@ func ResetPassword(c *m.ReqContext, form dtos.ResetUserPasswordForm) Response {
cmd := m.ChangeUserPasswordCommand{}
cmd.UserId = query.Result.Id
cmd.NewPassword = util.EncodePassword(form.NewPassword, query.Result.Salt)
var err error
cmd.NewPassword, err = util.EncodePassword(form.NewPassword, query.Result.Salt)
if err != nil {
return Error(500, "Failed to encode password", err)
}
if err := bus.Dispatch(&cmd); err != nil {
return Error(500, "Failed to change user password", err)
......
......@@ -34,7 +34,11 @@ func SignUp(c *m.ReqContext, form dtos.SignUpForm) Response {
cmd.Email = form.Email
cmd.Status = m.TmpUserSignUpStarted
cmd.InvitedByUserId = c.UserId
cmd.Code = util.GetRandomString(20)
var err error
cmd.Code, err = util.GetRandomString(20)
if err != nil {
return Error(500, "Failed to generate random string", err)
}
cmd.RemoteAddr = c.Req.RemoteAddr
if err := bus.Dispatch(&cmd); err != nil {
......
......@@ -222,7 +222,10 @@ func ChangeUserPassword(c *m.ReqContext, cmd m.ChangeUserPasswordCommand) Respon
return Error(500, "Could not read user from database", err)
}
passwordHashed := util.EncodePassword(cmd.OldPassword, userQuery.Result.Salt)
passwordHashed, err := util.EncodePassword(cmd.OldPassword, userQuery.Result.Salt)
if err != nil {
return Error(500, "Failed to encode password", err)
}
if passwordHashed != userQuery.Result.Password {
return Error(401, "Invalid old password", nil)
}
......@@ -233,7 +236,10 @@ func ChangeUserPassword(c *m.ReqContext, cmd m.ChangeUserPasswordCommand) Respon
}
cmd.UserId = c.UserId
cmd.NewPassword = util.EncodePassword(cmd.NewPassword, userQuery.Result.Salt)
cmd.NewPassword, err = util.EncodePassword(cmd.NewPassword, userQuery.Result.Salt)
if err != nil {
return Error(500, "Failed to encode password", err)
}
if err := bus.Dispatch(&cmd); err != nil {
return Error(500, "Failed to change user password", err)
......
......@@ -10,6 +10,7 @@ import (
"github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/services/sqlstore"
"github.com/grafana/grafana/pkg/util"
"github.com/grafana/grafana/pkg/util/errutil"
)
const AdminUserId = 1
......@@ -28,7 +29,10 @@ func resetPasswordCommand(c utils.CommandLine, sqlStore *sqlstore.SqlStore) erro
return fmt.Errorf("Could not read user from database. Error: %v", err)
}
passwordHashed := util.EncodePassword(newPassword, userQuery.Result.Salt)
passwordHashed, err := util.EncodePassword(newPassword, userQuery.Result.Salt)
if err != nil {
return err
}
cmd := models.ChangeUserPasswordCommand{
UserId: AdminUserId,
......@@ -36,7 +40,7 @@ func resetPasswordCommand(c utils.CommandLine, sqlStore *sqlstore.SqlStore) erro
}
if err := bus.Dispatch(&cmd); err != nil {
return fmt.Errorf("Failed to update user password")
return errutil.Wrapf(err, "Failed to update user password")
}
logger.Infof("\n")
......
......@@ -21,20 +21,30 @@ type ApiKeyJson struct {
OrgId int64 `json:"id"`
}
func New(orgId int64, name string) KeyGenResult {
jsonKey := ApiKeyJson{}
func New(orgId int64, name string) (KeyGenResult, error) {
result := KeyGenResult{}
jsonKey := ApiKeyJson{}
jsonKey.OrgId = orgId
jsonKey.Name = name
jsonKey.Key = util.GetRandomString(32)
var err error
jsonKey.Key, err = util.GetRandomString(32)
if err != nil {
return result, err
}
result := KeyGenResult{}
result.HashedKey = util.EncodePassword(jsonKey.Key, name)
result.HashedKey, err = util.EncodePassword(jsonKey.Key, name)
if err != nil {
return result, err
}
jsonString, _ := json.Marshal(jsonKey)
jsonString, err := json.Marshal(jsonKey)
if err != nil {
return result, err
}
result.ClientSecret = base64.StdEncoding.EncodeToString(jsonString)
return result
return result, nil
}
func Decode(keyString string) (*ApiKeyJson, error) {
......@@ -52,7 +62,10 @@ func Decode(keyString string) (*ApiKeyJson, error) {
return &keyObj, nil
}
func IsValid(key *ApiKeyJson, hashedKey string) bool {
check := util.EncodePassword(key.Key, key.Name)
return check == hashedKey
func IsValid(key *ApiKeyJson, hashedKey string) (bool, error) {
check, err := util.EncodePassword(key.Key, key.Name)
if err != nil {
return false, err
}
return check == hashedKey, nil
}
......@@ -10,7 +10,8 @@ import (
func TestApiKeyGen(t *testing.T) {
Convey("When generating new api key", t, func() {
result := New(12, "Cool key")
result, err := New(12, "Cool key")
So(err, ShouldBeNil)
So(result.ClientSecret, ShouldNotBeEmpty)
So(result.HashedKey, ShouldNotBeEmpty)
......@@ -19,7 +20,8 @@ func TestApiKeyGen(t *testing.T) {
keyInfo, err := Decode(result.ClientSecret)
So(err, ShouldBeNil)
keyHashed := util.EncodePassword(keyInfo.Key, keyInfo.Name)
keyHashed, err := util.EncodePassword(keyInfo.Key, keyInfo.Name)
So(err, ShouldBeNil)
So(keyHashed, ShouldEqual, result.HashedKey)
})
})
......
......@@ -51,7 +51,12 @@ func (az *AzureBlobUploader) Upload(ctx context.Context, imageDiskPath string) (
}
defer file.Close()
randomFileName := util.GetRandomString(30) + ".png"
randomFileName, err := util.GetRandomString(30)
if err != nil {
return "", err
}
randomFileName += pngExt
// upload image
az.log.Debug("Uploading image to azure_blob", "container_name", az.container_name, "blob_name", randomFileName)
resp, err := blob.FileUpload(az.container_name, randomFileName, file)
......
......@@ -35,7 +35,12 @@ func NewGCSUploader(keyFile, bucket, path string) *GCSUploader {
}
func (u *GCSUploader) Upload(ctx context.Context, imageDiskPath string) (string, error) {
fileName := util.GetRandomString(20) + ".png"
fileName, err := util.GetRandomString(20)
if err != nil {
return "", err
}
fileName += pngExt
key := path.Join(u.path, fileName)
u.log.Debug("Opening key file ", u.keyFile)
......
......@@ -9,6 +9,8 @@ import (
"github.com/grafana/grafana/pkg/setting"
)
const pngExt = ".png"
type ImageUploader interface {
Upload(ctx context.Context, path string) (string, error)
}
......
......@@ -61,7 +61,11 @@ func (u *S3Uploader) Upload(ctx context.Context, imageDiskPath string) (string,
}
s3_endpoint, _ := endpoints.DefaultResolver().EndpointFor("s3", u.region)
key := u.path + util.GetRandomString(20) + ".png"
rand, err := util.GetRandomString(20)
if err != nil {
return "", err
}
key := u.path + rand + pngExt
image_url := s3_endpoint.URL + "/" + u.bucket + "/" + key
log.Debug("Uploading image to s3. url = %s", image_url)
......
......@@ -47,7 +47,12 @@ func (u *WebdavUploader) PublicURL(filename string) string {
func (u *WebdavUploader) Upload(ctx context.Context, pa string) (string, error) {
url, _ := url.Parse(u.url)
filename := util.GetRandomString(20) + ".png"
filename, err := util.GetRandomString(20)
if err != nil {
return "", err
}
filename += pngExt
url.Path = path.Join(url.Path, filename)
imgData, err := ioutil.ReadFile(pa)
......
......@@ -9,7 +9,10 @@ import (
)
var validatePassword = func(providedPassword string, userPassword string, userSalt string) error {
passwordHashed := util.EncodePassword(providedPassword, userSalt)
passwordHashed, err := util.EncodePassword(providedPassword, userSalt)
if err != nil {
return err
}
if subtle.ConstantTimeCompare([]byte(passwordHashed), []byte(userPassword)) != 1 {
return ErrInvalidCredentials
}
......
......@@ -128,7 +128,12 @@ func initContextWithApiKey(ctx *models.ReqContext) bool {
apikey := keyQuery.Result
// validate api key
if !apikeygen.IsValid(decoded, apikey.Key) {
isValid, err := apikeygen.IsValid(decoded, apikey.Key)
if err != nil {
ctx.JsonApiErr(500, "Validating API key failed", err)
return true
}
if !isValid {
ctx.JsonApiErr(401, errStringInvalidAPIKey, err)
return true
}
......
......@@ -27,7 +27,8 @@ func TestMiddlewareBasicAuth(t *testing.T) {
middlewareScenario(t, "Valid API key", func(sc *scenarioContext) {
var orgID int64 = 2
keyhash := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
So(err, ShouldBeNil)
bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error {
query.Result = &models.ApiKey{OrgId: orgID, Role: models.ROLE_EDITOR, Key: keyhash}
......@@ -54,8 +55,12 @@ func TestMiddlewareBasicAuth(t *testing.T) {
var orgID int64 = 2
bus.AddHandler("grafana-auth", func(query *models.LoginUserQuery) error {
encoded, err := util.EncodePassword(password, salt)
if err != nil {
return err
}
query.User = &models.User{
Password: util.EncodePassword(password, salt),
Password: encoded,
Salt: salt,
}
return nil
......@@ -85,8 +90,12 @@ func TestMiddlewareBasicAuth(t *testing.T) {
authLogin.Init()
bus.AddHandler("user-query", func(query *models.GetUserByLoginQuery) error {
encoded, err := util.EncodePassword(password, salt)
if err != nil {
return err
}
query.Result = &models.User{
Password: util.EncodePassword(password, salt),
Password: encoded,
Id: id,
Salt: salt,
}
......
......@@ -142,7 +142,8 @@ func TestMiddlewareContext(t *testing.T) {
})
middlewareScenario(t, "Valid api key", func(sc *scenarioContext) {
keyhash := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
So(err, ShouldBeNil)
bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error {
query.Result = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash}
......@@ -182,10 +183,10 @@ func TestMiddlewareContext(t *testing.T) {
mockGetTime()
defer resetGetTime()
keyhash := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
keyhash, err := util.EncodePassword("v5nAwpMafFP6znaS4urhdWDLS5511M42", "asd")
So(err, ShouldBeNil)
bus.AddHandler("test", func(query *models.GetApiKeyByNameQuery) error {
// api key expired one second before
expires := getTime().Add(-1 * time.Second).Unix()
query.Result = &models.ApiKey{OrgId: 12, Role: models.ROLE_EDITOR, Key: keyhash,
......
......@@ -33,10 +33,14 @@ func initContextWithRenderAuth(ctx *m.ReqContext) bool {
return true
}
func AddRenderAuthKey(orgId int64, userId int64, orgRole m.RoleType) string {
func AddRenderAuthKey(orgId int64, userId int64, orgRole m.RoleType) (string, error) {
renderKeysLock.Lock()
defer renderKeysLock.Unlock()
key := util.GetRandomString(32)
key, err := util.GetRandomString(32)
if err != nil {
return "", err
}
renderKeys[key] = &m.SignedInUser{
OrgId: orgId,
......@@ -44,13 +48,12 @@ func AddRenderAuthKey(orgId int64, userId int64, orgRole m.RoleType) string {
UserId: userId,
}
renderKeysLock.Unlock()
return key
return key, nil
}
func RemoveRenderAuthKey(key string) {
renderKeysLock.Lock()
defer renderKeysLock.Unlock()
delete(renderKeys, key)
renderKeysLock.Unlock()
}
......@@ -25,16 +25,24 @@ var netClient = &http.Client{
}
func (rs *RenderingService) renderViaHttp(ctx context.Context, opts Opts) (*RenderResult, error) {
filePath := rs.getFilePathForNewImage()
filePath, err := rs.getFilePathForNewImage()
if err != nil {
return nil, err
}
rendererUrl, err := url.Parse(rs.Cfg.RendererUrl)
if err != nil {
return nil, err
}
renderKey, err := rs.getRenderKey(opts.OrgId, opts.UserId, opts.OrgRole)
if err != nil {
return nil, err
}
queryParams := rendererUrl.Query()
queryParams.Add("url", rs.getURL(opts.Path))
queryParams.Add("renderKey", rs.getRenderKey(opts.OrgId, opts.UserId, opts.OrgRole))
queryParams.Add("renderKey", renderKey)
queryParams.Add("width", strconv.Itoa(opts.Width))
queryParams.Add("height", strconv.Itoa(opts.Height))
queryParams.Add("domain", rs.domain)
......
......@@ -30,9 +30,15 @@ func (rs *RenderingService) renderViaPhantomJS(ctx context.Context, opts Opts) (
}
scriptPath, _ := filepath.Abs(filepath.Join(rs.Cfg.PhantomDir, "render.js"))
pngPath := rs.getFilePathForNewImage()
pngPath, err := rs.getFilePathForNewImage()
if err != nil {
return nil, err
}
renderKey := middleware.AddRenderAuthKey(opts.OrgId, opts.UserId, opts.OrgRole)
renderKey, err := middleware.AddRenderAuthKey(opts.OrgId, opts.UserId, opts.OrgRole)
if err != nil {
return nil, err
}
defer middleware.RemoveRenderAuthKey(renderKey)
phantomDebugArg := "--debug=false"
......
......@@ -69,7 +69,15 @@ func (rs *RenderingService) watchAndRestartPlugin(ctx context.Context) error {
}
func (rs *RenderingService) renderViaPlugin(ctx context.Context, opts Opts) (*RenderResult, error) {
pngPath := rs.getFilePathForNewImage()
pngPath, err := rs.getFilePathForNewImage()
if err != nil {
return nil, err
}
renderKey, err := rs.getRenderKey(opts.OrgId, opts.UserId, opts.OrgRole)
if err != nil {
return nil, err
}
rsp, err := rs.grpcPlugin.Render(ctx, &pluginModel.RenderRequest{
Url: rs.getURL(opts.Path),
......@@ -77,16 +85,14 @@ func (rs *RenderingService) renderViaPlugin(ctx context.Context, opts Opts) (*Re
Height: int32(opts.Height),
FilePath: pngPath,
Timeout: int32(opts.Timeout.Seconds()),
RenderKey: rs.getRenderKey(opts.OrgId, opts.UserId, opts.OrgRole),
RenderKey: renderKey,
Encoding: opts.Encoding,
Timezone: isoTimeOffsetToPosixTz(opts.Timezone),
Domain: rs.domain,
})
if err != nil {
return nil, err
}
if rsp.Error != "" {
return nil, fmt.Errorf("Rendering failed: %v", rsp.Error)
}
......
......@@ -112,9 +112,17 @@ func (rs *RenderingService) Render(ctx context.Context, opts Opts) (*RenderResul
return nil, fmt.Errorf("No renderer found")
}
func (rs *RenderingService) getFilePathForNewImage() string {
pngPath, _ := filepath.Abs(filepath.Join(rs.Cfg.ImagesDir, util.GetRandomString(20)))
return pngPath + ".png"
func (rs *RenderingService) getFilePathForNewImage() (string, error) {
rand, err := util.GetRandomString(20)
if err != nil {
return "", err
}
pngPath, err := filepath.Abs(filepath.Join(rs.Cfg.ImagesDir, rand))
if err != nil {
return "", err
}
return pngPath + ".png", nil
}
func (rs *RenderingService) getURL(path string) string {
......@@ -131,6 +139,6 @@ func (rs *RenderingService) getURL(path string) string {
return fmt.Sprintf("%s://%s:%s/%s&render=1", setting.Protocol, rs.domain, setting.HttpPort, path)
}
func (rs *RenderingService) getRenderKey(orgId, userId int64, orgRole models.RoleType) string {
func (rs *RenderingService) getRenderKey(orgId, userId int64, orgRole models.RoleType) (string, error) {
return middleware.AddRenderAuthKey(orgId, userId, orgRole)
}
......@@ -146,10 +146,18 @@ func (m *AddMissingUserSaltAndRandsMigration) Exec(sess *xorm.Session, mg *Migra
}
for _, user := range users {
_, err := sess.Exec("UPDATE "+mg.Dialect.Quote("user")+" SET salt = ?, rands = ? WHERE id = ?", util.GetRandomString(10), util.GetRandomString(10), user.Id)
salt, err := util.GetRandomString(10)
if err != nil {
return err
}
rands, err := util.GetRandomString(10)
if err != nil {
return err
}
if _, err := sess.Exec("UPDATE "+mg.Dialect.Quote("user")+
" SET salt = ?, rands = ? WHERE id = ?", salt, rands, user.Id); err != nil {
return err
}
}
return nil
}
......@@ -114,11 +114,23 @@ func CreateUser(ctx context.Context, cmd *models.CreateUserCommand) error {
LastSeenAt: time.Now().AddDate(-10, 0, 0),
}
user.Salt = util.GetRandomString(10)
user.Rands = util.GetRandomString(10)
salt, err := util.GetRandomString(10)
if err != nil {
return err
}
user.Salt = salt
rands, err := util.GetRandomString(10)
if err != nil {
return err
}
user.Rands = rands
if len(cmd.Password) > 0 {
user.Password = util.EncodePassword(cmd.Password, user.Salt)
encodedPassword, err := util.EncodePassword(cmd.Password, user.Salt)
if err != nil {
return err
}
user.Password = encodedPassword
}
sess.UseBool("is_admin")
......
......@@ -2,7 +2,6 @@ package util
import (
"crypto/hmac"
"crypto/md5"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
......@@ -14,10 +13,13 @@ import (
// GetRandomString generate random string by specify chars.
// source: https://github.com/gogits/gogs/blob/9ee80e3e5426821f03a4e99fad34418f5c736413/modules/base/tool.go#L58
func GetRandomString(n int, alphabets ...byte) string {
func GetRandomString(n int, alphabets ...byte) (string, error) {
const alphanum = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
var bytes = make([]byte, n)
rand.Read(bytes)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
for i, b := range bytes {
if len(alphabets) == 0 {
bytes[i] = alphanum[b%byte(len(alphanum))]
......@@ -25,26 +27,22 @@ func GetRandomString(n int, alphabets ...byte) string {
bytes[i] = alphabets[b%byte(len(alphabets))]
}
}
return string(bytes)
return string(bytes), nil
}
// EncodePassword encodes a password using PBKDF2.
func EncodePassword(password string, salt string) string {
newPasswd := PBKDF2([]byte(password), []byte(salt), 10000, 50, sha256.New)
return hex.EncodeToString(newPasswd)
}
// EncodeMd5 encodes a string to md5 hex value.
func EncodeMd5(str string) string {
m := md5.New()
m.Write([]byte(str))
return hex.EncodeToString(m.Sum(nil))
func EncodePassword(password string, salt string) (string, error) {
newPasswd, err := PBKDF2([]byte(password), []byte(salt), 10000, 50, sha256.New)
if err != nil {
return "", err
}
return hex.EncodeToString(newPasswd), nil
}
// PBKDF2 implements Password-Based Key Derivation Function 2), aimed to reduce
// the vulnerability of encrypted keys to brute force attacks.
// http://code.google.com/p/go/source/browse/pbkdf2/pbkdf2.go?repo=crypto
func PBKDF2(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte {
func PBKDF2(password, salt []byte, iter, keyLen int, h func() hash.Hash) ([]byte, error) {
prf := hmac.New(h, password)
hashLen := prf.Size()
numBlocks := (keyLen + hashLen - 1) / hashLen
......@@ -57,12 +55,17 @@ func PBKDF2(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte
// for each block T_i = U_1 ^ U_2 ^ ... ^ U_iter
// U_1 = PRF(password, salt || uint(i))
prf.Reset()
prf.Write(salt)
if _, err := prf.Write(salt); err != nil {
return nil, err
}
buf[0] = byte(block >> 24)
buf[1] = byte(block >> 16)
buf[2] = byte(block >> 8)
buf[3] = byte(block)
prf.Write(buf[:4])
if _, err := prf.Write(buf[:4]); err != nil {
return nil, err
}
dk = prf.Sum(dk)
T := dk[len(dk)-hashLen:]
copy(U, T)
......@@ -70,7 +73,9 @@ func PBKDF2(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte
// U_n = PRF(password, U_(n-1))
for n := 2; n <= iter; n++ {
prf.Reset()
prf.Write(U)
if _, err := prf.Write(U); err != nil {
return nil, err
}
U = U[:0]
U = prf.Sum(U)
for x := range U {
......@@ -78,7 +83,7 @@ func PBKDF2(password, salt []byte, iter, keyLen int, h func() hash.Hash) []byte
}
}
}
return dk[:keyLen]
return dk[:keyLen], nil
}
// GetBasicAuthHeader returns a base64 encoded string from user and password.
......
......@@ -23,7 +23,8 @@ func TestEncoding(t *testing.T) {
})
Convey("When encoding password", t, func() {
encodedPassword := EncodePassword("iamgod", "pepper")
encodedPassword, err := EncodePassword("iamgod", "pepper")
So(err, ShouldBeNil)
So(encodedPassword, ShouldEqual, "e59c568621e57756495a468f47c74e07c911b037084dd464bb2ed72410970dc849cabd71b48c394faf08a5405dae53741ce9")
})
}
......@@ -14,7 +14,10 @@ const saltLength = 8
// Decrypt decrypts a payload with a given secret.
func Decrypt(payload []byte, secret string) ([]byte, error) {
salt := payload[:saltLength]
key := encryptionKeyToBytes(secret, string(salt))
key, err := encryptionKeyToBytes(secret, string(salt))
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
......@@ -39,9 +42,15 @@ func Decrypt(payload []byte, secret string) ([]byte, error) {
// Encrypt encrypts a payload with a given secret.
func Encrypt(payload []byte, secret string) ([]byte, error) {
salt := GetRandomString(saltLength)
salt, err := GetRandomString(saltLength)
if err != nil {
return nil, err
}
key := encryptionKeyToBytes(secret, salt)
key, err := encryptionKeyToBytes(secret, salt)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
......@@ -63,6 +72,6 @@ func Encrypt(payload []byte, secret string) ([]byte, error) {
}
// Key needs to be 32bytes
func encryptionKeyToBytes(secret, salt string) []byte {
func encryptionKeyToBytes(secret, salt string) ([]byte, error) {
return PBKDF2([]byte(secret), []byte(salt), 10000, 32, sha256.New)
}
......@@ -10,10 +10,12 @@ func TestEncryption(t *testing.T) {
Convey("When getting encryption key", t, func() {
key := encryptionKeyToBytes("secret", "salt")
key, err := encryptionKeyToBytes("secret", "salt")
So(err, ShouldBeNil)
So(len(key), ShouldEqual, 32)
key = encryptionKeyToBytes("a very long secret key that is larger then 32bytes", "salt")
key, err = encryptionKeyToBytes("a very long secret key that is larger then 32bytes", "salt")
So(err, ShouldBeNil)
So(len(key), ShouldEqual, 32)
})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment