Commit 7cbf3d8d by Sean Johnson Committed by Marcus Efraimsson

OAuth: Fix role mapping from id token (#20300)

As part of the new improvements to JMESPath in #17149 a 
commit ensuring Role and Email data could be queried from 
id_token was lost.
This refactors and fixes extracting from both token and user 
info api consistently where before only either token or 
either user info api was used for extracting data/attributes.

Fixes #20243

Co-authored-by: Timo Wendt <timo@tjwendt.de>
Co-authored-by: twendt <timo@tjwendt.de>
Co-authored-by: henninge <henning@eggers.name>
Co-Authored-by: Henning Eggers <henning.eggers@inovex.de>
Co-authored-by: Marcus Efraimsson <marcus.efraimsson@gmail.com>
parent fef31acb
......@@ -9,6 +9,8 @@ import (
"net/mail"
"regexp"
"github.com/grafana/grafana/pkg/util/errutil"
"github.com/grafana/grafana/pkg/models"
"github.com/jmespath/go-jmespath"
"golang.org/x/oauth2"
......@@ -43,8 +45,8 @@ func (s *SocialGenericOAuth) IsTeamMember(client *http.Client) bool {
return true
}
teamMemberships, err := s.FetchTeamMemberships(client)
if err != nil {
teamMemberships, ok := s.FetchTeamMemberships(client)
if !ok {
return false
}
......@@ -64,8 +66,8 @@ func (s *SocialGenericOAuth) IsOrganizationMember(client *http.Client) bool {
return true
}
organizations, err := s.FetchOrganizations(client)
if err != nil {
organizations, ok := s.FetchOrganizations(client)
if !ok {
return false
}
......@@ -80,128 +82,6 @@ func (s *SocialGenericOAuth) IsOrganizationMember(client *http.Client) bool {
return false
}
// searchJSONForAttr searches the provided JSON response for the given attribute
// using the configured attribute path associated with the generic OAuth
// provider.
// Returns an empty string if an attribute is not found.
func (s *SocialGenericOAuth) searchJSONForAttr(attributePath string, data []byte) string {
if attributePath == "" {
s.log.Error("No attribute path specified")
return ""
}
if len(data) == 0 {
s.log.Error("Empty user info JSON response provided")
return ""
}
var buf interface{}
if err := json.Unmarshal(data, &buf); err != nil {
s.log.Error("Failed to unmarshal user info JSON response", "err", err.Error())
return ""
}
val, err := jmespath.Search(attributePath, buf)
if err != nil {
s.log.Error("Failed to search user info JSON response with provided path", "attributePath", attributePath, "err", err.Error())
return ""
}
strVal, ok := val.(string)
if ok {
return strVal
}
s.log.Error("Attribute not found when searching JSON with provided path", "attributePath", attributePath)
return ""
}
func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, error) {
type Record struct {
Email string `json:"email"`
Primary bool `json:"primary"`
IsPrimary bool `json:"is_primary"`
Verified bool `json:"verified"`
IsConfirmed bool `json:"is_confirmed"`
}
response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/emails"))
if err != nil {
return "", fmt.Errorf("Error getting email address: %s", err)
}
var records []Record
err = json.Unmarshal(response.Body, &records)
if err != nil {
var data struct {
Values []Record `json:"values"`
}
err = json.Unmarshal(response.Body, &data)
if err != nil {
return "", fmt.Errorf("Error getting email address: %s", err)
}
records = data.Values
}
var email = ""
for _, record := range records {
if record.Primary || record.IsPrimary {
email = record.Email
break
}
}
return email, nil
}
func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]int, error) {
type Record struct {
Id int `json:"id"`
}
response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/teams"))
if err != nil {
return nil, fmt.Errorf("Error getting team memberships: %s", err)
}
var records []Record
err = json.Unmarshal(response.Body, &records)
if err != nil {
return nil, fmt.Errorf("Error getting team memberships: %s", err)
}
var ids = make([]int, len(records))
for i, record := range records {
ids[i] = record.Id
}
return ids, nil
}
func (s *SocialGenericOAuth) FetchOrganizations(client *http.Client) ([]string, error) {
type Record struct {
Login string `json:"login"`
}
response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/orgs"))
if err != nil {
return nil, fmt.Errorf("Error getting organizations: %s", err)
}
var records []Record
err = json.Unmarshal(response.Body, &records)
if err != nil {
return nil, fmt.Errorf("Error getting organizations: %s", err)
}
var logins = make([]string, len(records))
for i, record := range records {
logins[i] = record.Login
}
return logins, nil
}
type UserInfoJson struct {
Name string `json:"name"`
DisplayName string `json:"display_name"`
......@@ -210,44 +90,38 @@ type UserInfoJson struct {
Email string `json:"email"`
Upn string `json:"upn"`
Attributes map[string][]string `json:"attributes"`
rawJSON []byte
}
func (info *UserInfoJson) String() string {
return fmt.Sprintf(
"Name: %s, Displayname: %s, Login: %s, Username: %s, Email: %s, Upn: %s, Attributes: %v",
info.Name, info.DisplayName, info.Login, info.Username, info.Email, info.Upn, info.Attributes)
}
func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
var data UserInfoJson
var rawUserInfoResponse HttpGetResponse
var err error
if !s.extractToken(&data, token) {
rawUserInfoResponse, err = HttpGet(client, s.apiUrl)
if err != nil {
return nil, fmt.Errorf("Error getting user info: %s", err)
}
userInfo := &BasicUserInfo{}
err = json.Unmarshal(rawUserInfoResponse.Body, &data)
if err != nil {
return nil, fmt.Errorf("Error decoding user info JSON: %s", err)
}
if s.extractToken(&data, token) {
s.fillUserInfo(userInfo, &data)
}
name := s.extractName(&data)
if s.extractAPI(&data, client) {
s.fillUserInfo(userInfo, &data)
}
email := s.extractEmail(&data, rawUserInfoResponse.Body)
if email == "" {
email, err = s.FetchPrivateEmail(client)
if userInfo.Email == "" {
userInfo.Email, err = s.FetchPrivateEmail(client)
if err != nil {
return nil, err
}
}
role := s.extractRole(&data, rawUserInfoResponse.Body)
login := s.extractLogin(&data, email)
userInfo := &BasicUserInfo{
Name: name,
Login: login,
Email: email,
Role: role,
if userInfo.Login == "" {
userInfo.Login = userInfo.Email
}
if !s.IsTeamMember(client) {
......@@ -258,10 +132,28 @@ func (s *SocialGenericOAuth) UserInfo(client *http.Client, token *oauth2.Token)
return nil, errors.New("User not a member of one of the required organizations")
}
s.log.Debug("User info result", "result", userInfo)
return userInfo, nil
}
func (s *SocialGenericOAuth) fillUserInfo(userInfo *BasicUserInfo, data *UserInfoJson) {
if userInfo.Email == "" {
userInfo.Email = s.extractEmail(data)
}
if userInfo.Role == "" {
userInfo.Role = s.extractRole(data)
}
if userInfo.Name == "" {
userInfo.Name = s.extractName(data)
}
if userInfo.Login == "" {
userInfo.Login = s.extractLogin(data)
}
}
func (s *SocialGenericOAuth) extractToken(data *UserInfoJson, token *oauth2.Token) bool {
var err error
idToken := token.Extra("id_token")
if idToken == nil {
s.log.Debug("No id_token found", "token", token)
......@@ -275,34 +167,49 @@ func (s *SocialGenericOAuth) extractToken(data *UserInfoJson, token *oauth2.Toke
return false
}
payload, err := base64.RawURLEncoding.DecodeString(matched[2])
data.rawJSON, err = base64.RawURLEncoding.DecodeString(matched[2])
if err != nil {
s.log.Error("Error base64 decoding id_token", "raw_payload", matched[2], "err", err)
s.log.Error("Error base64 decoding id_token", "raw_payload", matched[2], "error", err)
return false
}
err = json.Unmarshal(payload, data)
err = json.Unmarshal(data.rawJSON, data)
if err != nil {
s.log.Error("Error decoding id_token JSON", "payload", string(payload), "err", err)
s.log.Error("Error decoding id_token JSON", "raw_json", string(data.rawJSON), "error", err)
data.rawJSON = []byte{}
return false
}
if email := s.extractEmail(data, payload); email == "" {
s.log.Debug("No email found in id_token", "json", string(payload), "data", data)
s.log.Debug("Received id_token", "raw_json", string(data.rawJSON), "data", data)
return true
}
func (s *SocialGenericOAuth) extractAPI(data *UserInfoJson, client *http.Client) bool {
rawUserInfoResponse, err := HttpGet(client, s.apiUrl)
if err != nil {
s.log.Debug("Error getting user info response", "url", s.apiUrl, "error", err)
return false
}
data.rawJSON = rawUserInfoResponse.Body
s.log.Debug("Received id_token", "json", string(payload), "data", data)
err = json.Unmarshal(data.rawJSON, data)
if err != nil {
s.log.Error("Error decoding user info response", "raw_json", data.rawJSON, "error", err)
data.rawJSON = []byte{}
return false
}
s.log.Debug("Received user info response", "raw_json", string(data.rawJSON), "data", data)
return true
}
func (s *SocialGenericOAuth) extractEmail(data *UserInfoJson, userInfoResp []byte) string {
func (s *SocialGenericOAuth) extractEmail(data *UserInfoJson) string {
if data.Email != "" {
return data.Email
}
if s.emailAttributePath != "" {
email := s.searchJSONForAttr(s.emailAttributePath, userInfoResp)
email := s.searchJSONForAttr(s.emailAttributePath, data.rawJSON)
if email != "" {
return email
}
......@@ -318,15 +225,15 @@ func (s *SocialGenericOAuth) extractEmail(data *UserInfoJson, userInfoResp []byt
if emailErr == nil {
return emailAddr.Address
}
s.log.Debug("Failed to parse e-mail address", "err", emailErr.Error())
s.log.Debug("Failed to parse e-mail address", "error", emailErr.Error())
}
return ""
}
func (s *SocialGenericOAuth) extractRole(data *UserInfoJson, userInfoResp []byte) string {
func (s *SocialGenericOAuth) extractRole(data *UserInfoJson) string {
if s.roleAttributePath != "" {
role := s.searchJSONForAttr(s.roleAttributePath, userInfoResp)
role := s.searchJSONForAttr(s.roleAttributePath, data.rawJSON)
if role != "" {
return role
}
......@@ -334,7 +241,7 @@ func (s *SocialGenericOAuth) extractRole(data *UserInfoJson, userInfoResp []byte
return ""
}
func (s *SocialGenericOAuth) extractLogin(data *UserInfoJson, email string) string {
func (s *SocialGenericOAuth) extractLogin(data *UserInfoJson) string {
if data.Login != "" {
return data.Login
}
......@@ -343,7 +250,7 @@ func (s *SocialGenericOAuth) extractLogin(data *UserInfoJson, email string) stri
return data.Username
}
return email
return ""
}
func (s *SocialGenericOAuth) extractName(data *UserInfoJson) string {
......@@ -357,3 +264,139 @@ func (s *SocialGenericOAuth) extractName(data *UserInfoJson) string {
return ""
}
// searchJSONForAttr searches the provided JSON response for the given attribute
// using the configured attribute path associated with the generic OAuth
// provider.
// Returns an empty string if an attribute is not found.
func (s *SocialGenericOAuth) searchJSONForAttr(attributePath string, data []byte) string {
if attributePath == "" {
s.log.Error("No attribute path specified")
return ""
}
if len(data) == 0 {
s.log.Error("Empty user info JSON response provided")
return ""
}
var buf interface{}
if err := json.Unmarshal(data, &buf); err != nil {
s.log.Error("Failed to unmarshal user info JSON response", "err", err.Error())
return ""
}
val, err := jmespath.Search(attributePath, buf)
if err != nil {
s.log.Error("Failed to search user info JSON response with provided path", "attributePath", attributePath, "err", err.Error())
return ""
}
strVal, ok := val.(string)
if ok {
return strVal
}
s.log.Error("Attribute not found when searching JSON with provided path", "attributePath", attributePath)
return ""
}
func (s *SocialGenericOAuth) FetchPrivateEmail(client *http.Client) (string, error) {
type Record struct {
Email string `json:"email"`
Primary bool `json:"primary"`
IsPrimary bool `json:"is_primary"`
Verified bool `json:"verified"`
IsConfirmed bool `json:"is_confirmed"`
}
response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/emails"))
if err != nil {
s.log.Error("Error getting email address", "url", s.apiUrl+"/emails", "error", err)
return "", errutil.Wrap("Error getting email address", err)
}
var records []Record
err = json.Unmarshal(response.Body, &records)
if err != nil {
var data struct {
Values []Record `json:"values"`
}
err = json.Unmarshal(response.Body, &data)
if err != nil {
s.log.Error("Error decoding email addresses response", "raw_json", string(response.Body), "error", err)
return "", errutil.Wrap("Erro decoding email addresses response", err)
}
records = data.Values
}
s.log.Debug("Received email addresses", "emails", records)
var email = ""
for _, record := range records {
if record.Primary || record.IsPrimary {
email = record.Email
break
}
}
s.log.Debug("Using email address", "email", email)
return email, nil
}
func (s *SocialGenericOAuth) FetchTeamMemberships(client *http.Client) ([]int, bool) {
type Record struct {
Id int `json:"id"`
}
response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/teams"))
if err != nil {
s.log.Error("Error getting team memberships", "url", s.apiUrl+"/teams", "error", err)
return nil, false
}
var records []Record
err = json.Unmarshal(response.Body, &records)
if err != nil {
s.log.Error("Error decoding team memberships response", "raw_json", string(response.Body), "error", err)
return nil, false
}
var ids = make([]int, len(records))
for i, record := range records {
ids[i] = record.Id
}
s.log.Debug("Received team memberships", "ids", ids)
return ids, true
}
func (s *SocialGenericOAuth) FetchOrganizations(client *http.Client) ([]string, bool) {
type Record struct {
Login string `json:"login"`
}
response, err := HttpGet(client, fmt.Sprintf(s.apiUrl+"/orgs"))
if err != nil {
s.log.Error("Error getting organizations", "url", s.apiUrl+"/orgs", "error", err)
return nil, false
}
var records []Record
err = json.Unmarshal(response.Body, &records)
if err != nil {
s.log.Error("Error decoding organization response", "response", string(response.Body), "error", err)
return nil, false
}
var logins = make([]string, len(records))
for i, record := range records {
logins[i] = record.Login
}
s.log.Debug("Received organizations", "logins", logins)
return logins, true
}
package social
import (
"github.com/grafana/grafana/pkg/infra/log"
. "github.com/smartystreets/goconvey/convey"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"time"
"github.com/stretchr/testify/require"
"testing"
"github.com/grafana/grafana/pkg/infra/log"
"golang.org/x/oauth2"
)
func TestSearchJSONForEmail(t *testing.T) {
Convey("Given a generic OAuth provider", t, func() {
t.Run("Given a generic OAuth provider", func(t *testing.T) {
provider := SocialGenericOAuth{
SocialBase: &SocialBase{
log: log.New("generic_oauth_test"),
......@@ -77,16 +86,16 @@ func TestSearchJSONForEmail(t *testing.T) {
for _, test := range tests {
provider.emailAttributePath = test.EmailAttributePath
Convey(test.Name, func() {
t.Run(test.Name, func(t *testing.T) {
actualResult := provider.searchJSONForAttr(test.EmailAttributePath, test.UserInfoJSONResponse)
So(actualResult, ShouldEqual, test.ExpectedResult)
require.Equal(t, test.ExpectedResult, actualResult)
})
}
})
}
func TestSearchJSONForRole(t *testing.T) {
Convey("Given a generic OAuth provider", t, func() {
t.Run("Given a generic OAuth provider", func(t *testing.T) {
provider := SocialGenericOAuth{
SocialBase: &SocialBase{
log: log.New("generic_oauth_test"),
......@@ -131,9 +140,173 @@ func TestSearchJSONForRole(t *testing.T) {
for _, test := range tests {
provider.roleAttributePath = test.RoleAttributePath
Convey(test.Name, func() {
t.Run(test.Name, func(t *testing.T) {
actualResult := provider.searchJSONForAttr(test.RoleAttributePath, test.UserInfoJSONResponse)
So(actualResult, ShouldEqual, test.ExpectedResult)
require.Equal(t, test.ExpectedResult, actualResult)
})
}
})
}
func TestUserInfoSearchesForEmailAndRole(t *testing.T) {
t.Run("Given a generic OAuth provider", func(t *testing.T) {
provider := SocialGenericOAuth{
SocialBase: &SocialBase{
log: log.New("generic_oauth_test"),
},
emailAttributePath: "email",
}
tests := []struct {
Name string
APIURLReponse interface{}
OAuth2Extra interface{}
RoleAttributePath string
ExpectedEmail string
ExpectedRole string
}{
{
Name: "Given a valid id_token, a valid role path, no api response, use id_token",
OAuth2Extra: map[string]interface{}{
// { "role": "Admin", "email": "john.doe@example.com" }
"id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoiQWRtaW4iLCJlbWFpbCI6ImpvaG4uZG9lQGV4YW1wbGUuY29tIn0.9PtHcCaXxZa2HDlASyKIaFGfOKlw2ILQo32xlvhvhRg",
},
RoleAttributePath: "role",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "Admin",
},
{
Name: "Given a valid id_token, no role path, no api response, use id_token",
OAuth2Extra: map[string]interface{}{
// { "email": "john.doe@example.com" }
"id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6ImpvaG4uZG9lQGV4YW1wbGUuY29tIn0.k5GwPcZvGe2BE_jgwN0ntz0nz4KlYhEd0hRRLApkTJ4",
},
RoleAttributePath: "",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "",
},
{
Name: "Given a valid id_token, an invalid role path, no api response, use id_token",
OAuth2Extra: map[string]interface{}{
// { "role": "Admin", "email": "john.doe@example.com" }
"id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoiQWRtaW4iLCJlbWFpbCI6ImpvaG4uZG9lQGV4YW1wbGUuY29tIn0.9PtHcCaXxZa2HDlASyKIaFGfOKlw2ILQo32xlvhvhRg",
},
RoleAttributePath: "invalid_path",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "",
},
{
Name: "Given no id_token, a valid role path, a valid api response, use api response",
APIURLReponse: map[string]interface{}{
"role": "Admin",
"email": "john.doe@example.com",
},
RoleAttributePath: "role",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "Admin",
},
{
Name: "Given no id_token, no role path, a valid api response, use api response",
APIURLReponse: map[string]interface{}{
"email": "john.doe@example.com",
},
RoleAttributePath: "",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "",
},
{
Name: "Given no id_token, a role path, a valid api response without a role, use api response",
APIURLReponse: map[string]interface{}{
"email": "john.doe@example.com",
},
RoleAttributePath: "role",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "",
},
{
Name: "Given no id_token, a valid role path, no api response, no data",
RoleAttributePath: "role",
ExpectedEmail: "",
ExpectedRole: "",
},
{
Name: "Given a valid id_token, a valid role path, a valid api response, prefer id_token",
OAuth2Extra: map[string]interface{}{
// { "role": "Admin", "email": "john.doe@example.com" }
"id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoiQWRtaW4iLCJlbWFpbCI6ImpvaG4uZG9lQGV4YW1wbGUuY29tIn0.9PtHcCaXxZa2HDlASyKIaFGfOKlw2ILQo32xlvhvhRg",
},
APIURLReponse: map[string]interface{}{
"role": "FromResponse",
"email": "from_response@example.com",
},
RoleAttributePath: "role",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "Admin",
},
{
Name: "Given a valid id_token, an invalid role path, a valid api response, prefer id_token",
OAuth2Extra: map[string]interface{}{
// { "role": "Admin", "email": "john.doe@example.com" }
"id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoiQWRtaW4iLCJlbWFpbCI6ImpvaG4uZG9lQGV4YW1wbGUuY29tIn0.9PtHcCaXxZa2HDlASyKIaFGfOKlw2ILQo32xlvhvhRg",
},
APIURLReponse: map[string]interface{}{
"role": "FromResponse",
"email": "from_response@example.com",
},
RoleAttributePath: "invalid_path",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "",
},
{
Name: "Given a valid id_token with no email, a valid role path, a valid api response with no role, merge",
OAuth2Extra: map[string]interface{}{
// { "role": "Admin" }
"id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoiQWRtaW4ifQ.k5GwPcZvGe2BE_jgwN0ntz0nz4KlYhEd0hRRLApkTJ4",
},
APIURLReponse: map[string]interface{}{
"email": "from_response@example.com",
},
RoleAttributePath: "role",
ExpectedEmail: "from_response@example.com",
ExpectedRole: "Admin",
},
{
Name: "Given a valid id_token with no role, a valid role path, a valid api response with no email, merge",
OAuth2Extra: map[string]interface{}{
// { "email": "john.doe@example.com" }
"id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJlbWFpbCI6ImpvaG4uZG9lQGV4YW1wbGUuY29tIn0.k5GwPcZvGe2BE_jgwN0ntz0nz4KlYhEd0hRRLApkTJ4",
},
APIURLReponse: map[string]interface{}{
"role": "FromResponse",
},
RoleAttributePath: "role",
ExpectedEmail: "john.doe@example.com",
ExpectedRole: "FromResponse",
},
}
for _, test := range tests {
provider.roleAttributePath = test.RoleAttributePath
t.Run(test.Name, func(t *testing.T) {
response, _ := json.Marshal(test.APIURLReponse)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, string(response))
}))
provider.apiUrl = ts.URL
staticToken := oauth2.Token{
AccessToken: "",
TokenType: "",
RefreshToken: "",
Expiry: time.Now(),
}
token := staticToken.WithExtra(test.OAuth2Extra)
actualResult, _ := provider.UserInfo(ts.Client(), token)
require.Equal(t, test.ExpectedEmail, actualResult.Email)
require.Equal(t, test.ExpectedEmail, actualResult.Login)
require.Equal(t, test.ExpectedRole, actualResult.Role)
})
}
})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment