Commit 19caa100 by Bill Oley Committed by GitHub

OAuth: Fix token refresh failure when custom SSL settings are configured for…

OAuth: Fix token refresh failure when custom SSL settings are configured for OAuth provider (#27523)

OAuth token refresh fails when custom SSL settings are configured for 
oauth provider. These changes makes sure that custom SSL settings 
are applied for HTTP client before refreshing token.

Fixes #27514
parent 5d11d8fa
...@@ -4,13 +4,10 @@ import ( ...@@ -4,13 +4,10 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"io/ioutil"
"net/http" "net/http"
"net/url" "net/url"
...@@ -116,46 +113,14 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) { ...@@ -116,46 +113,14 @@ func (hs *HTTPServer) OAuthLogin(ctx *models.ReqContext) {
return return
} }
// handle callback oauthClient, err := social.GetOAuthHttpClient(name)
tr := &http.Transport{ if err != nil {
Proxy: http.ProxyFromEnvironment, ctx.Logger.Error("Failed to create OAuth http client", "error", err)
TLSClientConfig: &tls.Config{ hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
InsecureSkipVerify: setting.OAuthService.OAuthInfos[name].TlsSkipVerify, HttpStatus: http.StatusInternalServerError,
}, PublicMessage: "login.OAuthLogin(" + err.Error() + ")",
} })
oauthClient := &http.Client{ return
Transport: tr,
}
if setting.OAuthService.OAuthInfos[name].TlsClientCert != "" || setting.OAuthService.OAuthInfos[name].TlsClientKey != "" {
cert, err := tls.LoadX509KeyPair(setting.OAuthService.OAuthInfos[name].TlsClientCert, setting.OAuthService.OAuthInfos[name].TlsClientKey)
if err != nil {
ctx.Logger.Error("Failed to setup TlsClientCert", "oauth", name, "error", err)
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError,
PublicMessage: "login.OAuthLogin(Failed to setup TlsClientCert)",
})
return
}
tr.TLSClientConfig.Certificates = append(tr.TLSClientConfig.Certificates, cert)
}
if setting.OAuthService.OAuthInfos[name].TlsClientCa != "" {
caCert, err := ioutil.ReadFile(setting.OAuthService.OAuthInfos[name].TlsClientCa)
if err != nil {
ctx.Logger.Error("Failed to setup TlsClientCa", "oauth", name, "error", err)
hs.handleOAuthLoginError(ctx, loginInfo, LoginError{
HttpStatus: http.StatusInternalServerError,
PublicMessage: "login.OAuthLogin(Failed to setup TlsClientCa)",
})
return
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
tr.TLSClientConfig.RootCAs = caCertPool
} }
oauthCtx := context.WithValue(context.Background(), oauth2.HTTPClient, oauthClient) oauthCtx := context.WithValue(context.Background(), oauth2.HTTPClient, oauthClient)
......
...@@ -2,6 +2,7 @@ package pluginproxy ...@@ -2,6 +2,7 @@ package pluginproxy
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
...@@ -311,10 +312,10 @@ func addOAuthPassThruAuth(c *models.ReqContext, req *http.Request) { ...@@ -311,10 +312,10 @@ func addOAuthPassThruAuth(c *models.ReqContext, req *http.Request) {
return return
} }
provider := authInfoQuery.Result.AuthModule authProvider := authInfoQuery.Result.AuthModule
connect, ok := social.SocialMap[strings.TrimPrefix(provider, "oauth_")] // The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does connect, err := social.GetConnector(authProvider)
if !ok { if err != nil {
logger.Error("Failed to find oauth provider with given name", "provider", provider) logger.Error("Failed to get OAuth connector", "error", err)
return return
} }
...@@ -324,8 +325,16 @@ func addOAuthPassThruAuth(c *models.ReqContext, req *http.Request) { ...@@ -324,8 +325,16 @@ func addOAuthPassThruAuth(c *models.ReqContext, req *http.Request) {
RefreshToken: authInfoQuery.Result.OAuthRefreshToken, RefreshToken: authInfoQuery.Result.OAuthRefreshToken,
TokenType: authInfoQuery.Result.OAuthTokenType, TokenType: authInfoQuery.Result.OAuthTokenType,
} }
client, err := social.GetOAuthHttpClient(authProvider)
if err != nil {
logger.Error("Failed to create OAuth http client", "error", err)
return
}
oauthctx := context.WithValue(c.Req.Context(), oauth2.HTTPClient, client)
// TokenSource handles refreshing the token if it has expired // TokenSource handles refreshing the token if it has expired
token, err := connect.TokenSource(c.Req.Context(), persistedToken).Token() token, err := connect.TokenSource(oauthctx, persistedToken).Token()
if err != nil { if err != nil {
logger.Error("Failed to retrieve access token from OAuth provider", "provider", authInfoQuery.Result.AuthModule, "userid", c.UserId, "username", c.Login, "error", err) logger.Error("Failed to retrieve access token from OAuth provider", "provider", authInfoQuery.Result.AuthModule, "userid", c.UserId, "username", c.Login, "error", err)
return return
......
...@@ -383,6 +383,9 @@ func TestDSRouteRule(t *testing.T) { ...@@ -383,6 +383,9 @@ func TestDSRouteRule(t *testing.T) {
Config: &oauth2.Config{}, Config: &oauth2.Config{},
}, },
} }
setting.OAuthService = &setting.OAuther{}
setting.OAuthService.OAuthInfos = make(map[string]*setting.OAuthInfo)
setting.OAuthService.OAuthInfos["generic_oauth"] = &setting.OAuthInfo{}
bus.AddHandler("test", func(query *models.GetAuthInfoQuery) error { bus.AddHandler("test", func(query *models.GetAuthInfoQuery) error {
query.Result = &models.UserAuth{ query.Result = &models.UserAuth{
......
package social package social
import ( import (
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"net/http" "net/http"
"strings" "strings"
...@@ -13,6 +17,10 @@ import ( ...@@ -13,6 +17,10 @@ import (
"github.com/grafana/grafana/pkg/util" "github.com/grafana/grafana/pkg/util"
) )
var (
logger = log.New("social")
)
type BasicUserInfo struct { type BasicUserInfo struct {
Id string Id string
Name string Name string
...@@ -225,3 +233,58 @@ var GetOAuthProviders = func(cfg *setting.Cfg) map[string]bool { ...@@ -225,3 +233,58 @@ var GetOAuthProviders = func(cfg *setting.Cfg) map[string]bool {
return result return result
} }
func GetOAuthHttpClient(name string) (*http.Client, error) {
if setting.OAuthService == nil {
return nil, fmt.Errorf("OAuth not enabled")
}
// The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does
name = strings.TrimPrefix(name, "oauth_")
info, ok := setting.OAuthService.OAuthInfos[name]
if !ok {
return nil, fmt.Errorf("Could not find %s in OAuth Settings", name)
}
// handle call back
tr := &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: info.TlsSkipVerify,
},
}
oauthClient := &http.Client{
Transport: tr,
}
if info.TlsClientCert != "" || info.TlsClientKey != "" {
cert, err := tls.LoadX509KeyPair(info.TlsClientCert, info.TlsClientKey)
logger.Error("Failed to setup TlsClientCert", "oauth", name, "error", err)
if err != nil {
return nil, fmt.Errorf("Failed to setup TlsClientCert")
}
tr.TLSClientConfig.Certificates = append(tr.TLSClientConfig.Certificates, cert)
}
if info.TlsClientCa != "" {
caCert, err := ioutil.ReadFile(info.TlsClientCa)
if err != nil {
logger.Error("Failed to setup TlsClientCa", "oauth", name, "error", err)
return nil, fmt.Errorf("Failed to setup TlsClientCa")
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
tr.TLSClientConfig.RootCAs = caCertPool
}
return oauthClient, nil
}
func GetConnector(name string) (SocialConnector, error) {
// The socialMap keys don't have "oauth_" prefix, but everywhere else in the system does
provider := strings.TrimPrefix(name, "oauth_")
connector, ok := SocialMap[provider]
if !ok {
return nil, fmt.Errorf("Failed to find oauth provider for %s", name)
}
return connector, nil
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment