Commit 7cbf3d8d by Sean Johnson Committed by Marcus Efraimsson

OAuth: Fix role mapping from id token (#20300)

As part of the new improvements to JMESPath in #17149 a 
commit ensuring Role and Email data could be queried from 
id_token was lost.
This refactors and fixes extracting from both token and user 
info api consistently where before only either token or 
either user info api was used for extracting data/attributes.

Fixes #20243

Co-authored-by: Timo Wendt <timo@tjwendt.de>
Co-authored-by: twendt <timo@tjwendt.de>
Co-authored-by: henninge <henning@eggers.name>
Co-Authored-by: Henning Eggers <henning.eggers@inovex.de>
Co-authored-by: Marcus Efraimsson <marcus.efraimsson@gmail.com>
parent fef31acb
...@@ -9,6 +9,8 @@ import ( ...@@ -9,6 +9,8 @@ import (
"net/mail" "net/mail"
"regexp" "regexp"
"github.com/grafana/grafana/pkg/util/errutil"
"github.com/grafana/grafana/pkg/models" "github.com/grafana/grafana/pkg/models"
"github.com/jmespath/go-jmespath" "github.com/jmespath/go-jmespath"
"golang.org/x/oauth2" "golang.org/x/oauth2"
...@@ -43,8 +45,8 @@ func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool { ...@@ -43,8 +45,8 @@ func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool {
return true return true
} }
teamMemberships, err := s.FetchTeamMemberships(client) teamMemberships, ok := s.FetchTeamMemberships(client)
if err != nil { if !ok {
return false return false
} }
...@@ -64,8 +66,8 @@ func (s *SocialGenericOAuth) IsOrganizationMember(client *http.Client) bool { ...@@ -64,8 +66,8 @@ func (s *SocialGenericOAuth) IsOrganizationMember(client *http.Client) bool {
return true return true
} }
organizations, err := s.FetchOrganizations(client) organizations, ok := s.FetchOrganizations(client)
if err != nil { if !ok {
return false return false
} }
...@@ -80,128 +82,6 @@ func (s *SocialGenericOAuth) IsOrganizationMember(client *http.Client) bool { ...@@ -80,128 +82,6 @@ func (s *SocialGenericOAuth) IsOrganizationMember(client *http.Client) bool {
return false return false
} }
// searchJSONForAttr searches the provided JSON response for the given attribute
// using the configured attribute path associated with the generic OAuth
// provider.
// Returns an empty string if an attribute is not found.
func (s *SocialGenericOAuth) searchJSONForAttr(attributePath string, data []byte) string {
if attributePath == "" {
s.log.Error("No attribute path specified")
return ""
}
if len(data) == 0 {
s.log.Error("Empty user info JSON response provided")
return ""
}
var buf interface{}
if err := json.Unmarshal(data, &buf); err != nil {
s.log.Error("Failed to unmarshal user info JSON response", "err", err.Error())
return ""
}
val, err := jmespath.Search(attributePath, buf)
if err != nil {
s.log.Error("Failed to search user info JSON response with provided path", "attributePath", attributePath, "err", err.Error())
return ""
}
strVal, ok := val.(string)
if ok {
return strVal
}
s.log.Error("Attribute not found when searching JSON with provided path", "attributePath", attributePath)
return ""
}
func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, error) {
type Record struct {
Email string `json:"email"`
Primary bool `json:"primary"`
IsPrimary bool `json:"is_primary"`
Verified bool `json:"verified"`
IsConfirmed bool `json:"is_confirmed"`
}
response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/emails"))
if err != nil {
return "", fmt.Errorf("Error getting email address: %s", err)
}
var records []Record
err = json.Unmarshal(response.Body, &records)
if err != nil {
var data struct {
Values []Record `json:"values"`
}
err = json.Unmarshal(response.Body, &data)
if err != nil {
return "", fmt.Errorf("Error getting email address: %s", err)
}
records = data.Values
}
var email = ""
for _, record := range records {
if record.Primary || record.IsPrimary {
email = record.Email
break
}
}
return email, nil
}
func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]int, error) {
type Record struct {
Id int `json:"id"`
}
response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/teams"))
if err != nil {
return nil, fmt.Errorf("Error getting team memberships: %s", err)
}
var records []Record
err = json.Unmarshal(response.Body, &records)
if err != nil {
return nil, fmt.Errorf("Error getting team memberships: %s", err)
}
var ids = make([]int, len(records))
for i, record := range records {
ids[i] = record.Id
}
return ids, nil
}
func (s *SocialGenericOAuth) FetchOrganizations(client *http.Client) ([]string, error) {
type Record struct {
Login string `json:"login"`
}
response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/orgs"))
if err != nil {
return nil, fmt.Errorf("Error getting organizations: %s", err)
}
var records []Record
err = json.Unmarshal(response.Body, &records)
if err != nil {
return nil, fmt.Errorf("Error getting organizations: %s", err)
}
var logins = make([]string, len(records))
for i, record := range records {
logins[i] = record.Login
}
return logins, nil
}
type UserInfoJson struct { type UserInfoJson struct {
Name string `json:"name"` Name string `json:"name"`
DisplayName string `json:"display_name"` DisplayName string `json:"display_name"`
...@@ -210,44 +90,38 @@ type UserInfoJson struct { ...@@ -210,44 +90,38 @@ type UserInfoJson struct {
Email string `json:"email"` Email string `json:"email"`
Upn string `json:"upn"` Upn string `json:"upn"`
Attributes map[string][]string `json:"attributes"` Attributes map[string][]string `json:"attributes"`
rawJSON []byte
}
func (info *UserInfoJson) String() string {
return fmt.Sprintf(
"Name: %s, Displayname: %s, Login: %s, Username: %s, Email: %s, Upn: %s, Attributes: %v",
info.Name, info.DisplayName, info.Login, info.Username, info.Email, info.Upn, info.Attributes)
} }
func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) { func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
var data UserInfoJson var data UserInfoJson
var rawUserInfoResponse HttpGetResponse
var err error var err error
if !s.extractToken(&data, token) { userInfo := &BasicUserInfo{}
rawUserInfoResponse, err = HttpGet(client, s.apiUrl)
if err != nil {
return nil, fmt.Errorf("Error getting user info: %s", err)
}
err = json.Unmarshal(rawUserInfoResponse.Body, &data) if s.extractToken(&data, token) {
if err != nil { s.fillUserInfo(userInfo, &data)
return nil, fmt.Errorf("Error decoding user info JSON: %s", err)
}
} }
name := s.extractName(&data) if s.extractAPI(&data, client) {
s.fillUserInfo(userInfo, &data)
}
email := s.extractEmail(&data, rawUserInfoResponse.Body) if userInfo.Email == "" {
if email == "" { userInfo.Email, err = s.FetchPrivateEmail(client)
email, err = s.FetchPrivateEmail(client)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
role := s.extractRole(&data, rawUserInfoResponse.Body) if userInfo.Login == "" {
userInfo.Login = userInfo.Email
login := s.extractLogin(&data, email)
userInfo := &BasicUserInfo{
Name: name,
Login: login,
Email: email,
Role: role,
} }
if !s.IsTeamMember(client) { if !s.IsTeamMember(client) {
...@@ -258,10 +132,28 @@ func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token) ...@@ -258,10 +132,28 @@ func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token)
return nil, errors.New("User not a member of one of the required organizations") return nil, errors.New("User not a member of one of the required organizations")
} }
s.log.Debug("User info result", "result", userInfo)
return userInfo, nil return userInfo, nil
} }
func (s *SocialGenericOAuth) fillUserInfo(userInfo *BasicUserInfo, data *UserInfoJson) {
if userInfo.Email == "" {
userInfo.Email = s.extractEmail(data)
}
if userInfo.Role == "" {
userInfo.Role = s.extractRole(data)
}
if userInfo.Name == "" {
userInfo.Name = s.extractName(data)
}
if userInfo.Login == "" {
userInfo.Login = s.extractLogin(data)
}
}
func (s *SocialGenericOAuth) extractToken(data *UserInfoJson, token *oauth2.Token) bool { func (s *SocialGenericOAuth) extractToken(data *UserInfoJson, token *oauth2.Token) bool {
var err error
idToken := token.Extra("id_token") idToken := token.Extra("id_token")
if idToken == nil { if idToken == nil {
s.log.Debug("No id_token found", "token", token) s.log.Debug("No id_token found", "token", token)
...@@ -275,34 +167,49 @@ func (s *SocialGenericOAuth) extractToken(data *UserInfoJson, token *oauth2.Toke ...@@ -275,34 +167,49 @@ func (s *SocialGenericOAuth) extractToken(data *UserInfoJson, token *oauth2.Toke
return false return false
} }
payload, err := base64.RawURLEncoding.DecodeString(matched[2]) data.rawJSON, err = base64.RawURLEncoding.DecodeString(matched[2])
if err != nil { if err != nil {
s.log.Error("Error base64 decoding id_token", "raw_payload", matched[2], "err", err) s.log.Error("Error base64 decoding id_token", "raw_payload", matched[2], "error", err)
return false return false
} }
err = json.Unmarshal(payload, data) err = json.Unmarshal(data.rawJSON, data)
if err != nil { if err != nil {
s.log.Error("Error decoding id_token JSON", "payload", string(payload), "err", err) s.log.Error("Error decoding id_token JSON", "raw_json", string(data.rawJSON), "error", err)
data.rawJSON = []byte{}
return false return false
} }
if email := s.extractEmail(data, payload); email == "" { s.log.Debug("Received id_token", "raw_json", string(data.rawJSON), "data", data)
s.log.Debug("No email found in id_token", "json", string(payload), "data", data) return true
}
func (s *SocialGenericOAuth) extractAPI(data *UserInfoJson, client *http.Client) bool {
rawUserInfoResponse, err := HttpGet(client, s.apiUrl)
if err != nil {
s.log.Debug("Error getting user info response", "url", s.apiUrl, "error", err)
return false return false
} }
data.rawJSON = rawUserInfoResponse.Body
s.log.Debug("Received id_token", "json", string(payload), "data", data) err = json.Unmarshal(data.rawJSON, data)
if err != nil {
s.log.Error("Error decoding user info response", "raw_json", data.rawJSON, "error", err)
data.rawJSON = []byte{}
return false
}
s.log.Debug("Received user info response", "raw_json", string(data.rawJSON), "data", data)
return true return true
} }
func (s *SocialGenericOAuth) extractEmail(data *UserInfoJson, userInfoResp []byte) string { func (s *SocialGenericOAuth) extractEmail(data *UserInfoJson) string {
if data.Email != "" { if data.Email != "" {
return data.Email return data.Email
} }
if s.emailAttributePath != "" { if s.emailAttributePath != "" {
email := s.searchJSONForAttr(s.emailAttributePath, userInfoResp) email := s.searchJSONForAttr(s.emailAttributePath, data.rawJSON)
if email != "" { if email != "" {
return email return email
} }
...@@ -318,15 +225,15 @@ func (s *SocialGenericOAuth) extractEmail(data *UserInfoJson, userInfoResp []byt ...@@ -318,15 +225,15 @@ func (s *SocialGenericOAuth) extractEmail(data *UserInfoJson, userInfoResp []byt
if emailErr == nil { if emailErr == nil {
return emailAddr.Address return emailAddr.Address
} }
s.log.Debug("Failed to parse e-mail address", "err", emailErr.Error()) s.log.Debug("Failed to parse e-mail address", "error", emailErr.Error())
} }
return "" return ""
} }
func (s *SocialGenericOAuth) extractRole(data *UserInfoJson, userInfoResp []byte) string { func (s *SocialGenericOAuth) extractRole(data *UserInfoJson) string {
if s.roleAttributePath != "" { if s.roleAttributePath != "" {
role := s.searchJSONForAttr(s.roleAttributePath, userInfoResp) role := s.searchJSONForAttr(s.roleAttributePath, data.rawJSON)
if role != "" { if role != "" {
return role return role
} }
...@@ -334,7 +241,7 @@ func (s *SocialGenericOAuth) extractRole(data *UserInfoJson, userInfoResp []byte ...@@ -334,7 +241,7 @@ func (s *SocialGenericOAuth) extractRole(data *UserInfoJson, userInfoResp []byte
return "" return ""
} }
func (s *SocialGenericOAuth) extractLogin(data *UserInfoJson, email string) string { func (s *SocialGenericOAuth) extractLogin(data *UserInfoJson) string {
if data.Login != "" { if data.Login != "" {
return data.Login return data.Login
} }
...@@ -343,7 +250,7 @@ func (s *SocialGenericOAuth) extractLogin(data *UserInfoJson, email string) stri ...@@ -343,7 +250,7 @@ func (s *SocialGenericOAuth) extractLogin(data *UserInfoJson, email string) stri
return data.Username return data.Username
} }
return email return ""
} }
func (s *SocialGenericOAuth) extractName(data *UserInfoJson) string { func (s *SocialGenericOAuth) extractName(data *UserInfoJson) string {
...@@ -357,3 +264,139 @@ func (s *SocialGenericOAuth) extractName(data *UserInfoJson) string { ...@@ -357,3 +264,139 @@ func (s *SocialGenericOAuth) extractName(data *UserInfoJson) string {
return "" return ""
} }
// searchJSONForAttr searches the provided JSON response for the given attribute
// using the configured attribute path associated with the generic OAuth
// provider.
// Returns an empty string if an attribute is not found.
func (s *SocialGenericOAuth) searchJSONForAttr(attributePath string, data []byte) string {
if attributePath == "" {
s.log.Error("No attribute path specified")
return ""
}
if len(data) == 0 {
s.log.Error("Empty user info JSON response provided")
return ""
}
var buf interface{}
if err := json.Unmarshal(data, &buf); err != nil {
s.log.Error("Failed to unmarshal user info JSON response", "err", err.Error())
return ""
}
val, err := jmespath.Search(attributePath, buf)
if err != nil {
s.log.Error("Failed to search user info JSON response with provided path", "attributePath", attributePath, "err", err.Error())
return ""
}
strVal, ok := val.(string)
if ok {
return strVal
}
s.log.Error("Attribute not found when searching JSON with provided path", "attributePath", attributePath)
return ""
}
func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, error) {
type Record struct {
Email string `json:"email"`
Primary bool `json:"primary"`
IsPrimary bool `json:"is_primary"`
Verified bool `json:"verified"`
IsConfirmed bool `json:"is_confirmed"`
}
response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/emails"))
if err != nil {
s.log.Error("Error getting email address", "url", s.apiUrl+"/emails", "error", err)
return "", errutil.Wrap("Error getting email address", err)
}
var records []Record
err = json.Unmarshal(response.Body, &records)
if err != nil {
var data struct {
Values []Record `json:"values"`
}
err = json.Unmarshal(response.Body, &data)
if err != nil {
s.log.Error("Error decoding email addresses response", "raw_json", string(response.Body), "error", err)
return "", errutil.Wrap("Erro decoding email addresses response", err)
}
records = data.Values
}
s.log.Debug("Received email addresses", "emails", records)
var email = ""
for _, record := range records {
if record.Primary || record.IsPrimary {
email = record.Email
break
}
}
s.log.Debug("Using email address", "email", email)
return email, nil
}
func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]int, bool) {
type Record struct {
Id int `json:"id"`
}
response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/teams"))
if err != nil {
s.log.Error("Error getting team memberships", "url", s.apiUrl+"/teams", "error", err)
return nil, false
}
var records []Record
err = json.Unmarshal(response.Body, &records)
if err != nil {
s.log.Error("Error decoding team memberships response", "raw_json", string(response.Body), "error", err)
return nil, false
}
var ids = make([]int, len(records))
for i, record := range records {
ids[i] = record.Id
}
s.log.Debug("Received team memberships", "ids", ids)
return ids, true
}
func (s *SocialGenericOAuth) FetchOrganizations(client *http.Client) ([]string, bool) {
type Record struct {
Login string `json:"login"`
}
response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/orgs"))
if err != nil {
s.log.Error("Error getting organizations", "url", s.apiUrl+"/orgs", "error", err)
return nil, false
}
var records []Record
err = json.Unmarshal(response.Body, &records)
if err != nil {
s.log.Error("Error decoding organization response", "response", string(response.Body), "error", err)
return nil, false
}
var logins = make([]string, len(records))
for i, record := range records {
logins[i] = record.Login
}
s.log.Debug("Received organizations", "logins", logins)
return logins, true
}
package social package social
import ( import (
"github.com/grafana/grafana/pkg/infra/log" "encoding/json"
. "github.com/smartystreets/goconvey/convey" "io"
"net/http"
"net/http/httptest"
"time"
"github.com/stretchr/testify/require"
"testing" "testing"
"github.com/grafana/grafana/pkg/infra/log"
"golang.org/x/oauth2"
) )
func TestSearchJSONForEmail(t *testing.T) { func TestSearchJSONForEmail(t *testing.T) {
Convey("Given a generic OAuth provider", t, func() { t.Run("Given a generic OAuth provider", func(t *testing.T) {
provider := SocialGenericOAuth{ provider := SocialGenericOAuth{
SocialBase: &SocialBase{ SocialBase: &SocialBase{
log: log.New("generic_oauth_test"), log: log.New("generic_oauth_test"),
...@@ -77,16 +86,16 @@ func TestSearchJSONForEmail(t *testing.T) { ...@@ -77,16 +86,16 @@ func TestSearchJSONForEmail(t *testing.T) {
for _, test := range tests { for _, test := range tests {
provider.emailAttributePath = test.EmailAttributePath provider.emailAttributePath = test.EmailAttributePath
Convey(test.Name, func() { t.Run(test.Name, func(t *testing.T) {
actualResult := provider.searchJSONForAttr(test.EmailAttributePath, test.UserInfoJSONResponse) actualResult := provider.searchJSONForAttr(test.EmailAttributePath, test.UserInfoJSONResponse)
So(actualResult, ShouldEqual, test.ExpectedResult) require.Equal(t, test.ExpectedResult, actualResult)
}) })
} }
}) })
} }
func TestSearchJSONForRole(t *testing.T) { func TestSearchJSONForRole(t *testing.T) {
Convey("Given a generic OAuth provider", t, func() { t.Run("Given a generic OAuth provider", func(t *testing.T) {
provider := SocialGenericOAuth{ provider := SocialGenericOAuth{
SocialBase: &SocialBase{ SocialBase: &SocialBase{
log: log.New("generic_oauth_test"), log: log.New("generic_oauth_test"),
...@@ -131,9 +140,173 @@ func TestSearchJSONForRole(t *testing.T) { ...@@ -131,9 +140,173 @@ func TestSearchJSONForRole(t *testing.T) {
for _, test := range tests { for _, test := range tests {
provider.roleAttributePath = test.RoleAttributePath provider.roleAttributePath = test.RoleAttributePath
Convey(test.Name, func() { t.Run(test.Name, func(t *testing.T) {
actualResult := provider.searchJSONForAttr(test.RoleAttributePath, test.UserInfoJSONResponse) actualResult := provider.searchJSONForAttr(test.RoleAttributePath, test.UserInfoJSONResponse)
So(actualResult, ShouldEqual, test.ExpectedResult) require.Equal(t, test.ExpectedResult, actualResult)
})
}
})
}
func TestUserInfoSearchesForEmailAndRole(t *testing.T) {
t.Run("Given a generic OAuth provider", func(t *testing.T) {
provider := SocialGenericOAuth{
SocialBase: &SocialBase{
log: log.New("generic_oauth_test"),
},
emailAttributePath: "email",
}
tests := []struct {
Name string
APIURLReponse interface{}
OAuth2Extra interface{}
RoleAttributePath string
ExpectedEmail string
ExpectedRole string
}{
{
Name: "Given a valid id_token, a valid role path, no api response, use id_token",
OAuth2Extra: map[string]interface{}{
// { "role": "Admin", "email": "john.doe@example.com" }
"id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoiQWRtaW4iLCJlbWFpbCI6ImpvaG4uZG9lQGV4YW1wbGUuY29tIn0.9PtHcCaXxZa2HDlASyKIaFGfOKlw2ILQo32xlvhvhRg",
},
RoleAttributePath: "role",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "Admin",
},
{
Name: "Given a valid id_token, no role path, no api response, use id_token",
OAuth2Extra: map[string]interface{}{
// { "email": "john.doe@example.com" }
"id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6ImpvaG4uZG9lQGV4YW1wbGUuY29tIn0.k5GwPcZvGe2BE_jgwN0ntz0nz4KlYhEd0hRRLApkTJ4",
},
RoleAttributePath: "",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "",
},
{
Name: "Given a valid id_token, an invalid role path, no api response, use id_token",
OAuth2Extra: map[string]interface{}{
// { "role": "Admin", "email": "john.doe@example.com" }
"id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoiQWRtaW4iLCJlbWFpbCI6ImpvaG4uZG9lQGV4YW1wbGUuY29tIn0.9PtHcCaXxZa2HDlASyKIaFGfOKlw2ILQo32xlvhvhRg",
},
RoleAttributePath: "invalid_path",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "",
},
{
Name: "Given no id_token, a valid role path, a valid api response, use api response",
APIURLReponse: map[string]interface{}{
"role": "Admin",
"email": "john.doe@example.com",
},
RoleAttributePath: "role",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "Admin",
},
{
Name: "Given no id_token, no role path, a valid api response, use api response",
APIURLReponse: map[string]interface{}{
"email": "john.doe@example.com",
},
RoleAttributePath: "",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "",
},
{
Name: "Given no id_token, a role path, a valid api response without a role, use api response",
APIURLReponse: map[string]interface{}{
"email": "john.doe@example.com",
},
RoleAttributePath: "role",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "",
},
{
Name: "Given no id_token, a valid role path, no api response, no data",
RoleAttributePath: "role",
ExpectedEmail: "",
ExpectedRole: "",
},
{
Name: "Given a valid id_token, a valid role path, a valid api response, prefer id_token",
OAuth2Extra: map[string]interface{}{
// { "role": "Admin", "email": "john.doe@example.com" }
"id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoiQWRtaW4iLCJlbWFpbCI6ImpvaG4uZG9lQGV4YW1wbGUuY29tIn0.9PtHcCaXxZa2HDlASyKIaFGfOKlw2ILQo32xlvhvhRg",
},
APIURLReponse: map[string]interface{}{
"role": "FromResponse",
"email": "from_response@example.com",
},
RoleAttributePath: "role",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "Admin",
},
{
Name: "Given a valid id_token, an invalid role path, a valid api response, prefer id_token",
OAuth2Extra: map[string]interface{}{
// { "role": "Admin", "email": "john.doe@example.com" }
"id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoiQWRtaW4iLCJlbWFpbCI6ImpvaG4uZG9lQGV4YW1wbGUuY29tIn0.9PtHcCaXxZa2HDlASyKIaFGfOKlw2ILQo32xlvhvhRg",
},
APIURLReponse: map[string]interface{}{
"role": "FromResponse",
"email": "from_response@example.com",
},
RoleAttributePath: "invalid_path",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "",
},
{
Name: "Given a valid id_token with no email, a valid role path, a valid api response with no role, merge",
OAuth2Extra: map[string]interface{}{
// { "role": "Admin" }
"id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoiQWRtaW4ifQ.k5GwPcZvGe2BE_jgwN0ntz0nz4KlYhEd0hRRLApkTJ4",
},
APIURLReponse: map[string]interface{}{
"email": "from_response@example.com",
},
RoleAttributePath: "role",
ExpectedEmail: "from_response@example.com",
ExpectedRole: "Admin",
},
{
Name: "Given a valid id_token with no role, a valid role path, a valid api response with no email, merge",
OAuth2Extra: map[string]interface{}{
// { "email": "john.doe@example.com" }
"id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6ImpvaG4uZG9lQGV4YW1wbGUuY29tIn0.k5GwPcZvGe2BE_jgwN0ntz0nz4KlYhEd0hRRLApkTJ4",
},
APIURLReponse: map[string]interface{}{
"role": "FromResponse",
},
RoleAttributePath: "role",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "FromResponse",
},
}
for _, test := range tests {
provider.roleAttributePath = test.RoleAttributePath
t.Run(test.Name, func(t *testing.T) {
response, _ := json.Marshal(test.APIURLReponse)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, string(response))
}))
provider.apiUrl = ts.URL
staticToken := oauth2.Token{
AccessToken: "",
TokenType: "",
RefreshToken: "",
Expiry: time.Now(),
}
token := staticToken.WithExtra(test.OAuth2Extra)
actualResult, _ := provider.UserInfo(ts.Client(), token)
require.Equal(t, test.ExpectedEmail, actualResult.Email)
require.Equal(t, test.ExpectedEmail, actualResult.Login)
require.Equal(t, test.ExpectedRole, actualResult.Role)
}) })
} }
}) })
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment