Commit ff6a082e by twendt Committed by GitHub

Auth: Azure AD OAuth (#20030)

* Implement Azure AD oauth

* Use go-jose and cleanup

* Update go-jose in go.mod

* cleanup

* Add unit tests

* Fix scopes

* Add documentation page

* Improve documentation

* Convert extract_role into function.

* Do not use upn and replace unique_name with preferred_username

* Configure login button

* Use official microsoft icon and color from branding guideline.

* Add Azure AD config section in sample.ini.
parent ceca067d
......@@ -366,6 +366,17 @@ client_secret = some_secret
scopes = user:email
allowed_organizations =
#################################### Azure AD OAuth #######################
[auth.azuread]
name = Azure AD
enabled = false
allow_sign_up = true
client_id = some_client_id
client_secret = some_client_secret
scopes = openid email profile
auth_url = https://login.microsoftonline.com/<tenant-id>/oauth2/v2.0/authorize
token_url = https://login.microsoftonline.com/<tenant-id>/oauth2/v2.0/token
#################################### Generic OAuth #######################
[auth.generic_oauth]
name = OAuth
......
......@@ -356,6 +356,17 @@
;scopes = user:email
;allowed_organizations =
#################################### Azure AD OAuth #######################
[auth.azuread]
;name = Azure AD
;enabled = false
;allow_sign_up = true
;client_id = some_client_id
;client_secret = some_client_secret
;scopes = openid email profile
;auth_url = https://login.microsoftonline.com/<tenant-id>/oauth2/v2.0/authorize
;token_url = https://login.microsoftonline.com/<tenant-id>/oauth2/v2.0/token
#################################### Generic OAuth ##########################
[auth.generic_oauth]
;enabled = false
......
+++
title = "Azure AD OAuth2 Authentication"
description = "Grafana OAuthentication Guide "
keywords = ["grafana", "configuration", "documentation", "oauth"]
type = "docs"
[menu.docs]
name = "Azure AD"
identifier = "azuread_oauth2"
parent = "authentication"
weight = 3
+++
# Azure AD OAuth2 Authentication
The Azure AD authentication provides the possibility to use an Azure Active Directory tenant as an identity provider for Grafana.
By using Azure AD Application Roles it is also possible to assign Users and Groups to Grafana roles from the Azure Portal.
To enable the Azure AD OAuth2 you must register your application with Azure AD.
# Create Azure AD application
1. Log in to [Azure Portal](https://portal.azure.com) and click **Azure Active Directory** in the side menu.
1. Click **App Registrations** and add a new application registration:
- Name: Grafana
- Application type: Web app / API
- Sign-on URL: `https://<grafana domain>/login/azuread`
1. Click the name of the new application to open the application details page.
1. Click **Endpoints**.
- Note down the **OAuth 2.0 authorization endpoint (v2)**, this will be the auth url.
- Note down the **OAuth 2.0 token endpoint (v2)**, this will be the token url.
1. Close the Endpoints page to come back to the application details page.
1. Note down the "Application ID", this will be the OAuth client id.
1. Click **Certificates & secrets** and add a new entry under Client secrets.
- Description: Grafana OAuth
- Expires: Never
1. Click **Add** then copy the key value, this will be the OAuth client secret.
1. Click **Manifest**.
- Add definitions for the required Application Roles for Grafana (Viewer, Editor, Admin). Without this configuration all users will be assigned to the Viewer role.
- Every role has to have a unique id. On Linux this can be created with `uuidgen` for instance.
```json
"appRoles": [
{
"allowedMemberTypes": [
"User"
],
"description": "Grafana admin Users",
"displayName": "Grafana Admin",
"id": "SOME_UNIQUE_ID",
"isEnabled": true,
"lang": null,
"origin": "Application",
"value": "Admin"
},
{
"allowedMemberTypes": [
"User"
],
"description": "Grafana read only Users",
"displayName": "Grafana Viewer",
"id": "SOME_UNIQUE_ID",
"isEnabled": true,
"lang": null,
"origin": "Application",
"value": "Viewer"
},
{
"allowedMemberTypes": [
"User"
],
"description": "Grafana Editor Users",
"displayName": "Grafana Editor",
"id": "SOME_UNIQUE_ID",
"isEnabled": true,
"lang": null,
"origin": "Application",
"value": "Editor"
}
],
```
1. Click Overview and then on **Managed application in local directory** to show the Enterprise Application details.
1. Click on **Users and groups** and add Users/Groups to the Grafana roles by using **Add User**.
1. Add the following to the [Grafana configuration file]({{< relref "../installation/configuration.md#config-file-locations" >}}):
```ini
[auth.azuread]
name = Azure AD
enabled = true
allow_sign_up = true
client_id = APPLICATION_ID
client_secret = CLIENT_SECRET
scopes = openid email profile
auth_url = https://login.microsoftonline.com/TENANT_ID/oauth2/v2.0/authorize
token_url = https://login.microsoftonline.com/TENANT_ID/oauth2/v2.0/token
```
> Note: Ensure that the [root_url]({{< relref "../installation/configuration/#root-url" >}}) in Grafana is set in your Azure Application Reply URLs (App -> Settings -> Reply URLs)
......@@ -39,6 +39,8 @@
name: Generic OAuth
- link: /auth/google/
name: Google
- link: /auth/azuread/
name: Azure AD
- link: /auth/github/
name: GitHub
- link: /auth/gitlab/
......
......@@ -83,6 +83,6 @@ require (
gopkg.in/macaron.v1 v1.3.4
gopkg.in/mail.v2 v2.3.1
gopkg.in/redis.v5 v5.2.9
gopkg.in/square/go-jose.v2 v2.3.0
gopkg.in/square/go-jose.v2 v2.4.1
gopkg.in/yaml.v2 v2.2.5
)
......@@ -419,6 +419,8 @@ gopkg.in/redis.v5 v5.2.9 h1:MNZYOLPomQzZMfpN3ZtD1uyJ2IDonTTlxYiV/pEApiw=
gopkg.in/redis.v5 v5.2.9/go.mod h1:6gtv0/+A4iM08kdRfocWYB3bLX2tebpNtfKlFT6H4mY=
gopkg.in/square/go-jose.v2 v2.3.0 h1:nLzhkFyl5bkblqYBoiWJUt5JkWOzmiaBtCxdJAqJd3U=
gopkg.in/square/go-jose.v2 v2.3.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/square/go-jose.v2 v2.4.1 h1:H0TmLt7/KmzlrDOpa1F+zr0Tk90PbJYBfsVUmRLrf9Y=
gopkg.in/square/go-jose.v2 v2.4.1/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI=
gopkg.in/stretchr/testify.v1 v1.2.2/go.mod h1:QI5V/q6UbPmuhtm10CaFZxED9NreB8PnFYN9JcR6TxU=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
......
......@@ -219,6 +219,11 @@ $external-services: (
borderColor: #b83e31,
icon: '',
),
azuread: (
bgColor: #2f2f2f,
borderColor: #2f2f2f,
icon: '',
),
grafanacom: (
bgColor: #262628,
borderColor: #393939,
......
......@@ -341,6 +341,8 @@ func GetAuthProviderLabel(authModule string) string {
return "GitHub"
case "oauth_google":
return "Google"
case "oauth_azuread":
return "AzureAD"
case "oauth_gitlab":
return "GitLab"
case "oauth_grafana_com", "oauth_grafananet":
......
......@@ -160,6 +160,7 @@ func TestMetrics(t *testing.T) {
oauthProviders := map[string]bool{
"github": true,
"gitlab": true,
"azuread": true,
"google": true,
"generic_oauth": true,
"grafana_com": true,
......@@ -254,6 +255,7 @@ func TestMetrics(t *testing.T) {
So(metrics.Get("stats.auth_enabled.oauth_github.count").MustInt(), ShouldEqual, 1)
So(metrics.Get("stats.auth_enabled.oauth_gitlab.count").MustInt(), ShouldEqual, 1)
So(metrics.Get("stats.auth_enabled.oauth_google.count").MustInt(), ShouldEqual, 1)
So(metrics.Get("stats.auth_enabled.oauth_azuread.count").MustInt(), ShouldEqual, 1)
So(metrics.Get("stats.auth_enabled.oauth_generic_oauth.count").MustInt(), ShouldEqual, 1)
So(metrics.Get("stats.auth_enabled.oauth_grafana_com.count").MustInt(), ShouldEqual, 1)
......
package social
import (
"errors"
"fmt"
"net/http"
"strings"
"github.com/grafana/grafana/pkg/models"
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2/jwt"
)
type SocialAzureAD struct {
*SocialBase
allowedDomains []string
allowSignup bool
}
type azureClaims struct {
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Roles []string `json:"roles"`
Name string `json:"name"`
ID string `json:"oid"`
}
func (s *SocialAzureAD) Type() int {
return int(models.AZUREAD)
}
func (s *SocialAzureAD) IsEmailAllowed(email string) bool {
return isEmailAllowed(email, s.allowedDomains)
}
func (s *SocialAzureAD) IsSignupAllowed() bool {
return s.allowSignup
}
func (s *SocialAzureAD) UserInfo(_ *http.Client, token *oauth2.Token) (*BasicUserInfo, error) {
idToken := token.Extra("id_token")
if idToken == nil {
return nil, fmt.Errorf("No id_token found")
}
parsedToken, err := jwt.ParseSigned(idToken.(string))
if err != nil {
return nil, fmt.Errorf("Error parsing id token")
}
var claims azureClaims
if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil {
return nil, fmt.Errorf("Error getting claims from id token")
}
email := extractEmail(claims)
if email == "" {
return nil, errors.New("Error getting user info: No email found in access token")
}
role := extractRole(claims)
return &BasicUserInfo{
Id: claims.ID,
Name: claims.Name,
Email: email,
Login: email,
Role: string(role),
}, nil
}
func extractEmail(claims azureClaims) string {
if claims.Email == "" {
if claims.PreferredUsername != "" {
return claims.PreferredUsername
}
}
return claims.Email
}
func extractRole(claims azureClaims) models.RoleType {
if len(claims.Roles) == 0 {
return models.ROLE_VIEWER
}
roleOrder := []models.RoleType{
models.ROLE_ADMIN,
models.ROLE_EDITOR,
models.ROLE_VIEWER,
}
for _, role := range roleOrder {
if found := hasRole(claims.Roles, role); found {
return role
}
}
return models.ROLE_VIEWER
}
func hasRole(roles []string, role models.RoleType) bool {
for _, item := range roles {
if strings.EqualFold(item, string(role)) {
return true
}
}
return false
}
package social
import (
"golang.org/x/oauth2"
"gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/jwt"
"net/http"
"reflect"
"testing"
"time"
)
func TestSocialAzureAD_UserInfo(t *testing.T) {
type fields struct {
SocialBase *SocialBase
allowedDomains []string
allowSignup bool
}
type args struct {
client *http.Client
}
tests := []struct {
name string
fields fields
claims *azureClaims
args args
want *BasicUserInfo
wantErr bool
}{
{
name: "Email in email claim",
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
Roles: []string{},
Name: "My Name",
ID: "1234",
},
want: &BasicUserInfo{
Id: "1234",
Name: "My Name",
Email: "me@example.com",
Login: "me@example.com",
Company: "",
Role: "Viewer",
Groups: nil,
},
},
{
name: "No email",
claims: &azureClaims{
Email: "",
PreferredUsername: "",
Roles: []string{},
Name: "My Name",
ID: "1234",
},
want: nil,
wantErr: true,
},
{
name: "No id token",
claims: nil,
want: nil,
wantErr: true,
},
{
name: "Email in preferred_username claim",
claims: &azureClaims{
Email: "",
PreferredUsername: "me@example.com",
Roles: []string{},
Name: "My Name",
ID: "1234",
},
want: &BasicUserInfo{
Id: "1234",
Name: "My Name",
Email: "me@example.com",
Login: "me@example.com",
Company: "",
Role: "Viewer",
Groups: nil,
},
},
{
name: "Admin role",
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
Roles: []string{"Admin"},
Name: "My Name",
ID: "1234",
},
want: &BasicUserInfo{
Id: "1234",
Name: "My Name",
Email: "me@example.com",
Login: "me@example.com",
Company: "",
Role: "Admin",
Groups: nil,
},
},
{
name: "Lowercase Admin role",
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
Roles: []string{"admin"},
Name: "My Name",
ID: "1234",
},
want: &BasicUserInfo{
Id: "1234",
Name: "My Name",
Email: "me@example.com",
Login: "me@example.com",
Company: "",
Role: "Admin",
Groups: nil,
},
},
{
name: "Only other roles",
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
Roles: []string{"AppAdmin"},
Name: "My Name",
ID: "1234",
},
want: &BasicUserInfo{
Id: "1234",
Name: "My Name",
Email: "me@example.com",
Login: "me@example.com",
Company: "",
Role: "Viewer",
Groups: nil,
},
},
{
name: "Editor role",
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
Roles: []string{"Editor"},
Name: "My Name",
ID: "1234",
},
want: &BasicUserInfo{
Id: "1234",
Name: "My Name",
Email: "me@example.com",
Login: "me@example.com",
Company: "",
Role: "Editor",
Groups: nil,
},
},
{
name: "Admin and Editor roles in claim",
claims: &azureClaims{
Email: "me@example.com",
PreferredUsername: "",
Roles: []string{"Admin", "Editor"},
Name: "My Name",
ID: "1234",
},
want: &BasicUserInfo{
Id: "1234",
Name: "My Name",
Email: "me@example.com",
Login: "me@example.com",
Company: "",
Role: "Admin",
Groups: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &SocialAzureAD{
SocialBase: tt.fields.SocialBase,
allowedDomains: tt.fields.allowedDomains,
allowSignup: tt.fields.allowSignup,
}
key := []byte("secret")
sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.HS256, Key: key}, (&jose.SignerOptions{}).WithType("JWT"))
if err != nil {
panic(err)
}
cl := jwt.Claims{
Subject: "subject",
Issuer: "issuer",
NotBefore: jwt.NewNumericDate(time.Date(2016, 1, 1, 0, 0, 0, 0, time.UTC)),
Audience: jwt.Audience{"leela", "fry"},
}
var raw string
if tt.claims != nil {
raw, err = jwt.Signed(sig).Claims(cl).Claims(tt.claims).CompactSerialize()
if err != nil {
t.Error(err)
}
} else {
raw, err = jwt.Signed(sig).Claims(cl).CompactSerialize()
if err != nil {
t.Error(err)
}
}
token := &oauth2.Token{}
if tt.claims != nil {
token = token.WithExtra(map[string]interface{}{"id_token": raw})
}
got, err := s.UserInfo(tt.args.client, token)
if (err != nil) != tt.wantErr {
t.Errorf("UserInfo() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("UserInfo() got = %v, want %v", got, tt.want)
}
})
}
}
......@@ -55,7 +55,7 @@ const (
var (
SocialBaseUrl = "/login/"
SocialMap = make(map[string]SocialConnector)
allOauthes = []string{"github", "gitlab", "google", "generic_oauth", "grafananet", grafanaCom}
allOauthes = []string{"github", "gitlab", "google", "generic_oauth", "grafananet", grafanaCom, "azuread"}
)
func NewOAuthService() {
......@@ -152,6 +152,18 @@ func NewOAuthService() {
}
}
// AzureAD.
if name == "azuread" {
SocialMap["azuread"] = &SocialAzureAD{
SocialBase: &SocialBase{
Config: &config,
log: logger,
},
allowedDomains: info.AllowedDomains,
allowSignup: info.AllowSignup,
}
}
// Generic - Uses the same scheme as Github.
if name == "generic_oauth" {
SocialMap["generic_oauth"] = &SocialGenericOAuth{
......
......@@ -9,4 +9,5 @@ const (
GENERIC
GRAFANA_COM
GITLAB
AZUREAD
)
......@@ -15,6 +15,10 @@ const loginServices: () => LoginServices = () => {
enabled: oauthEnabled && config.oauth.google,
name: 'Google',
},
azuread: {
enabled: config.oauth.azuread,
name: 'Microsoft',
},
github: {
enabled: oauthEnabled && config.oauth.github,
name: 'GitHub',
......
<svg xmlns="http://www.w3.org/2000/svg" width="21" height="21" viewBox="0 0 21 21"><title>MS-SymbolLockup</title><rect x="1" y="1" width="9" height="9" fill="#f25022"/><rect x="1" y="11" width="9" height="9" fill="#00a4ef"/><rect x="11" y="1" width="9" height="9" fill="#7fba00"/><rect x="11" y="11" width="9" height="9" fill="#ffb900"/></svg>
\ No newline at end of file
......@@ -222,6 +222,11 @@ $external-services: (
borderColor: #b83e31,
icon: '',
),
azuread: (
bgColor: #2f2f2f,
borderColor: #2f2f2f,
icon: '',
),
grafanacom: (
bgColor: #262628,
borderColor: #393939,
......
......@@ -216,6 +216,15 @@ $btn-service-icon-width: 35px;
}
}
.btn-service--azuread {
.btn-service-icon {
background-image: url(/public/img/microsoft_auth_icon.svg);
background-repeat: no-repeat;
background-position: 50%;
background-size: 60%;
}
}
//Toggle button
.toggle-btn {
......
......@@ -29,7 +29,7 @@ import (
"math/big"
"golang.org/x/crypto/ed25519"
"gopkg.in/square/go-jose.v2/cipher"
josecipher "gopkg.in/square/go-jose.v2/cipher"
"gopkg.in/square/go-jose.v2/json"
)
......@@ -288,7 +288,7 @@ func (ctx rsaDecrypterSigner) signPayload(payload []byte, alg SignatureAlgorithm
out, err = rsa.SignPKCS1v15(RandReader, ctx.privateKey, hash, hashed)
case PS256, PS384, PS512:
out, err = rsa.SignPSS(RandReader, ctx.privateKey, hash, hashed, &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto,
SaltLength: rsa.PSSSaltLengthEqualsHash,
})
}
......
......@@ -17,8 +17,10 @@
package josecipher
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"encoding/binary"
)
......@@ -44,16 +46,38 @@ func DeriveECDHES(alg string, apuData, apvData []byte, priv *ecdsa.PrivateKey, p
panic("public key not on same curve as private key")
}
z, _ := priv.PublicKey.Curve.ScalarMult(pub.X, pub.Y, priv.D.Bytes())
reader := NewConcatKDF(crypto.SHA256, z.Bytes(), algID, ptyUInfo, ptyVInfo, supPubInfo, []byte{})
z, _ := priv.Curve.ScalarMult(pub.X, pub.Y, priv.D.Bytes())
zBytes := z.Bytes()
// Note that calling z.Bytes() on a big.Int may strip leading zero bytes from
// the returned byte array. This can lead to a problem where zBytes will be
// shorter than expected which breaks the key derivation. Therefore we must pad
// to the full length of the expected coordinate here before calling the KDF.
octSize := dSize(priv.Curve)
if len(zBytes) != octSize {
zBytes = append(bytes.Repeat([]byte{0}, octSize-len(zBytes)), zBytes...)
}
reader := NewConcatKDF(crypto.SHA256, zBytes, algID, ptyUInfo, ptyVInfo, supPubInfo, []byte{})
key := make([]byte, size)
// Read on the KDF will never fail
_, _ = reader.Read(key)
return key
}
// dSize returns the size in octets for a coordinate on a elliptic curve.
func dSize(curve elliptic.Curve) int {
order := curve.Params().P
bitLen := order.BitLen()
size := bitLen / 8
if bitLen%8 != 0 {
size++
}
return size
}
func lengthPrefixed(data []byte) []byte {
out := make([]byte, len(data)+4)
binary.BigEndian.PutUint32(out, uint32(len(data)))
......
......@@ -141,6 +141,8 @@ func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions)
keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
case *JSONWebKey:
keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
case OpaqueKeyEncrypter:
keyID, rawKey = encryptionKey.KeyID(), encryptionKey
default:
rawKey = encryptionKey
}
......@@ -267,9 +269,11 @@ func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKey
recipient, err := makeJWERecipient(alg, encryptionKey.Key)
recipient.keyID = encryptionKey.KeyID
return recipient, err
default:
return recipientKeyInfo{}, ErrUnsupportedKeyType
}
if encrypter, ok := encryptionKey.(OpaqueKeyEncrypter); ok {
return newOpaqueKeyEncrypter(alg, encrypter)
}
return recipientKeyInfo{}, ErrUnsupportedKeyType
}
// newDecrypter creates an appropriate decrypter based on the key type
......@@ -295,9 +299,11 @@ func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) {
return newDecrypter(decryptionKey.Key)
case *JSONWebKey:
return newDecrypter(decryptionKey.Key)
default:
return nil, ErrUnsupportedKeyType
}
if okd, ok := decryptionKey.(OpaqueKeyDecrypter); ok {
return &opaqueKeyDecrypter{decrypter: okd}, nil
}
return nil, ErrUnsupportedKeyType
}
// Implementation of encrypt method producing a JWE object.
......
......@@ -23,13 +23,12 @@ import (
"encoding/binary"
"io"
"math/big"
"regexp"
"strings"
"unicode"
"gopkg.in/square/go-jose.v2/json"
)
var stripWhitespaceRegex = regexp.MustCompile("\\s")
// Helper function to serialize known-good objects.
// Precondition: value is not a nil pointer.
func mustSerializeJSON(value interface{}) []byte {
......@@ -56,7 +55,14 @@ func mustSerializeJSON(value interface{}) []byte {
// Strip all newlines and whitespace
func stripWhitespace(data string) string {
return stripWhitespaceRegex.ReplaceAllString(data, "")
buf := strings.Builder{}
buf.Grow(len(data))
for _, r := range data {
if !unicode.IsSpace(r) {
buf.WriteRune(r)
}
}
return buf.String()
}
// Perform compression based on algorithm
......
......@@ -230,7 +230,7 @@ func (k *JSONWebKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
case *rsa.PrivateKey:
input, err = rsaThumbprintInput(key.N, key.E)
case ed25519.PrivateKey:
input, err = edThumbprintInput(ed25519.PublicKey(key[0:32]))
input, err = edThumbprintInput(ed25519.PublicKey(key[32:]))
default:
return nil, fmt.Errorf("square/go-jose: unknown key type '%s'", reflect.TypeOf(key))
}
......@@ -357,11 +357,11 @@ func (key rawJSONWebKey) ecPublicKey() (*ecdsa.PublicKey, error) {
// the curve specified in the "crv" parameter.
// https://tools.ietf.org/html/rfc7518#section-6.2.1.2
if curveSize(curve) != len(key.X.data) {
return nil, fmt.Errorf("square/go-jose: invalid EC private key, wrong length for x")
return nil, fmt.Errorf("square/go-jose: invalid EC public key, wrong length for x")
}
if curveSize(curve) != len(key.Y.data) {
return nil, fmt.Errorf("square/go-jose: invalid EC private key, wrong length for y")
return nil, fmt.Errorf("square/go-jose: invalid EC public key, wrong length for y")
}
x := key.X.bigInt()
......@@ -421,8 +421,8 @@ func (key rawJSONWebKey) edPrivateKey() (ed25519.PrivateKey, error) {
}
privateKey := make([]byte, ed25519.PrivateKeySize)
copy(privateKey[0:32], key.X.bytes())
copy(privateKey[32:], key.D.bytes())
copy(privateKey[0:32], key.D.bytes())
copy(privateKey[32:], key.X.bytes())
rv := ed25519.PrivateKey(privateKey)
return rv, nil
}
......@@ -483,9 +483,9 @@ func (key rawJSONWebKey) rsaPrivateKey() (*rsa.PrivateKey, error) {
}
func fromEdPrivateKey(ed ed25519.PrivateKey) (*rawJSONWebKey, error) {
raw := fromEdPublicKey(ed25519.PublicKey(ed[0:32]))
raw := fromEdPublicKey(ed25519.PublicKey(ed[32:]))
raw.D = newBuffer(ed[32:])
raw.D = newBuffer(ed[0:32])
return raw, nil
}
......
......@@ -17,6 +17,7 @@
package jose
import (
"bytes"
"encoding/base64"
"errors"
"fmt"
......@@ -75,13 +76,21 @@ type Signature struct {
}
// ParseSigned parses a signed message in compact or full serialization format.
func ParseSigned(input string) (*JSONWebSignature, error) {
input = stripWhitespace(input)
if strings.HasPrefix(input, "{") {
return parseSignedFull(input)
func ParseSigned(signature string) (*JSONWebSignature, error) {
signature = stripWhitespace(signature)
if strings.HasPrefix(signature, "{") {
return parseSignedFull(signature)
}
return parseSignedCompact(input)
return parseSignedCompact(signature, nil)
}
// ParseDetached parses a signed message in compact serialization format with detached payload.
func ParseDetached(signature string, payload []byte) (*JSONWebSignature, error) {
if payload == nil {
return nil, errors.New("square/go-jose: nil payload")
}
return parseSignedCompact(stripWhitespace(signature), payload)
}
// Get a header value
......@@ -93,20 +102,39 @@ func (sig Signature) mergedHeaders() rawHeader {
}
// Compute data to be signed
func (obj JSONWebSignature) computeAuthData(payload []byte, signature *Signature) []byte {
var serializedProtected string
func (obj JSONWebSignature) computeAuthData(payload []byte, signature *Signature) ([]byte, error) {
var authData bytes.Buffer
protectedHeader := new(rawHeader)
if signature.original != nil && signature.original.Protected != nil {
serializedProtected = signature.original.Protected.base64()
if err := json.Unmarshal(signature.original.Protected.bytes(), protectedHeader); err != nil {
return nil, err
}
authData.WriteString(signature.original.Protected.base64())
} else if signature.protected != nil {
serializedProtected = base64.RawURLEncoding.EncodeToString(mustSerializeJSON(signature.protected))
protectedHeader = signature.protected
authData.WriteString(base64.RawURLEncoding.EncodeToString(mustSerializeJSON(protectedHeader)))
}
needsBase64 := true
if protectedHeader != nil {
var err error
if needsBase64, err = protectedHeader.getB64(); err != nil {
needsBase64 = true
}
}
authData.WriteByte('.')
if needsBase64 {
authData.WriteString(base64.RawURLEncoding.EncodeToString(payload))
} else {
serializedProtected = ""
authData.Write(payload)
}
return []byte(fmt.Sprintf("%s.%s",
serializedProtected,
base64.RawURLEncoding.EncodeToString(payload)))
return authData.Bytes(), nil
}
// parseSignedFull parses a message in full format.
......@@ -246,20 +274,26 @@ func (parsed *rawJSONWebSignature) sanitized() (*JSONWebSignature, error) {
}
// parseSignedCompact parses a message in compact format.
func parseSignedCompact(input string) (*JSONWebSignature, error) {
func parseSignedCompact(input string, payload []byte) (*JSONWebSignature, error) {
parts := strings.Split(input, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("square/go-jose: compact JWS format must have three parts")
}
if parts[1] != "" && payload != nil {
return nil, fmt.Errorf("square/go-jose: payload is not detached")
}
rawProtected, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil {
return nil, err
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, err
if payload == nil {
payload, err = base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, err
}
}
signature, err := base64.RawURLEncoding.DecodeString(parts[2])
......@@ -275,19 +309,30 @@ func parseSignedCompact(input string) (*JSONWebSignature, error) {
return raw.sanitized()
}
// CompactSerialize serializes an object using the compact serialization format.
func (obj JSONWebSignature) CompactSerialize() (string, error) {
func (obj JSONWebSignature) compactSerialize(detached bool) (string, error) {
if len(obj.Signatures) != 1 || obj.Signatures[0].header != nil || obj.Signatures[0].protected == nil {
return "", ErrNotSupported
}
serializedProtected := mustSerializeJSON(obj.Signatures[0].protected)
serializedProtected := base64.RawURLEncoding.EncodeToString(mustSerializeJSON(obj.Signatures[0].protected))
payload := ""
signature := base64.RawURLEncoding.EncodeToString(obj.Signatures[0].Signature)
if !detached {
payload = base64.RawURLEncoding.EncodeToString(obj.payload)
}
return fmt.Sprintf("%s.%s.%s", serializedProtected, payload, signature), nil
}
// CompactSerialize serializes an object using the compact serialization format.
func (obj JSONWebSignature) CompactSerialize() (string, error) {
return obj.compactSerialize(false)
}
return fmt.Sprintf(
"%s.%s.%s",
base64.RawURLEncoding.EncodeToString(serializedProtected),
base64.RawURLEncoding.EncodeToString(obj.payload),
base64.RawURLEncoding.EncodeToString(obj.Signatures[0].Signature)), nil
// DetachedCompactSerialize serializes an object using the compact serialization format with detached payload.
func (obj JSONWebSignature) DetachedCompactSerialize() (string, error) {
return obj.compactSerialize(true)
}
// FullSerialize serializes an object using the full JSON serialization format.
......
/*-
* Copyright 2016 Zbigniew Mandziejewicz
* Copyright 2016 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jwt
import (
"bytes"
"reflect"
"gopkg.in/square/go-jose.v2/json"
"gopkg.in/square/go-jose.v2"
)
// Builder is a utility for making JSON Web Tokens. Calls can be chained, and
// errors are accumulated until the final call to CompactSerialize/FullSerialize.
type Builder interface {
// Claims encodes claims into JWE/JWS form. Multiple calls will merge claims
// into single JSON object. If you are passing private claims, make sure to set
// struct field tags to specify the name for the JSON key to be used when
// serializing.
Claims(i interface{}) Builder
// Token builds a JSONWebToken from provided data.
Token() (*JSONWebToken, error)
// FullSerialize serializes a token using the full serialization format.
FullSerialize() (string, error)
// CompactSerialize serializes a token using the compact serialization format.
CompactSerialize() (string, error)
}
// NestedBuilder is a utility for making Signed-Then-Encrypted JSON Web Tokens.
// Calls can be chained, and errors are accumulated until final call to
// CompactSerialize/FullSerialize.
type NestedBuilder interface {
// Claims encodes claims into JWE/JWS form. Multiple calls will merge claims
// into single JSON object. If you are passing private claims, make sure to set
// struct field tags to specify the name for the JSON key to be used when
// serializing.
Claims(i interface{}) NestedBuilder
// Token builds a NestedJSONWebToken from provided data.
Token() (*NestedJSONWebToken, error)
// FullSerialize serializes a token using the full serialization format.
FullSerialize() (string, error)
// CompactSerialize serializes a token using the compact serialization format.
CompactSerialize() (string, error)
}
type builder struct {
payload map[string]interface{}
err error
}
type signedBuilder struct {
builder
sig jose.Signer
}
type encryptedBuilder struct {
builder
enc jose.Encrypter
}
type nestedBuilder struct {
builder
sig jose.Signer
enc jose.Encrypter
}
// Signed creates builder for signed tokens.
func Signed(sig jose.Signer) Builder {
return &signedBuilder{
sig: sig,
}
}
// Encrypted creates builder for encrypted tokens.
func Encrypted(enc jose.Encrypter) Builder {
return &encryptedBuilder{
enc: enc,
}
}
// SignedAndEncrypted creates builder for signed-then-encrypted tokens.
// ErrInvalidContentType will be returned if encrypter doesn't have JWT content type.
func SignedAndEncrypted(sig jose.Signer, enc jose.Encrypter) NestedBuilder {
if contentType, _ := enc.Options().ExtraHeaders[jose.HeaderContentType].(jose.ContentType); contentType != "JWT" {
return &nestedBuilder{
builder: builder{
err: ErrInvalidContentType,
},
}
}
return &nestedBuilder{
sig: sig,
enc: enc,
}
}
func (b builder) claims(i interface{}) builder {
if b.err != nil {
return b
}
m, ok := i.(map[string]interface{})
switch {
case ok:
return b.merge(m)
case reflect.Indirect(reflect.ValueOf(i)).Kind() == reflect.Struct:
m, err := normalize(i)
if err != nil {
return builder{
err: err,
}
}
return b.merge(m)
default:
return builder{
err: ErrInvalidClaims,
}
}
}
func normalize(i interface{}) (map[string]interface{}, error) {
m := make(map[string]interface{})
raw, err := json.Marshal(i)
if err != nil {
return nil, err
}
d := json.NewDecoder(bytes.NewReader(raw))
d.UseNumber()
if err := d.Decode(&m); err != nil {
return nil, err
}
return m, nil
}
func (b *builder) merge(m map[string]interface{}) builder {
p := make(map[string]interface{})
for k, v := range b.payload {
p[k] = v
}
for k, v := range m {
p[k] = v
}
return builder{
payload: p,
}
}
func (b *builder) token(p func(interface{}) ([]byte, error), h []jose.Header) (*JSONWebToken, error) {
return &JSONWebToken{
payload: p,
Headers: h,
}, nil
}
func (b *signedBuilder) Claims(i interface{}) Builder {
return &signedBuilder{
builder: b.builder.claims(i),
sig: b.sig,
}
}
func (b *signedBuilder) Token() (*JSONWebToken, error) {
sig, err := b.sign()
if err != nil {
return nil, err
}
h := make([]jose.Header, len(sig.Signatures))
for i, v := range sig.Signatures {
h[i] = v.Header
}
return b.builder.token(sig.Verify, h)
}
func (b *signedBuilder) CompactSerialize() (string, error) {
sig, err := b.sign()
if err != nil {
return "", err
}
return sig.CompactSerialize()
}
func (b *signedBuilder) FullSerialize() (string, error) {
sig, err := b.sign()
if err != nil {
return "", err
}
return sig.FullSerialize(), nil
}
func (b *signedBuilder) sign() (*jose.JSONWebSignature, error) {
if b.err != nil {
return nil, b.err
}
p, err := json.Marshal(b.payload)
if err != nil {
return nil, err
}
return b.sig.Sign(p)
}
func (b *encryptedBuilder) Claims(i interface{}) Builder {
return &encryptedBuilder{
builder: b.builder.claims(i),
enc: b.enc,
}
}
func (b *encryptedBuilder) CompactSerialize() (string, error) {
enc, err := b.encrypt()
if err != nil {
return "", err
}
return enc.CompactSerialize()
}
func (b *encryptedBuilder) FullSerialize() (string, error) {
enc, err := b.encrypt()
if err != nil {
return "", err
}
return enc.FullSerialize(), nil
}
func (b *encryptedBuilder) Token() (*JSONWebToken, error) {
enc, err := b.encrypt()
if err != nil {
return nil, err
}
return b.builder.token(enc.Decrypt, []jose.Header{enc.Header})
}
func (b *encryptedBuilder) encrypt() (*jose.JSONWebEncryption, error) {
if b.err != nil {
return nil, b.err
}
p, err := json.Marshal(b.payload)
if err != nil {
return nil, err
}
return b.enc.Encrypt(p)
}
func (b *nestedBuilder) Claims(i interface{}) NestedBuilder {
return &nestedBuilder{
builder: b.builder.claims(i),
sig: b.sig,
enc: b.enc,
}
}
func (b *nestedBuilder) Token() (*NestedJSONWebToken, error) {
enc, err := b.signAndEncrypt()
if err != nil {
return nil, err
}
return &NestedJSONWebToken{
enc: enc,
Headers: []jose.Header{enc.Header},
}, nil
}
func (b *nestedBuilder) CompactSerialize() (string, error) {
enc, err := b.signAndEncrypt()
if err != nil {
return "", err
}
return enc.CompactSerialize()
}
func (b *nestedBuilder) FullSerialize() (string, error) {
enc, err := b.signAndEncrypt()
if err != nil {
return "", err
}
return enc.FullSerialize(), nil
}
func (b *nestedBuilder) signAndEncrypt() (*jose.JSONWebEncryption, error) {
if b.err != nil {
return nil, b.err
}
p, err := json.Marshal(b.payload)
if err != nil {
return nil, err
}
sig, err := b.sig.Sign(p)
if err != nil {
return nil, err
}
p2, err := sig.CompactSerialize()
if err != nil {
return nil, err
}
return b.enc.Encrypt([]byte(p2))
}
/*-
* Copyright 2016 Zbigniew Mandziejewicz
* Copyright 2016 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jwt
import (
"strconv"
"time"
"gopkg.in/square/go-jose.v2/json"
)
// Claims represents public claim values (as specified in RFC 7519).
type Claims struct {
Issuer string `json:"iss,omitempty"`
Subject string `json:"sub,omitempty"`
Audience Audience `json:"aud,omitempty"`
Expiry *NumericDate `json:"exp,omitempty"`
NotBefore *NumericDate `json:"nbf,omitempty"`
IssuedAt *NumericDate `json:"iat,omitempty"`
ID string `json:"jti,omitempty"`
}
// NumericDate represents date and time as the number of seconds since the
// epoch, including leap seconds. Non-integer values can be represented
// in the serialized format, but we round to the nearest second.
type NumericDate int64
// NewNumericDate constructs NumericDate from time.Time value.
func NewNumericDate(t time.Time) *NumericDate {
if t.IsZero() {
return nil
}
// While RFC 7519 technically states that NumericDate values may be
// non-integer values, we don't bother serializing timestamps in
// claims with sub-second accurancy and just round to the nearest
// second instead. Not convined sub-second accuracy is useful here.
out := NumericDate(t.Unix())
return &out
}
// MarshalJSON serializes the given NumericDate into its JSON representation.
func (n NumericDate) MarshalJSON() ([]byte, error) {
return []byte(strconv.FormatInt(int64(n), 10)), nil
}
// UnmarshalJSON reads a date from its JSON representation.
func (n *NumericDate) UnmarshalJSON(b []byte) error {
s := string(b)
f, err := strconv.ParseFloat(s, 64)
if err != nil {
return ErrUnmarshalNumericDate
}
*n = NumericDate(f)
return nil
}
// Time returns time.Time representation of NumericDate.
func (n *NumericDate) Time() time.Time {
if n == nil {
return time.Time{}
}
return time.Unix(int64(*n), 0)
}
// Audience represents the recipents that the token is intended for.
type Audience []string
// UnmarshalJSON reads an audience from its JSON representation.
func (s *Audience) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
switch v := v.(type) {
case string:
*s = []string{v}
case []interface{}:
a := make([]string, len(v))
for i, e := range v {
s, ok := e.(string)
if !ok {
return ErrUnmarshalAudience
}
a[i] = s
}
*s = a
default:
return ErrUnmarshalAudience
}
return nil
}
func (s Audience) Contains(v string) bool {
for _, a := range s {
if a == v {
return true
}
}
return false
}
/*-
* Copyright 2017 Square Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
Package jwt provides an implementation of the JSON Web Token standard.
*/
package jwt
/*-
* Copyright 2016 Zbigniew Mandziejewicz
* Copyright 2016 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jwt
import "errors"
// ErrUnmarshalAudience indicates that aud claim could not be unmarshalled.
var ErrUnmarshalAudience = errors.New("square/go-jose/jwt: expected string or array value to unmarshal to Audience")
// ErrUnmarshalNumericDate indicates that JWT NumericDate could not be unmarshalled.
var ErrUnmarshalNumericDate = errors.New("square/go-jose/jwt: expected number value to unmarshal NumericDate")
// ErrInvalidClaims indicates that given claims have invalid type.
var ErrInvalidClaims = errors.New("square/go-jose/jwt: expected claims to be value convertible into JSON object")
// ErrInvalidIssuer indicates invalid iss claim.
var ErrInvalidIssuer = errors.New("square/go-jose/jwt: validation failed, invalid issuer claim (iss)")
// ErrInvalidSubject indicates invalid sub claim.
var ErrInvalidSubject = errors.New("square/go-jose/jwt: validation failed, invalid subject claim (sub)")
// ErrInvalidAudience indicated invalid aud claim.
var ErrInvalidAudience = errors.New("square/go-jose/jwt: validation failed, invalid audience claim (aud)")
// ErrInvalidID indicates invalid jti claim.
var ErrInvalidID = errors.New("square/go-jose/jwt: validation failed, invalid ID claim (jti)")
// ErrNotValidYet indicates that token is used before time indicated in nbf claim.
var ErrNotValidYet = errors.New("square/go-jose/jwt: validation failed, token not valid yet (nbf)")
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
var ErrExpired = errors.New("square/go-jose/jwt: validation failed, token is expired (exp)")
// ErrIssuedInTheFuture indicates that the iat field is in the future.
var ErrIssuedInTheFuture = errors.New("square/go-jose/jwt: validation field, token issued in the future (iat)")
// ErrInvalidContentType indicates that token requires JWT cty header.
var ErrInvalidContentType = errors.New("square/go-jose/jwt: expected content type to be JWT (cty header)")
/*-
* Copyright 2016 Zbigniew Mandziejewicz
* Copyright 2016 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jwt
import (
"fmt"
"strings"
jose "gopkg.in/square/go-jose.v2"
"gopkg.in/square/go-jose.v2/json"
)
// JSONWebToken represents a JSON Web Token (as specified in RFC7519).
type JSONWebToken struct {
payload func(k interface{}) ([]byte, error)
unverifiedPayload func() []byte
Headers []jose.Header
}
type NestedJSONWebToken struct {
enc *jose.JSONWebEncryption
Headers []jose.Header
}
// Claims deserializes a JSONWebToken into dest using the provided key.
func (t *JSONWebToken) Claims(key interface{}, dest ...interface{}) error {
payloadKey := tryJWKS(t.Headers, key)
b, err := t.payload(payloadKey)
if err != nil {
return err
}
for _, d := range dest {
if err := json.Unmarshal(b, d); err != nil {
return err
}
}
return nil
}
// UnsafeClaimsWithoutVerification deserializes the claims of a
// JSONWebToken into the dests. For signed JWTs, the claims are not
// verified. This function won't work for encrypted JWTs.
func (t *JSONWebToken) UnsafeClaimsWithoutVerification(dest ...interface{}) error {
if t.unverifiedPayload == nil {
return fmt.Errorf("square/go-jose: Cannot get unverified claims")
}
claims := t.unverifiedPayload()
for _, d := range dest {
if err := json.Unmarshal(claims, d); err != nil {
return err
}
}
return nil
}
func (t *NestedJSONWebToken) Decrypt(decryptionKey interface{}) (*JSONWebToken, error) {
key := tryJWKS(t.Headers, decryptionKey)
b, err := t.enc.Decrypt(key)
if err != nil {
return nil, err
}
sig, err := ParseSigned(string(b))
if err != nil {
return nil, err
}
return sig, nil
}
// ParseSigned parses token from JWS form.
func ParseSigned(s string) (*JSONWebToken, error) {
sig, err := jose.ParseSigned(s)
if err != nil {
return nil, err
}
headers := make([]jose.Header, len(sig.Signatures))
for i, signature := range sig.Signatures {
headers[i] = signature.Header
}
return &JSONWebToken{
payload: sig.Verify,
unverifiedPayload: sig.UnsafePayloadWithoutVerification,
Headers: headers,
}, nil
}
// ParseEncrypted parses token from JWE form.
func ParseEncrypted(s string) (*JSONWebToken, error) {
enc, err := jose.ParseEncrypted(s)
if err != nil {
return nil, err
}
return &JSONWebToken{
payload: enc.Decrypt,
Headers: []jose.Header{enc.Header},
}, nil
}
// ParseSignedAndEncrypted parses signed-then-encrypted token from JWE form.
func ParseSignedAndEncrypted(s string) (*NestedJSONWebToken, error) {
enc, err := jose.ParseEncrypted(s)
if err != nil {
return nil, err
}
contentType, _ := enc.Header.ExtraHeaders[jose.HeaderContentType].(string)
if strings.ToUpper(contentType) != "JWT" {
return nil, ErrInvalidContentType
}
return &NestedJSONWebToken{
enc: enc,
Headers: []jose.Header{enc.Header},
}, nil
}
func tryJWKS(headers []jose.Header, key interface{}) interface{} {
jwks, ok := key.(*jose.JSONWebKeySet)
if !ok {
return key
}
var kid string
for _, header := range headers {
if header.KeyID != "" {
kid = header.KeyID
break
}
}
if kid == "" {
return key
}
keys := jwks.Key(kid)
if len(keys) == 0 {
return key
}
return keys[0].Key
}
/*-
* Copyright 2016 Zbigniew Mandziejewicz
* Copyright 2016 Square, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package jwt
import "time"
const (
// DefaultLeeway defines the default leeway for matching NotBefore/Expiry claims.
DefaultLeeway = 1.0 * time.Minute
)
// Expected defines values used for protected claims validation.
// If field has zero value then validation is skipped.
type Expected struct {
// Issuer matches the "iss" claim exactly.
Issuer string
// Subject matches the "sub" claim exactly.
Subject string
// Audience matches the values in "aud" claim, regardless of their order.
Audience Audience
// ID matches the "jti" claim exactly.
ID string
// Time matches the "exp" and "nbf" claims with leeway.
Time time.Time
}
// WithTime copies expectations with new time.
func (e Expected) WithTime(t time.Time) Expected {
e.Time = t
return e
}
// Validate checks claims in a token against expected values.
// A default leeway value of one minute is used to compare time values.
//
// The default leeway will cause the token to be deemed valid until one
// minute after the expiration time. If you're a server application that
// wants to give an extra minute to client tokens, use this
// function. If you're a client application wondering if the server
// will accept your token, use ValidateWithLeeway with a leeway <=0,
// otherwise this function might make you think a token is valid when
// it is not.
func (c Claims) Validate(e Expected) error {
return c.ValidateWithLeeway(e, DefaultLeeway)
}
// ValidateWithLeeway checks claims in a token against expected values. A
// custom leeway may be specified for comparing time values. You may pass a
// zero value to check time values with no leeway, but you should not that
// numeric date values are rounded to the nearest second and sub-second
// precision is not supported.
//
// The leeway gives some extra time to the token from the server's
// point of view. That is, if the token is expired, ValidateWithLeeway
// will still accept the token for 'leeway' amount of time. This fails
// if you're using this function to check if a server will accept your
// token, because it will think the token is valid even after it
// expires. So if you're a client validating if the token is valid to
// be submitted to a server, use leeway <=0, if you're a server
// validation a token, use leeway >=0.
func (c Claims) ValidateWithLeeway(e Expected, leeway time.Duration) error {
if e.Issuer != "" && e.Issuer != c.Issuer {
return ErrInvalidIssuer
}
if e.Subject != "" && e.Subject != c.Subject {
return ErrInvalidSubject
}
if e.ID != "" && e.ID != c.ID {
return ErrInvalidID
}
if len(e.Audience) != 0 {
for _, v := range e.Audience {
if !c.Audience.Contains(v) {
return ErrInvalidAudience
}
}
}
if !e.Time.IsZero() {
if c.NotBefore != nil && e.Time.Add(leeway).Before(c.NotBefore.Time()) {
return ErrNotValidYet
}
if c.Expiry != nil && e.Time.Add(-leeway).After(c.Expiry.Time()) {
return ErrExpired
}
// IssuedAt is optional but cannot be in the future. This is not required by the RFC, but
// something is misconfigured if this happens and we should not trust it.
if c.IssuedAt != nil && e.Time.Add(leeway).Before(c.IssuedAt.Time()) {
return ErrIssuedInTheFuture
}
}
return nil
}
......@@ -81,3 +81,64 @@ type opaqueVerifier struct {
func (o *opaqueVerifier) verifyPayload(payload []byte, signature []byte, alg SignatureAlgorithm) error {
return o.verifier.VerifyPayload(payload, signature, alg)
}
// OpaqueKeyEncrypter is an interface that supports encrypting keys with an opaque key.
type OpaqueKeyEncrypter interface {
// KeyID returns the kid
KeyID() string
// Algs returns a list of supported key encryption algorithms.
Algs() []KeyAlgorithm
// encryptKey encrypts the CEK using the given algorithm.
encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error)
}
type opaqueKeyEncrypter struct {
encrypter OpaqueKeyEncrypter
}
func newOpaqueKeyEncrypter(alg KeyAlgorithm, encrypter OpaqueKeyEncrypter) (recipientKeyInfo, error) {
var algSupported bool
for _, salg := range encrypter.Algs() {
if alg == salg {
algSupported = true
break
}
}
if !algSupported {
return recipientKeyInfo{}, ErrUnsupportedAlgorithm
}
return recipientKeyInfo{
keyID: encrypter.KeyID(),
keyAlg: alg,
keyEncrypter: &opaqueKeyEncrypter{
encrypter: encrypter,
},
}, nil
}
func (oke *opaqueKeyEncrypter) encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error) {
return oke.encrypter.encryptKey(cek, alg)
}
//OpaqueKeyDecrypter is an interface that supports decrypting keys with an opaque key.
type OpaqueKeyDecrypter interface {
DecryptKey(encryptedKey []byte, header Header) ([]byte, error)
}
type opaqueKeyDecrypter struct {
decrypter OpaqueKeyDecrypter
}
func (okd *opaqueKeyDecrypter) decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error) {
mergedHeaders := rawHeader{}
mergedHeaders.merge(&headers)
mergedHeaders.merge(recipient.header)
header, err := mergedHeaders.sanitized()
if err != nil {
return nil, err
}
return okd.decrypter.DecryptKey(recipient.encryptedKey, header)
}
......@@ -153,12 +153,18 @@ const (
headerJWK = "jwk" // *JSONWebKey
headerKeyID = "kid" // string
headerNonce = "nonce" // string
headerB64 = "b64" // bool
headerP2C = "p2c" // *byteBuffer (int)
headerP2S = "p2s" // *byteBuffer ([]byte)
)
// supportedCritical is the set of supported extensions that are understood and processed.
var supportedCritical = map[string]bool{
headerB64: true,
}
// rawHeader represents the JOSE header for JWE/JWS objects (used for parsing).
//
// The decoding of the constituent items is deferred because we want to marshal
......@@ -349,6 +355,21 @@ func (parsed rawHeader) getP2S() (*byteBuffer, error) {
return parsed.getByteBuffer(headerP2S)
}
// getB64 extracts parsed "b64" from the raw JSON, defaulting to true.
func (parsed rawHeader) getB64() (bool, error) {
v := parsed[headerB64]
if v == nil {
return true, nil
}
var b64 bool
err := json.Unmarshal(*v, &b64)
if err != nil {
return true, err
}
return b64, nil
}
// sanitized produces a cleaned-up header object from the raw JSON.
func (parsed rawHeader) sanitized() (h Header, err error) {
for k, v := range parsed {
......
......@@ -17,6 +17,7 @@
package jose
import (
"bytes"
"crypto/ecdsa"
"crypto/rsa"
"encoding/base64"
......@@ -77,6 +78,27 @@ func (so *SignerOptions) WithType(typ ContentType) *SignerOptions {
return so.WithHeader(HeaderType, typ)
}
// WithCritical adds the given names to the critical ("crit") header and returns
// the updated SignerOptions.
func (so *SignerOptions) WithCritical(names ...string) *SignerOptions {
if so.ExtraHeaders[headerCritical] == nil {
so.WithHeader(headerCritical, make([]string, 0, len(names)))
}
crit := so.ExtraHeaders[headerCritical].([]string)
so.ExtraHeaders[headerCritical] = append(crit, names...)
return so
}
// WithBase64 adds a base64url-encode payload ("b64") header and returns the updated
// SignerOptions. When the "b64" value is "false", the payload is not base64 encoded.
func (so *SignerOptions) WithBase64(b64 bool) *SignerOptions {
if !b64 {
so.WithHeader(headerB64, b64)
so.WithCritical(headerB64)
}
return so
}
type payloadSigner interface {
signPayload(payload []byte, alg SignatureAlgorithm) (Signature, error)
}
......@@ -233,7 +255,10 @@ func (ctx *genericSigner) Sign(payload []byte) (*JSONWebSignature, error) {
if ctx.embedJWK {
protected[headerJWK] = recipient.publicKey()
} else {
protected[headerKeyID] = recipient.publicKey().KeyID
keyID := recipient.publicKey().KeyID
if keyID != "" {
protected[headerKeyID] = keyID
}
}
}
......@@ -250,12 +275,26 @@ func (ctx *genericSigner) Sign(payload []byte) (*JSONWebSignature, error) {
}
serializedProtected := mustSerializeJSON(protected)
needsBase64 := true
if b64, ok := protected[headerB64]; ok {
if needsBase64, ok = b64.(bool); !ok {
return nil, errors.New("square/go-jose: Invalid b64 header parameter")
}
}
var input bytes.Buffer
input := []byte(fmt.Sprintf("%s.%s",
base64.RawURLEncoding.EncodeToString(serializedProtected),
base64.RawURLEncoding.EncodeToString(payload)))
input.WriteString(base64.RawURLEncoding.EncodeToString(serializedProtected))
input.WriteByte('.')
signatureInfo, err := recipient.signer.signPayload(input, recipient.sigAlg)
if needsBase64 {
input.WriteString(base64.RawURLEncoding.EncodeToString(payload))
} else {
input.Write(payload)
}
signatureInfo, err := recipient.signer.signPayload(input.Bytes(), recipient.sigAlg)
if err != nil {
return nil, err
}
......@@ -324,12 +363,18 @@ func (obj JSONWebSignature) DetachedVerify(payload []byte, verificationKey inter
if err != nil {
return err
}
if len(critical) > 0 {
// Unsupported crit header
for _, name := range critical {
if !supportedCritical[name] {
return ErrCryptoFailure
}
}
input, err := obj.computeAuthData(payload, &signature)
if err != nil {
return ErrCryptoFailure
}
input := obj.computeAuthData(payload, &signature)
alg := headers.getSignatureAlgorithm()
err = verifier.verifyPayload(input, signature.Signature, alg)
if err == nil {
......@@ -366,18 +411,25 @@ func (obj JSONWebSignature) DetachedVerifyMulti(payload []byte, verificationKey
return -1, Signature{}, err
}
outer:
for i, signature := range obj.Signatures {
headers := signature.mergedHeaders()
critical, err := headers.getCritical()
if err != nil {
continue
}
if len(critical) > 0 {
// Unsupported crit header
for _, name := range critical {
if !supportedCritical[name] {
continue outer
}
}
input, err := obj.computeAuthData(payload, &signature)
if err != nil {
continue
}
input := obj.computeAuthData(payload, &signature)
alg := headers.getSignatureAlgorithm()
err = verifier.verifyPayload(input, signature.Signature, alg)
if err == nil {
......
......@@ -396,9 +396,10 @@ gopkg.in/redis.v5/internal/consistenthash
gopkg.in/redis.v5/internal/hashtag
gopkg.in/redis.v5/internal/pool
gopkg.in/redis.v5/internal/proto
# gopkg.in/square/go-jose.v2 v2.3.0
# gopkg.in/square/go-jose.v2 v2.4.1
gopkg.in/square/go-jose.v2
gopkg.in/square/go-jose.v2/cipher
gopkg.in/square/go-jose.v2/json
gopkg.in/square/go-jose.v2/jwt
# gopkg.in/yaml.v2 v2.2.5
gopkg.in/yaml.v2
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment