Commit 982e095f by Daniel Lee

dsproxy: add mutex protection to the token caches

parent b5800ffe
...@@ -8,6 +8,7 @@ import ( ...@@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"strconv" "strconv"
"sync"
"time" "time"
"golang.org/x/oauth2" "golang.org/x/oauth2"
...@@ -17,10 +18,24 @@ import ( ...@@ -17,10 +18,24 @@ import (
) )
var ( var (
tokenCache = map[string]*jwtToken{} tokenCache = tokenCacheType{
oauthJwtTokenCache = map[string]*oauth2.Token{} cache: map[string]*jwtToken{},
}
oauthJwtTokenCache = oauthJwtTokenCacheType{
cache: map[string]*oauth2.Token{},
}
) )
type tokenCacheType struct {
cache map[string]*jwtToken
sync.Mutex
}
type oauthJwtTokenCacheType struct {
cache map[string]*oauth2.Token
sync.Mutex
}
type accessTokenProvider struct { type accessTokenProvider struct {
route *plugins.AppPluginRoute route *plugins.AppPluginRoute
datasourceID int64 datasourceID int64
...@@ -40,7 +55,9 @@ func newAccessTokenProvider(dsID int64, pluginRoute *plugins.AppPluginRoute) *ac ...@@ -40,7 +55,9 @@ func newAccessTokenProvider(dsID int64, pluginRoute *plugins.AppPluginRoute) *ac
} }
func (provider *accessTokenProvider) getAccessToken(data templateData) (string, error) { func (provider *accessTokenProvider) getAccessToken(data templateData) (string, error) {
if cachedToken, found := tokenCache[provider.getAccessTokenCacheKey()]; found { tokenCache.Lock()
defer tokenCache.Unlock()
if cachedToken, found := tokenCache.cache[provider.getAccessTokenCacheKey()]; found {
if cachedToken.ExpiresOn.After(time.Now().Add(time.Second * 10)) { if cachedToken.ExpiresOn.After(time.Now().Add(time.Second * 10)) {
logger.Info("Using token from cache") logger.Info("Using token from cache")
return cachedToken.AccessToken, nil return cachedToken.AccessToken, nil
...@@ -79,7 +96,7 @@ func (provider *accessTokenProvider) getAccessToken(data templateData) (string, ...@@ -79,7 +96,7 @@ func (provider *accessTokenProvider) getAccessToken(data templateData) (string,
expiresOnEpoch, _ := strconv.ParseInt(token.ExpiresOnString, 10, 64) expiresOnEpoch, _ := strconv.ParseInt(token.ExpiresOnString, 10, 64)
token.ExpiresOn = time.Unix(expiresOnEpoch, 0) token.ExpiresOn = time.Unix(expiresOnEpoch, 0)
tokenCache[provider.getAccessTokenCacheKey()] = &token tokenCache.cache[provider.getAccessTokenCacheKey()] = &token
logger.Info("Got new access token", "ExpiresOn", token.ExpiresOn) logger.Info("Got new access token", "ExpiresOn", token.ExpiresOn)
...@@ -87,7 +104,9 @@ func (provider *accessTokenProvider) getAccessToken(data templateData) (string, ...@@ -87,7 +104,9 @@ func (provider *accessTokenProvider) getAccessToken(data templateData) (string,
} }
func (provider *accessTokenProvider) getJwtAccessToken(ctx context.Context, data templateData) (string, error) { func (provider *accessTokenProvider) getJwtAccessToken(ctx context.Context, data templateData) (string, error) {
if cachedToken, found := oauthJwtTokenCache[provider.getAccessTokenCacheKey()]; found { oauthJwtTokenCache.Lock()
defer oauthJwtTokenCache.Unlock()
if cachedToken, found := oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()]; found {
if cachedToken.Expiry.After(time.Now().Add(time.Second * 10)) { if cachedToken.Expiry.After(time.Now().Add(time.Second * 10)) {
logger.Info("Using token from cache") logger.Info("Using token from cache")
return cachedToken.AccessToken, nil return cachedToken.AccessToken, nil
...@@ -127,7 +146,9 @@ func (provider *accessTokenProvider) getJwtAccessToken(ctx context.Context, data ...@@ -127,7 +146,9 @@ func (provider *accessTokenProvider) getJwtAccessToken(ctx context.Context, data
return "", err return "", err
} }
oauthJwtTokenCache[provider.getAccessTokenCacheKey()] = token oauthJwtTokenCache.cache[provider.getAccessTokenCacheKey()] = token
logger.Info("Got new access token", "ExpiresOn", token.Expiry)
return token.AccessToken, nil return token.AccessToken, nil
} }
...@@ -139,21 +160,9 @@ var getTokenSource = func(conf *jwt.Config, ctx context.Context) (*oauth2.Token, ...@@ -139,21 +160,9 @@ var getTokenSource = func(conf *jwt.Config, ctx context.Context) (*oauth2.Token,
return nil, err return nil, err
} }
// logger.Info("interpolatedVal", "token.AccessToken", token.AccessToken)
return token, nil return token, nil
} }
func (provider *accessTokenProvider) getAccessTokenCacheKey() string { func (provider *accessTokenProvider) getAccessTokenCacheKey() string {
return fmt.Sprintf("%v_%v_%v", provider.datasourceID, provider.route.Path, provider.route.Method) return fmt.Sprintf("%v_%v_%v", provider.datasourceID, provider.route.Path, provider.route.Method)
} }
//Export access token lookup
func GetAccessTokenFromCache(datasourceID int64, path string, method string) (string, error) {
key := fmt.Sprintf("%v_%v_%v", datasourceID, path, method)
if cachedToken, found := oauthJwtTokenCache[key]; found {
return cachedToken.AccessToken, nil
} else {
return "", fmt.Errorf("Key doesnt exist")
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment