Commit fda7e686 by Daniel Lee Committed by GitHub

Merge pull request #15205 from seanlaff/12556-oauth-pass-thru

Add oauth pass-thru option for datasources
parents 516b7ce8 50c58544
......@@ -165,6 +165,7 @@ func (hs *HTTPServer) OAuthLogin(ctx *m.ReqContext) {
extUser := &m.ExternalUserInfo{
AuthModule: "oauth_" + name,
OAuthToken: token,
AuthId: userInfo.Id,
Name: userInfo.Name,
Login: userInfo.Login,
......
......@@ -14,8 +14,11 @@ import (
"time"
"github.com/opentracing/opentracing-go"
"golang.org/x/oauth2"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/log"
"github.com/grafana/grafana/pkg/login/social"
m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
......@@ -221,6 +224,10 @@ func (proxy *DataSourceProxy) getDirector() func(req *http.Request) {
if proxy.route != nil {
ApplyRoute(proxy.ctx.Req.Context(), req, proxy.proxyPath, proxy.route, proxy.ds)
}
if proxy.ds.JsonData != nil && proxy.ds.JsonData.Get("oauthPassThru").MustBool() {
addOAuthPassThruAuth(proxy.ctx, req)
}
}
}
......@@ -311,3 +318,46 @@ func checkWhiteList(c *m.ReqContext, host string) bool {
return true
}
func addOAuthPassThruAuth(c *m.ReqContext, req *http.Request) {
authInfoQuery := &m.GetAuthInfoQuery{UserId: c.UserId}
if err := bus.Dispatch(authInfoQuery); err != nil {
logger.Error("Error feching oauth information for user", "error", err)
return
}
provider := authInfoQuery.Result.AuthModule
connect, ok := social.SocialMap[strings.TrimPrefix(provider, "oauth_")] // The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does
if !ok {
logger.Error("Failed to find oauth provider with given name", "provider", provider)
return
}
// TokenSource handles refreshing the token if it has expired
token, err := connect.TokenSource(c.Req.Context(), &oauth2.Token{
AccessToken: authInfoQuery.Result.OAuthAccessToken,
Expiry: authInfoQuery.Result.OAuthExpiry,
RefreshToken: authInfoQuery.Result.OAuthRefreshToken,
TokenType: authInfoQuery.Result.OAuthTokenType,
}).Token()
if err != nil {
logger.Error("Failed to retrieve access token from oauth provider", "provider", authInfoQuery.Result.AuthModule)
return
}
// If the tokens are not the same, update the entry in the DB
if token.AccessToken != authInfoQuery.Result.OAuthAccessToken {
updateAuthCommand := &m.UpdateAuthInfoCommand{
UserId: authInfoQuery.Result.Id,
AuthModule: authInfoQuery.Result.AuthModule,
AuthId: authInfoQuery.Result.AuthId,
OAuthToken: token,
}
if err := bus.Dispatch(updateAuthCommand); err != nil {
logger.Error("Failed to update access token during token refresh", "error", err)
return
}
}
req.Header.Del("Authorization")
req.Header.Add("Authorization", fmt.Sprintf("%s %s", token.Type(), token.AccessToken))
}
......@@ -9,10 +9,13 @@ import (
"testing"
"time"
"golang.org/x/oauth2"
macaron "gopkg.in/macaron.v1"
"github.com/grafana/grafana/pkg/bus"
"github.com/grafana/grafana/pkg/components/simplejson"
"github.com/grafana/grafana/pkg/log"
"github.com/grafana/grafana/pkg/login/social"
m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/plugins"
"github.com/grafana/grafana/pkg/setting"
......@@ -389,6 +392,54 @@ func TestDSRouteRule(t *testing.T) {
})
})
Convey("When proxying a datasource that has oauth token pass-thru enabled", func() {
social.SocialMap["generic_oauth"] = &social.SocialGenericOAuth{
SocialBase: &social.SocialBase{
Config: &oauth2.Config{},
},
}
bus.AddHandler("test", func(query *m.GetAuthInfoQuery) error {
query.Result = &m.UserAuth{
Id: 1,
UserId: 1,
AuthModule: "generic_oauth",
OAuthAccessToken: "testtoken",
OAuthRefreshToken: "testrefreshtoken",
OAuthTokenType: "Bearer",
OAuthExpiry: time.Now().AddDate(0, 0, 1),
}
return nil
})
plugin := &plugins.DataSourcePlugin{}
ds := &m.DataSource{
Type: "custom-datasource",
Url: "http://host/root/",
JsonData: simplejson.NewFromAny(map[string]interface{}{
"oauthPassThru": true,
}),
}
req, _ := http.NewRequest("GET", "http://localhost/asd", nil)
ctx := &m.ReqContext{
SignedInUser: &m.SignedInUser{UserId: 1},
Context: &macaron.Context{
Req: macaron.Request{Request: req},
},
}
proxy := NewDataSourceProxy(ds, plugin, ctx, "/path/to/folder/", &setting.Cfg{})
req, err := http.NewRequest(http.MethodGet, "http://grafana.com/sub", nil)
So(err, ShouldBeNil)
proxy.getDirector()(req)
Convey("Should have access token in header", func() {
So(req.Header.Get("Authorization"), ShouldEqual, fmt.Sprintf("%s %s", "Bearer", "testtoken"))
})
})
Convey("When SendUserHeader config is enabled", func() {
req := getDatasourceProxiedRequest(
&m.ReqContext{
......
......@@ -63,11 +63,12 @@ func (ls *LoginService) UpsertUser(cmd *m.UpsertUserCommand) error {
return err
}
if extUser.AuthModule != "" && extUser.AuthId != "" {
if extUser.AuthModule != "" {
cmd2 := &m.SetAuthInfoCommand{
UserId: cmd.Result.Id,
AuthModule: extUser.AuthModule,
AuthId: extUser.AuthId,
OAuthToken: extUser.OAuthToken,
}
if err := ls.Bus.Dispatch(cmd2); err != nil {
return err
......@@ -81,6 +82,14 @@ func (ls *LoginService) UpsertUser(cmd *m.UpsertUserCommand) error {
if err != nil {
return err
}
// Always persist the latest token at log-in
if extUser.AuthModule != "" && extUser.OAuthToken != nil {
err = updateUserAuth(cmd.Result, extUser)
if err != nil {
return err
}
}
}
err = syncOrgRoles(cmd.Result, extUser)
......@@ -155,6 +164,18 @@ func updateUser(user *m.User, extUser *m.ExternalUserInfo) error {
return bus.Dispatch(updateCmd)
}
func updateUserAuth(user *m.User, extUser *m.ExternalUserInfo) error {
updateCmd := &m.UpdateAuthInfoCommand{
AuthModule: extUser.AuthModule,
AuthId: extUser.AuthId,
UserId: user.Id,
OAuthToken: extUser.OAuthToken,
}
logger.Debug("Updating user_auth info", "user_id", user.Id)
return bus.Dispatch(updateCmd)
}
func syncOrgRoles(user *m.User, extUser *m.ExternalUserInfo) error {
// don't sync org roles if none are specified
if len(extUser.OrgRoles) == 0 {
......
......@@ -31,6 +31,7 @@ type SocialConnector interface {
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
Exchange(ctx context.Context, code string) (*oauth2.Token, error)
Client(ctx context.Context, t *oauth2.Token) *http.Client
TokenSource(ctx context.Context, t *oauth2.Token) oauth2.TokenSource
}
type SocialBase struct {
......
......@@ -2,17 +2,24 @@ package models
import (
"time"
"golang.org/x/oauth2"
)
type UserAuth struct {
Id int64
UserId int64
AuthModule string
AuthId string
Created time.Time
Id int64
UserId int64
AuthModule string
AuthId string
Created time.Time
OAuthAccessToken string
OAuthRefreshToken string
OAuthTokenType string
OAuthExpiry time.Time
}
type ExternalUserInfo struct {
OAuthToken *oauth2.Token
AuthModule string
AuthId string
UserId int64
......@@ -39,6 +46,14 @@ type SetAuthInfoCommand struct {
AuthModule string
AuthId string
UserId int64
OAuthToken *oauth2.Token
}
type UpdateAuthInfoCommand struct {
AuthModule string
AuthId string
UserId int64
OAuthToken *oauth2.Token
}
type DeleteAuthInfoCommand struct {
......@@ -67,6 +82,7 @@ type GetUserByAuthInfoQuery struct {
}
type GetAuthInfoQuery struct {
UserId int64
AuthModule string
AuthId string
......
......@@ -25,4 +25,21 @@ func addUserAuthMigrations(mg *Migrator) {
mg.AddMigration("alter user_auth.auth_id to length 190", NewRawSqlMigration("").
Postgres("ALTER TABLE user_auth ALTER COLUMN auth_id TYPE VARCHAR(190);").
Mysql("ALTER TABLE user_auth MODIFY auth_id VARCHAR(190);"))
mg.AddMigration("Add OAuth access token to user_auth", NewAddColumnMigration(userAuthV1, &Column{
Name: "o_auth_access_token", Type: DB_Text, Nullable: true,
}))
mg.AddMigration("Add OAuth refresh token to user_auth", NewAddColumnMigration(userAuthV1, &Column{
Name: "o_auth_refresh_token", Type: DB_Text, Nullable: true,
}))
mg.AddMigration("Add OAuth token type to user_auth", NewAddColumnMigration(userAuthV1, &Column{
Name: "o_auth_token_type", Type: DB_Text, Nullable: true,
}))
mg.AddMigration("Add OAuth expiry to user_auth", NewAddColumnMigration(userAuthV1, &Column{
Name: "o_auth_expiry", Type: DB_DateTime, Nullable: true,
}))
mg.AddMigration("Add index to user_id column in user_auth", NewAddIndexMigration(userAuthV1, &Index{
Cols: []string{"user_id"},
}))
}
package sqlstore
import (
"encoding/base64"
"time"
"github.com/grafana/grafana/pkg/bus"
m "github.com/grafana/grafana/pkg/models"
"github.com/grafana/grafana/pkg/setting"
"github.com/grafana/grafana/pkg/util"
)
var getTime = time.Now
func init() {
bus.AddHandler("sql", GetUserByAuthInfo)
bus.AddHandler("sql", GetAuthInfo)
bus.AddHandler("sql", SetAuthInfo)
bus.AddHandler("sql", UpdateAuthInfo)
bus.AddHandler("sql", DeleteAuthInfo)
}
......@@ -94,7 +100,7 @@ func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error {
}
// create authInfo record to link accounts
if authQuery.Result == nil && query.AuthModule != "" && query.AuthId != "" {
if authQuery.Result == nil && query.AuthModule != "" {
cmd2 := &m.SetAuthInfoCommand{
UserId: user.Id,
AuthModule: query.AuthModule,
......@@ -111,10 +117,11 @@ func GetUserByAuthInfo(query *m.GetUserByAuthInfoQuery) error {
func GetAuthInfo(query *m.GetAuthInfoQuery) error {
userAuth := &m.UserAuth{
UserId: query.UserId,
AuthModule: query.AuthModule,
AuthId: query.AuthId,
}
has, err := x.Get(userAuth)
has, err := x.Desc("created").Get(userAuth)
if err != nil {
return err
}
......@@ -122,6 +129,22 @@ func GetAuthInfo(query *m.GetAuthInfoQuery) error {
return m.ErrUserNotFound
}
secretAccessToken, err := decodeAndDecrypt(userAuth.OAuthAccessToken)
if err != nil {
return err
}
secretRefreshToken, err := decodeAndDecrypt(userAuth.OAuthRefreshToken)
if err != nil {
return err
}
secretTokenType, err := decodeAndDecrypt(userAuth.OAuthTokenType)
if err != nil {
return err
}
userAuth.OAuthAccessToken = secretAccessToken
userAuth.OAuthRefreshToken = secretRefreshToken
userAuth.OAuthTokenType = secretTokenType
query.Result = userAuth
return nil
}
......@@ -132,7 +155,27 @@ func SetAuthInfo(cmd *m.SetAuthInfoCommand) error {
UserId: cmd.UserId,
AuthModule: cmd.AuthModule,
AuthId: cmd.AuthId,
Created: time.Now(),
Created: getTime(),
}
if cmd.OAuthToken != nil {
secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken)
if err != nil {
return err
}
secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken)
if err != nil {
return err
}
secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType)
if err != nil {
return err
}
authUser.OAuthAccessToken = secretAccessToken
authUser.OAuthRefreshToken = secretRefreshToken
authUser.OAuthTokenType = secretTokenType
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
}
_, err := sess.Insert(authUser)
......@@ -140,9 +183,76 @@ func SetAuthInfo(cmd *m.SetAuthInfoCommand) error {
})
}
func UpdateAuthInfo(cmd *m.UpdateAuthInfoCommand) error {
return inTransaction(func(sess *DBSession) error {
authUser := &m.UserAuth{
UserId: cmd.UserId,
AuthModule: cmd.AuthModule,
AuthId: cmd.AuthId,
Created: getTime(),
}
if cmd.OAuthToken != nil {
secretAccessToken, err := encryptAndEncode(cmd.OAuthToken.AccessToken)
if err != nil {
return err
}
secretRefreshToken, err := encryptAndEncode(cmd.OAuthToken.RefreshToken)
if err != nil {
return err
}
secretTokenType, err := encryptAndEncode(cmd.OAuthToken.TokenType)
if err != nil {
return err
}
authUser.OAuthAccessToken = secretAccessToken
authUser.OAuthRefreshToken = secretRefreshToken
authUser.OAuthTokenType = secretTokenType
authUser.OAuthExpiry = cmd.OAuthToken.Expiry
}
cond := &m.UserAuth{
UserId: cmd.UserId,
AuthModule: cmd.AuthModule,
}
_, err := sess.Update(authUser, cond)
return err
})
}
func DeleteAuthInfo(cmd *m.DeleteAuthInfoCommand) error {
return inTransaction(func(sess *DBSession) error {
_, err := sess.Delete(cmd.UserAuth)
return err
})
}
// decodeAndDecrypt will decode the string with the standard bas64 decoder
// and then decrypt it with grafana's secretKey
func decodeAndDecrypt(s string) (string, error) {
// Bail out if empty string since it'll cause a segfault in util.Decrypt
if s == "" {
return "", nil
}
decoded, err := base64.StdEncoding.DecodeString(s)
if err != nil {
return "", err
}
decrypted, err := util.Decrypt(decoded, setting.SecretKey)
if err != nil {
return "", err
}
return string(decrypted), nil
}
// encryptAndEncode will encrypt a string with grafana's secretKey, and
// then encode it with the standard bas64 encoder
func encryptAndEncode(s string) (string, error) {
encrypted, err := util.Encrypt([]byte(s), setting.SecretKey)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(encrypted), nil
}
......@@ -4,8 +4,10 @@ import (
"context"
"fmt"
"testing"
"time"
. "github.com/smartystreets/goconvey/convey"
"golang.org/x/oauth2"
m "github.com/grafana/grafana/pkg/models"
)
......@@ -126,5 +128,97 @@ func TestUserAuth(t *testing.T) {
So(err, ShouldEqual, m.ErrUserNotFound)
So(query.Result, ShouldBeNil)
})
Convey("Can set & retrieve oauth token information", func() {
token := &oauth2.Token{
AccessToken: "testaccess",
RefreshToken: "testrefresh",
Expiry: time.Now(),
TokenType: "Bearer",
}
// Find a user to set tokens on
login := "loginuser0"
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
query := &m.GetUserByAuthInfoQuery{Login: login, AuthModule: "test", AuthId: "test"}
err = GetUserByAuthInfo(query)
So(err, ShouldBeNil)
So(query.Result.Login, ShouldEqual, login)
cmd := &m.UpdateAuthInfoCommand{
UserId: query.Result.Id,
AuthId: query.AuthId,
AuthModule: query.AuthModule,
OAuthToken: token,
}
err = UpdateAuthInfo(cmd)
So(err, ShouldBeNil)
getAuthQuery := &m.GetAuthInfoQuery{
UserId: query.Result.Id,
}
err = GetAuthInfo(getAuthQuery)
So(err, ShouldBeNil)
So(getAuthQuery.Result.OAuthAccessToken, ShouldEqual, token.AccessToken)
So(getAuthQuery.Result.OAuthRefreshToken, ShouldEqual, token.RefreshToken)
So(getAuthQuery.Result.OAuthTokenType, ShouldEqual, token.TokenType)
})
Convey("Always return the most recently used auth_module", func() {
// Find a user to set tokens on
login := "loginuser0"
// Calling GetUserByAuthInfoQuery on an existing user will populate an entry in the user_auth table
// Make the first log-in during the past
getTime = func() time.Time { return time.Now().AddDate(0, 0, -2) }
query := &m.GetUserByAuthInfoQuery{Login: login, AuthModule: "test1", AuthId: "test1"}
err = GetUserByAuthInfo(query)
getTime = time.Now
So(err, ShouldBeNil)
So(query.Result.Login, ShouldEqual, login)
// Add a second auth module for this user
// Have this module's last log-in be more recent
getTime = func() time.Time { return time.Now().AddDate(0, 0, -1) }
query = &m.GetUserByAuthInfoQuery{Login: login, AuthModule: "test2", AuthId: "test2"}
err = GetUserByAuthInfo(query)
getTime = time.Now
So(err, ShouldBeNil)
So(query.Result.Login, ShouldEqual, login)
// Get the latest entry by not supply an authmodule or authid
getAuthQuery := &m.GetAuthInfoQuery{
UserId: query.Result.Id,
}
err = GetAuthInfo(getAuthQuery)
So(err, ShouldBeNil)
So(getAuthQuery.Result.AuthModule, ShouldEqual, "test2")
// "log in" again with the first auth module
updateAuthCmd := &m.UpdateAuthInfoCommand{UserId: query.Result.Id, AuthModule: "test1", AuthId: "test1"}
err = UpdateAuthInfo(updateAuthCmd)
So(err, ShouldBeNil)
// Get the latest entry by not supply an authmodule or authid
getAuthQuery = &m.GetAuthInfoQuery{
UserId: query.Result.Id,
}
err = GetAuthInfo(getAuthQuery)
So(err, ShouldBeNil)
So(getAuthQuery.Result.AuthModule, ShouldEqual, "test1")
})
})
}
......@@ -71,22 +71,25 @@
<h3 class="page-heading">Auth</h3>
<div class="gf-form-group">
<div class="gf-form-inline">
<gf-form-checkbox class="gf-form" label="Basic Auth" checked="current.basicAuth" label-class="width-10" switch-class="max-width-6"></gf-form-checkbox>
<gf-form-checkbox class="gf-form" label="Basic Auth" checked="current.basicAuth" label-class="width-13" switch-class="max-width-6"></gf-form-checkbox>
<gf-form-checkbox class="gf-form" label="With Credentials" tooltip="Whether credentials such as cookies or auth
headers should be sent with cross-site requests." checked="current.withCredentials" label-class="width-11"
headers should be sent with cross-site requests." checked="current.withCredentials" label-class="width-13"
switch-class="max-width-6"></gf-form-checkbox>
</div>
<div class="gf-form-inline">
<gf-form-checkbox class="gf-form" ng-if="current.access=='proxy'" label="TLS Client Auth" label-class="width-10"
<gf-form-checkbox class="gf-form" ng-if="current.access=='proxy'" label="TLS Client Auth" label-class="width-13"
checked="current.jsonData.tlsAuth" switch-class="max-width-6"></gf-form-checkbox>
<gf-form-checkbox class="gf-form" ng-if="current.access=='proxy'" label="With CA Cert" tooltip="Needed for
verifing self-signed TLS Certs" checked="current.jsonData.tlsAuthWithCACert" label-class="width-11"
verifing self-signed TLS Certs" checked="current.jsonData.tlsAuthWithCACert" label-class="width-13"
switch-class="max-width-6"></gf-form-checkbox>
</div>
<div class="gf-form-inline">
<gf-form-checkbox class="gf-form" ng-if="current.access=='proxy'" label="Skip TLS Verify" label-class="width-10"
<gf-form-checkbox class="gf-form" ng-if="current.access=='proxy'" label="Skip TLS Verify" label-class="width-13"
checked="current.jsonData.tlsSkipVerify" switch-class="max-width-6"></gf-form-checkbox>
</div>
<div class="gf-form-inline">
<gf-form-checkbox class="gf-form" ng-if="current.access=='proxy'" label="Forward OAuth Identity" label-class="width-13" tooltip="Forward the user's upstream OAuth identity to the datasource (Their access token gets passed along)." checked="current.jsonData.oauthPassThru" switch-class="max-width-6"></gf-form-checkbox>
</div>
</div>
<div class="gf-form-group" ng-if="current.basicAuth">
......@@ -102,4 +105,4 @@
</div>
<datasource-tls-auth-settings current="current" ng-if="(current.jsonData.tlsAuth || current.jsonData.tlsAuthWithCACert) && current.access=='proxy'">
</datasource-tls-auth-settings>
\ No newline at end of file
</datasource-tls-auth-settings>
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment