Commit e9fcca16 by Torkel Ödegaard

updated to new golang/x/oauth2

parent c04a2aba
......@@ -29,10 +29,6 @@
"Rev": "5c23849a66f4593e68909bb6c1fa30651b5b0541"
},
{
"ImportPath": "github.com/golang/oauth2",
"Rev": "5fd31d511c212ab476371f61b4fa28e9f168a8f0"
},
{
"ImportPath": "github.com/macaron-contrib/session",
"Rev": "f00d48fd4f85088603c1493b0a99fdfe95d0658c"
},
......
language: go
go: 1.3
install:
- go get -v -tags='appengine appenginevm' ./...
script:
- go test -v -tags='appengine appenginevm' ./...
# This source code refers to The Go Authors for copyright purposes.
# The master list of authors is in the main Go distribution,
# visible at http://tip.golang.org/AUTHORS.
# This source code was written by the Go contributors.
# The master list of contributors is in the main Go distribution,
# visible at http://tip.golang.org/CONTRIBUTORS.
Copyright (c) 2009 The oauth2 Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# OAuth2 for Go
[![Build Status](https://travis-ci.org/golang/oauth2.svg?branch=master)](https://travis-ci.org/golang/oauth2)
oauth2 package contains a client implementation for OAuth 2.0 spec.
## Installation
~~~~
go get github.com/golang/oauth2
~~~~
See godoc for further documentation and examples.
* [godoc.org/github.com/golang/oauth2](http://godoc.org/github.com/golang/oauth2)
* [godoc.org/github.com/golang/oauth2/google](http://godoc.org/github.com/golang/oauth2/google)
## Contributing
Fork the repo, make changes, run the tests and open a pull request.
Before we can accept any pull requests
we have to jump through a couple of legal hurdles,
primarily a Contributor License Agreement (CLA):
- **If you are an individual writing original source code**
and you're sure you own the intellectual property,
then you'll need to sign an [individual CLA](http://code.google.com/legal/individual-cla-v1.0.html).
- **If you work for a company that wants to allow you to contribute your work**,
then you'll need to sign a [corporate CLA](http://code.google.com/legal/corporate-cla-v1.0.html).
You can sign these electronically (just scroll to the bottom).
After that, we'll be able to accept your pull requests.
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2_test
import (
"fmt"
"log"
"net/http"
"testing"
"github.com/golang/oauth2"
)
// TODO(jbd): Remove after Go 1.4.
// Related to https://codereview.appspot.com/107320046
func TestA(t *testing.T) {}
func Example_regular() {
opts, err := oauth2.New(
oauth2.Client("YOUR_CLIENT_ID", "YOUR_CLIENT_SECRET"),
oauth2.RedirectURL("YOUR_REDIRECT_URL"),
oauth2.Scope("SCOPE1", "SCOPE2"),
oauth2.Endpoint(
"https://provider.com/o/oauth2/auth",
"https://provider.com/o/oauth2/token",
),
)
if err != nil {
log.Fatal(err)
}
// Redirect user to consent page to ask for permission
// for the scopes specified above.
url := opts.AuthCodeURL("state", "online", "auto")
fmt.Printf("Visit the URL for the auth dialog: %v", url)
// Use the authorization code that is pushed to the redirect URL.
// NewTransportWithCode will do the handshake to retrieve
// an access token and initiate a Transport that is
// authorized and authenticated by the retrieved token.
var code string
if _, err = fmt.Scan(&code); err != nil {
log.Fatal(err)
}
t, err := opts.NewTransportFromCode(code)
if err != nil {
log.Fatal(err)
}
// You can use t to initiate a new http.Client and
// start making authenticated requests.
client := http.Client{Transport: t}
client.Get("...")
}
func Example_jWT() {
opts, err := oauth2.New(
// The contents of your RSA private key or your PEM file
// that contains a private key.
// If you have a p12 file instead, you
// can use `openssl` to export the private key into a pem file.
//
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
//
// It only supports PEM containers with no passphrase.
oauth2.JWTClient(
"xxx@developer.gserviceaccount.com",
[]byte("-----BEGIN RSA PRIVATE KEY-----...")),
oauth2.Scope("SCOPE1", "SCOPE2"),
oauth2.JWTEndpoint("https://provider.com/o/oauth2/token"),
// If you would like to impersonate a user, you can
// create a transport with a subject. The following GET
// request will be made on the behalf of user@example.com.
// Subject is optional.
oauth2.Subject("user@example.com"),
)
if err != nil {
log.Fatal(err)
}
// Initiate an http.Client, the following GET request will be
// authorized and authenticated on the behalf of user@example.com.
client := http.Client{Transport: opts.NewTransport()}
client.Get("...")
}
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appengine,!appenginevm
package google
import (
"net/http"
"strings"
"sync"
"time"
"github.com/golang/oauth2"
"appengine"
"appengine/memcache"
"appengine/urlfetch"
)
var (
// memcacheGob enables mocking of the memcache.Gob calls for unit testing.
memcacheGob memcacher = &aeMemcache{}
// accessTokenFunc enables mocking of the appengine.AccessToken call for unit testing.
accessTokenFunc = appengine.AccessToken
// mu protects multiple threads from attempting to fetch a token at the same time.
mu sync.Mutex
// tokens implements a local cache of tokens to prevent hitting quota limits for appengine.AccessToken calls.
tokens map[string]*oauth2.Token
)
// safetyMargin is used to avoid clock-skew problems.
// 5 minutes is conservative because tokens are valid for 60 minutes.
const safetyMargin = 5 * time.Minute
func init() {
tokens = make(map[string]*oauth2.Token)
}
// AppEngineContext requires an App Engine request context.
func AppEngineContext(ctx appengine.Context) oauth2.Option {
return func(opts *oauth2.Options) error {
opts.TokenFetcherFunc = makeAppEngineTokenFetcher(ctx, opts)
opts.Client = &http.Client{
Transport: &urlfetch.Transport{Context: ctx},
}
return nil
}
}
// FetchToken fetches a new access token for the provided scopes.
// Tokens are cached locally and also with Memcache so that the app can scale
// without hitting quota limits by calling appengine.AccessToken too frequently.
func makeAppEngineTokenFetcher(ctx appengine.Context, opts *oauth2.Options) func(*oauth2.Token) (*oauth2.Token, error) {
return func(existing *oauth2.Token) (*oauth2.Token, error) {
mu.Lock()
defer mu.Unlock()
key := ":" + strings.Join(opts.Scopes, "_")
now := time.Now().Add(safetyMargin)
if t, ok := tokens[key]; ok && !t.Expiry.Before(now) {
return t, nil
}
delete(tokens, key)
// Attempt to get token from Memcache
tok := new(oauth2.Token)
_, err := memcacheGob.Get(ctx, key, tok)
if err == nil && !tok.Expiry.Before(now) {
tokens[key] = tok // Save token locally
return tok, nil
}
token, expiry, err := accessTokenFunc(ctx, opts.Scopes...)
if err != nil {
return nil, err
}
t := &oauth2.Token{
AccessToken: token,
Expiry: expiry,
}
tokens[key] = t
// Also back up token in Memcache
if err = memcacheGob.Set(ctx, &memcache.Item{
Key: key,
Value: []byte{},
Object: *t,
Expiration: expiry.Sub(now),
}); err != nil {
ctx.Errorf("unexpected memcache.Set error: %v", err)
}
return t, nil
}
}
// aeMemcache wraps the needed Memcache functionality to make it easy to mock
type aeMemcache struct{}
func (m *aeMemcache) Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) {
return memcache.Gob.Get(c, key, tok)
}
func (m *aeMemcache) Set(c appengine.Context, item *memcache.Item) error {
return memcache.Gob.Set(c, item)
}
type memcacher interface {
Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error)
Set(c appengine.Context, item *memcache.Item) error
}
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appengine,!appenginevm
package google
import (
"fmt"
"log"
"net/http"
"sync"
"testing"
"time"
"github.com/golang/oauth2"
"appengine"
"appengine/memcache"
)
type tokMap map[string]*oauth2.Token
type mockMemcache struct {
mu sync.RWMutex
vals tokMap
getCount, setCount int
}
func (m *mockMemcache) Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.getCount++
v, ok := m.vals[key]
if !ok {
return nil, fmt.Errorf("unexpected test error: key %q not found", key)
}
*tok = *v
return nil, nil // memcache.Item is ignored anyway - return nil
}
func (m *mockMemcache) Set(c appengine.Context, item *memcache.Item) error {
m.mu.Lock()
defer m.mu.Unlock()
m.setCount++
tok, ok := item.Object.(oauth2.Token)
if !ok {
log.Fatalf("unexpected test error: item.Object is not an oauth2.Token: %#v", item)
}
m.vals[item.Key] = &tok
return nil
}
var accessTokenCount = 0
func mockAccessToken(c appengine.Context, scopes ...string) (token string, expiry time.Time, err error) {
accessTokenCount++
return "mytoken", time.Now(), nil
}
const (
testScope = "myscope"
testScopeKey = ":" + testScope
)
func init() {
accessTokenFunc = mockAccessToken
}
func TestFetchTokenLocalCacheMiss(t *testing.T) {
m := &mockMemcache{vals: make(tokMap)}
memcacheGob = m
accessTokenCount = 0
delete(tokens, testScopeKey) // clear local cache
f, err := oauth2.New(
AppEngineContext(nil),
oauth2.Scope(testScope),
)
if err != nil {
t.Error(err)
}
tr := f.NewTransport()
c := http.Client{Transport: tr}
c.Get("server")
if w := 1; m.getCount != w {
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
}
if w := 1; accessTokenCount != w {
t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
}
if w := 1; m.setCount != w {
t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
}
// Make sure local cache has been populated
_, ok := tokens[testScopeKey]
if !ok {
t.Errorf("local cache not populated!")
}
}
func TestFetchTokenLocalCacheHit(t *testing.T) {
m := &mockMemcache{vals: make(tokMap)}
memcacheGob = m
accessTokenCount = 0
// Pre-populate the local cache
tokens[testScopeKey] = &oauth2.Token{
AccessToken: "mytoken",
Expiry: time.Now().Add(1 * time.Hour),
}
f, err := oauth2.New(
AppEngineContext(nil),
oauth2.Scope(testScope),
)
if err != nil {
t.Error(err)
}
tr := f.NewTransport()
c := http.Client{Transport: tr}
c.Get("server")
if err != nil {
t.Errorf("unable to FetchToken: %v", err)
}
if w := 0; m.getCount != w {
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
}
if w := 0; accessTokenCount != w {
t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
}
if w := 0; m.setCount != w {
t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
}
// Make sure local cache remains populated
_, ok := tokens[testScopeKey]
if !ok {
t.Errorf("local cache not populated!")
}
}
func TestFetchTokenMemcacheHit(t *testing.T) {
m := &mockMemcache{vals: make(tokMap)}
memcacheGob = m
accessTokenCount = 0
delete(tokens, testScopeKey) // clear local cache
// Pre-populate the memcache
tok := &oauth2.Token{
AccessToken: "mytoken",
Expiry: time.Now().Add(1 * time.Hour),
}
m.Set(nil, &memcache.Item{
Key: testScopeKey,
Object: *tok,
Expiration: 1 * time.Hour,
})
m.setCount = 0
f, err := oauth2.New(
AppEngineContext(nil),
oauth2.Scope(testScope),
)
if err != nil {
t.Error(err)
}
c := http.Client{Transport: f.NewTransport()}
c.Get("server")
if w := 1; m.getCount != w {
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
}
if w := 0; accessTokenCount != w {
t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
}
if w := 0; m.setCount != w {
t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
}
// Make sure local cache has been populated
_, ok := tokens[testScopeKey]
if !ok {
t.Errorf("local cache not populated!")
}
}
func TestFetchTokenLocalCacheExpired(t *testing.T) {
m := &mockMemcache{vals: make(tokMap)}
memcacheGob = m
accessTokenCount = 0
// Pre-populate the local cache
tokens[testScopeKey] = &oauth2.Token{
AccessToken: "mytoken",
Expiry: time.Now().Add(-1 * time.Hour),
}
// Pre-populate the memcache
tok := &oauth2.Token{
AccessToken: "mytoken",
Expiry: time.Now().Add(1 * time.Hour),
}
m.Set(nil, &memcache.Item{
Key: testScopeKey,
Object: *tok,
Expiration: 1 * time.Hour,
})
m.setCount = 0
f, err := oauth2.New(
AppEngineContext(nil),
oauth2.Scope(testScope),
)
if err != nil {
t.Error(err)
}
c := http.Client{Transport: f.NewTransport()}
c.Get("server")
if w := 1; m.getCount != w {
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
}
if w := 0; accessTokenCount != w {
t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
}
if w := 0; m.setCount != w {
t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
}
// Make sure local cache remains populated
_, ok := tokens[testScopeKey]
if !ok {
t.Errorf("local cache not populated!")
}
}
func TestFetchTokenMemcacheExpired(t *testing.T) {
m := &mockMemcache{vals: make(tokMap)}
memcacheGob = m
accessTokenCount = 0
delete(tokens, testScopeKey) // clear local cache
// Pre-populate the memcache
tok := &oauth2.Token{
AccessToken: "mytoken",
Expiry: time.Now().Add(-1 * time.Hour),
}
m.Set(nil, &memcache.Item{
Key: testScopeKey,
Object: *tok,
Expiration: -1 * time.Hour,
})
m.setCount = 0
f, err := oauth2.New(
AppEngineContext(nil),
oauth2.Scope(testScope),
)
if err != nil {
t.Error(err)
}
c := http.Client{Transport: f.NewTransport()}
c.Get("server")
if w := 1; m.getCount != w {
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
}
if w := 1; accessTokenCount != w {
t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
}
if w := 1; m.setCount != w {
t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
}
// Make sure local cache has been populated
_, ok := tokens[testScopeKey]
if !ok {
t.Errorf("local cache not populated!")
}
}
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appenginevm !appengine
package google
import (
"strings"
"sync"
"time"
"github.com/golang/oauth2"
"google.golang.org/appengine"
"google.golang.org/appengine/memcache"
)
var (
// memcacheGob enables mocking of the memcache.Gob calls for unit testing.
memcacheGob memcacher = &aeMemcache{}
// accessTokenFunc enables mocking of the appengine.AccessToken call for unit testing.
accessTokenFunc = appengine.AccessToken
// mu protects multiple threads from attempting to fetch a token at the same time.
mu sync.Mutex
// tokens implements a local cache of tokens to prevent hitting quota limits for appengine.AccessToken calls.
tokens map[string]*oauth2.Token
)
// safetyMargin is used to avoid clock-skew problems.
// 5 minutes is conservative because tokens are valid for 60 minutes.
const safetyMargin = 5 * time.Minute
func init() {
tokens = make(map[string]*oauth2.Token)
}
// AppEngineContext requires an App Engine request context.
func AppEngineContext(ctx appengine.Context) oauth2.Option {
return func(opts *oauth2.Options) error {
opts.TokenFetcherFunc = makeAppEngineTokenFetcher(ctx, opts)
return nil
}
}
// FetchToken fetches a new access token for the provided scopes.
// Tokens are cached locally and also with Memcache so that the app can scale
// without hitting quota limits by calling appengine.AccessToken too frequently.
func makeAppEngineTokenFetcher(ctx appengine.Context, opts *oauth2.Options) func(*oauth2.Token) (*oauth2.Token, error) {
return func(existing *oauth2.Token) (*oauth2.Token, error) {
mu.Lock()
defer mu.Unlock()
key := ":" + strings.Join(opts.Scopes, "_")
now := time.Now().Add(safetyMargin)
if t, ok := tokens[key]; ok && !t.Expiry.Before(now) {
return t, nil
}
delete(tokens, key)
// Attempt to get token from Memcache
tok := new(oauth2.Token)
_, err := memcacheGob.Get(ctx, key, tok)
if err == nil && !tok.Expiry.Before(now) {
tokens[key] = tok // Save token locally
return tok, nil
}
token, expiry, err := accessTokenFunc(ctx, opts.Scopes...)
if err != nil {
return nil, err
}
t := &oauth2.Token{
AccessToken: token,
Expiry: expiry,
}
tokens[key] = t
// Also back up token in Memcache
if err = memcacheGob.Set(ctx, &memcache.Item{
Key: key,
Value: []byte{},
Object: *t,
Expiration: expiry.Sub(now),
}); err != nil {
ctx.Errorf("unexpected memcache.Set error: %v", err)
}
return t, nil
}
}
// aeMemcache wraps the needed Memcache functionality to make it easy to mock
type aeMemcache struct{}
func (m *aeMemcache) Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) {
return memcache.Gob.Get(c, key, tok)
}
func (m *aeMemcache) Set(c appengine.Context, item *memcache.Item) error {
return memcache.Gob.Set(c, item)
}
type memcacher interface {
Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error)
Set(c appengine.Context, item *memcache.Item) error
}
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appenginevm !appengine
package google
import (
"fmt"
"log"
"net/http"
"sync"
"testing"
"time"
"github.com/golang/oauth2"
"google.golang.org/appengine"
"google.golang.org/appengine/memcache"
)
type tokMap map[string]*oauth2.Token
type mockMemcache struct {
mu sync.RWMutex
vals tokMap
getCount, setCount int
}
func (m *mockMemcache) Get(c appengine.Context, key string, tok *oauth2.Token) (*memcache.Item, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.getCount++
v, ok := m.vals[key]
if !ok {
return nil, fmt.Errorf("unexpected test error: key %q not found", key)
}
*tok = *v
return nil, nil // memcache.Item is ignored anyway - return nil
}
func (m *mockMemcache) Set(c appengine.Context, item *memcache.Item) error {
m.mu.Lock()
defer m.mu.Unlock()
m.setCount++
tok, ok := item.Object.(oauth2.Token)
if !ok {
log.Fatalf("unexpected test error: item.Object is not an oauth2.Token: %#v", item)
}
m.vals[item.Key] = &tok
return nil
}
var accessTokenCount = 0
func mockAccessToken(c appengine.Context, scopes ...string) (token string, expiry time.Time, err error) {
accessTokenCount++
return "mytoken", time.Now(), nil
}
const (
testScope = "myscope"
testScopeKey = ":" + testScope
)
func init() {
accessTokenFunc = mockAccessToken
}
func TestFetchTokenLocalCacheMiss(t *testing.T) {
m := &mockMemcache{vals: make(tokMap)}
memcacheGob = m
accessTokenCount = 0
delete(tokens, testScopeKey) // clear local cache
f, err := oauth2.New(
AppEngineContext(nil),
oauth2.Scope(testScope),
)
if err != nil {
t.Error(err)
}
tr := f.NewTransport()
c := http.Client{Transport: tr}
c.Get("server")
if w := 1; m.getCount != w {
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
}
if w := 1; accessTokenCount != w {
t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
}
if w := 1; m.setCount != w {
t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
}
// Make sure local cache has been populated
_, ok := tokens[testScopeKey]
if !ok {
t.Errorf("local cache not populated!")
}
}
func TestFetchTokenLocalCacheHit(t *testing.T) {
m := &mockMemcache{vals: make(tokMap)}
memcacheGob = m
accessTokenCount = 0
// Pre-populate the local cache
tokens[testScopeKey] = &oauth2.Token{
AccessToken: "mytoken",
Expiry: time.Now().Add(1 * time.Hour),
}
f, err := oauth2.New(
AppEngineContext(nil),
oauth2.Scope(testScope),
)
if err != nil {
t.Error(err)
}
tr := f.NewTransport()
c := http.Client{Transport: tr}
c.Get("server")
if err != nil {
t.Errorf("unable to FetchToken: %v", err)
}
if w := 0; m.getCount != w {
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
}
if w := 0; accessTokenCount != w {
t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
}
if w := 0; m.setCount != w {
t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
}
// Make sure local cache remains populated
_, ok := tokens[testScopeKey]
if !ok {
t.Errorf("local cache not populated!")
}
}
func TestFetchTokenMemcacheHit(t *testing.T) {
m := &mockMemcache{vals: make(tokMap)}
memcacheGob = m
accessTokenCount = 0
delete(tokens, testScopeKey) // clear local cache
// Pre-populate the memcache
tok := &oauth2.Token{
AccessToken: "mytoken",
Expiry: time.Now().Add(1 * time.Hour),
}
m.Set(nil, &memcache.Item{
Key: testScopeKey,
Object: *tok,
Expiration: 1 * time.Hour,
})
m.setCount = 0
f, err := oauth2.New(
AppEngineContext(nil),
oauth2.Scope(testScope),
)
if err != nil {
t.Error(err)
}
c := http.Client{Transport: f.NewTransport()}
c.Get("server")
if w := 1; m.getCount != w {
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
}
if w := 0; accessTokenCount != w {
t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
}
if w := 0; m.setCount != w {
t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
}
// Make sure local cache has been populated
_, ok := tokens[testScopeKey]
if !ok {
t.Errorf("local cache not populated!")
}
}
func TestFetchTokenLocalCacheExpired(t *testing.T) {
m := &mockMemcache{vals: make(tokMap)}
memcacheGob = m
accessTokenCount = 0
// Pre-populate the local cache
tokens[testScopeKey] = &oauth2.Token{
AccessToken: "mytoken",
Expiry: time.Now().Add(-1 * time.Hour),
}
// Pre-populate the memcache
tok := &oauth2.Token{
AccessToken: "mytoken",
Expiry: time.Now().Add(1 * time.Hour),
}
m.Set(nil, &memcache.Item{
Key: testScopeKey,
Object: *tok,
Expiration: 1 * time.Hour,
})
m.setCount = 0
f, err := oauth2.New(
AppEngineContext(nil),
oauth2.Scope(testScope),
)
if err != nil {
t.Error(err)
}
c := http.Client{Transport: f.NewTransport()}
c.Get("server")
if w := 1; m.getCount != w {
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
}
if w := 0; accessTokenCount != w {
t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
}
if w := 0; m.setCount != w {
t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
}
// Make sure local cache remains populated
_, ok := tokens[testScopeKey]
if !ok {
t.Errorf("local cache not populated!")
}
}
func TestFetchTokenMemcacheExpired(t *testing.T) {
m := &mockMemcache{vals: make(tokMap)}
memcacheGob = m
accessTokenCount = 0
delete(tokens, testScopeKey) // clear local cache
// Pre-populate the memcache
tok := &oauth2.Token{
AccessToken: "mytoken",
Expiry: time.Now().Add(-1 * time.Hour),
}
m.Set(nil, &memcache.Item{
Key: testScopeKey,
Object: *tok,
Expiration: -1 * time.Hour,
})
m.setCount = 0
f, err := oauth2.New(
AppEngineContext(nil),
oauth2.Scope(testScope),
)
if err != nil {
t.Error(err)
}
c := http.Client{Transport: f.NewTransport()}
c.Get("server")
if w := 1; m.getCount != w {
t.Errorf("bad memcache.Get count: got %v, want %v", m.getCount, w)
}
if w := 1; accessTokenCount != w {
t.Errorf("bad AccessToken count: got %v, want %v", accessTokenCount, w)
}
if w := 1; m.setCount != w {
t.Errorf("bad memcache.Set count: got %v, want %v", m.setCount, w)
}
// Make sure local cache has been populated
_, ok := tokens[testScopeKey]
if !ok {
t.Errorf("local cache not populated!")
}
}
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appenginevm !appengine
package google_test
import (
"fmt"
"log"
"net/http"
"testing"
"github.com/golang/oauth2"
"github.com/golang/oauth2/google"
"google.golang.org/appengine"
)
// Remove after Go 1.4.
// Related to https://codereview.appspot.com/107320046
func TestA(t *testing.T) {}
func Example_webServer() {
// Your credentials should be obtained from the Google
// Developer Console (https://console.developers.google.com).
opts, err := oauth2.New(
oauth2.Client("YOUR_CLIENT_ID", "YOUR_CLIENT_SECRET"),
oauth2.RedirectURL("YOUR_REDIRECT_URL"),
oauth2.Scope(
"https://www.googleapis.com/auth/bigquery",
"https://www.googleapis.com/auth/blogger",
),
google.Endpoint(),
)
if err != nil {
log.Fatal(err)
}
// Redirect user to Google's consent page to ask for permission
// for the scopes specified above.
url := opts.AuthCodeURL("state", "online", "auto")
fmt.Printf("Visit the URL for the auth dialog: %v", url)
// Handle the exchange code to initiate a transport
t, err := opts.NewTransportFromCode("exchange-code")
if err != nil {
log.Fatal(err)
}
client := http.Client{Transport: t}
client.Get("...")
}
func Example_serviceAccountsJSON() {
// Your credentials should be obtained from the Google
// Developer Console (https://console.developers.google.com).
// Navigate to your project, then see the "Credentials" page
// under "APIs & Auth".
// To create a service account client, click "Create new Client ID",
// select "Service Account", and click "Create Client ID". A JSON
// key file will then be downloaded to your computer.
opts, err := oauth2.New(
google.ServiceAccountJSONKey("/path/to/your-project-key.json"),
oauth2.Scope(
"https://www.googleapis.com/auth/bigquery",
"https://www.googleapis.com/auth/blogger",
),
)
if err != nil {
log.Fatal(err)
}
// Initiate an http.Client. The following GET request will be
// authorized and authenticated on the behalf of
// your service account.
client := http.Client{Transport: opts.NewTransport()}
client.Get("...")
}
func Example_serviceAccounts() {
// Your credentials should be obtained from the Google
// Developer Console (https://console.developers.google.com).
opts, err := oauth2.New(
// The contents of your RSA private key or your PEM file
// that contains a private key.
// If you have a p12 file instead, you
// can use `openssl` to export the private key into a pem file.
//
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
//
// It only supports PEM containers with no passphrase.
oauth2.JWTClient(
"xxx@developer.gserviceaccount.com",
[]byte("-----BEGIN RSA PRIVATE KEY-----...")),
oauth2.Scope(
"https://www.googleapis.com/auth/bigquery",
"https://www.googleapis.com/auth/blogger",
),
google.JWTEndpoint(),
// If you would like to impersonate a user, you can
// create a transport with a subject. The following GET
// request will be made on the behalf of user@example.com.
// Subject is optional.
oauth2.Subject("user@example.com"),
)
if err != nil {
log.Fatal(err)
}
// Initiate an http.Client, the following GET request will be
// authorized and authenticated on the behalf of user@example.com.
client := http.Client{Transport: opts.NewTransport()}
client.Get("...")
}
func Example_appEngine() {
ctx := appengine.NewContext(nil)
opts, err := oauth2.New(
google.AppEngineContext(ctx),
oauth2.Scope(
"https://www.googleapis.com/auth/bigquery",
"https://www.googleapis.com/auth/blogger",
),
)
if err != nil {
log.Fatal(err)
}
// The following client will be authorized by the App Engine
// app's service account for the provided scopes.
client := http.Client{Transport: opts.NewTransport()}
client.Get("...")
}
func Example_computeEngine() {
opts, err := oauth2.New(
// Query Google Compute Engine's metadata server to retrieve
// an access token for the provided account.
// If no account is specified, "default" is used.
google.ComputeEngineAccount(""),
)
if err != nil {
log.Fatal(err)
}
client := http.Client{Transport: opts.NewTransport()}
client.Get("...")
}
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package google provides support for making
// OAuth2 authorized and authenticated HTTP requests
// to Google APIs. It supports Web server, client-side,
// service accounts, Google Compute Engine service accounts,
// and Google App Engine service accounts authorization
// and authentications flows:
//
// For more information, please read
// https://developers.google.com/accounts/docs/OAuth2.
package google
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
"github.com/golang/oauth2"
"github.com/golang/oauth2/internal"
)
var (
uriGoogleAuth, _ = url.Parse("https://accounts.google.com/o/oauth2/auth")
uriGoogleToken, _ = url.Parse("https://accounts.google.com/o/oauth2/token")
)
type metaTokenRespBody struct {
AccessToken string `json:"access_token"`
ExpiresIn time.Duration `json:"expires_in"`
TokenType string `json:"token_type"`
}
// JWTEndpoint adds the endpoints required to complete the 2-legged service account flow.
func JWTEndpoint() oauth2.Option {
return func(opts *oauth2.Options) error {
opts.AUD = uriGoogleToken
return nil
}
}
// Endpoint adds the endpoints required to do the 3-legged Web server flow.
func Endpoint() oauth2.Option {
return func(opts *oauth2.Options) error {
opts.AuthURL = uriGoogleAuth
opts.TokenURL = uriGoogleToken
return nil
}
}
// ComputeEngineAccount uses the specified account to retrieve an access
// token from the Google Compute Engine's metadata server. If no user is
// provided, "default" is being used.
func ComputeEngineAccount(account string) oauth2.Option {
return func(opts *oauth2.Options) error {
if account == "" {
account = "default"
}
opts.TokenFetcherFunc = makeComputeFetcher(opts, account)
return nil
}
}
// ServiceAccountJSONKey uses the provided Google Developers
// JSON key file to authorize the user. See the "Credentials" page under
// "APIs & Auth" for your project at https://console.developers.google.com
// to download a JSON key file.
func ServiceAccountJSONKey(filename string) oauth2.Option {
return func(opts *oauth2.Options) error {
b, err := ioutil.ReadFile(filename)
if err != nil {
return err
}
var key struct {
Email string `json:"client_email"`
PrivateKey string `json:"private_key"`
}
if err := json.Unmarshal(b, &key); err != nil {
return err
}
pk, err := internal.ParseKey([]byte(key.PrivateKey))
if err != nil {
return err
}
opts.Email = key.Email
opts.PrivateKey = pk
opts.AUD = uriGoogleToken
return nil
}
}
func makeComputeFetcher(opts *oauth2.Options, account string) func(*oauth2.Token) (*oauth2.Token, error) {
return func(t *oauth2.Token) (*oauth2.Token, error) {
u := "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/" + account + "/token"
req, err := http.NewRequest("GET", u, nil)
if err != nil {
return nil, err
}
req.Header.Add("X-Google-Metadata-Request", "True")
c := &http.Client{}
if opts.Client != nil {
c = opts.Client
}
resp, err := c.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode > 299 {
return nil, fmt.Errorf("oauth2: can't retrieve a token from metadata server, status code: %d", resp.StatusCode)
}
var tokenResp metaTokenRespBody
err = json.NewDecoder(resp.Body).Decode(&tokenResp)
if err != nil {
return nil, err
}
return &oauth2.Token{
AccessToken: tokenResp.AccessToken,
TokenType: tokenResp.TokenType,
Expiry: time.Now().Add(tokenResp.ExpiresIn * time.Second),
}, nil
}
}
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package internal
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
)
// ParseKey converts the binary contents of a private key file
// to an *rsa.PrivateKey. It detects whether the private key is in a
// PEM container or not. If so, it extracts the the private key
// from PEM container before conversion. It only supports PEM
// containers with no passphrase.
func ParseKey(key []byte) (*rsa.PrivateKey, error) {
block, _ := pem.Decode(key)
if block != nil {
key = block.Bytes
}
parsedKey, err := x509.ParsePKCS8PrivateKey(key)
if err != nil {
parsedKey, err = x509.ParsePKCS1PrivateKey(key)
if err != nil {
return nil, err
}
}
parsed, ok := parsedKey.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("oauth2: private key is invalid")
}
return parsed, nil
}
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package jws provides encoding and decoding utilities for
// signed JWS messages.
package jws
import (
"bytes"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
)
// The JWT claim set contains information about the JWT including the
// permissions being requested (scopes), the target of the token, the issuer,
// the time the token was issued, and the lifetime of the token.
type ClaimSet struct {
Iss string `json:"iss"` // email address of the client_id of the application making the access token request
Scope string `json:"scope,omitempty"` // space-delimited list of the permissions the application requests
Aud string `json:"aud"` // descriptor of the intended target of the assertion (Optional).
Exp int64 `json:"exp"` // the expiration time of the assertion
Iat int64 `json:"iat"` // the time the assertion was issued.
Typ string `json:"typ,omitempty"` // token type (Optional).
// Email for which the application is requesting delegated access (Optional).
Sub string `json:"sub,omitempty"`
// The old name of Sub. Client keeps setting Prn to be
// complaint with legacy OAuth 2.0 providers. (Optional)
Prn string `json:"prn,omitempty"`
// See http://tools.ietf.org/html/draft-jones-json-web-token-10#section-4.3
// This array is marshalled using custom code (see (c *ClaimSet) encode()).
PrivateClaims map[string]interface{} `json:"-"`
exp time.Time
iat time.Time
}
func (c *ClaimSet) encode() (string, error) {
if c.exp.IsZero() || c.iat.IsZero() {
// Reverting time back for machines whose time is not perfectly in sync.
// If client machine's time is in the future according
// to Google servers, an access token will not be issued.
now := time.Now().Add(-10 * time.Second)
c.iat = now
c.exp = now.Add(time.Hour)
}
c.Exp = c.exp.Unix()
c.Iat = c.iat.Unix()
b, err := json.Marshal(c)
if err != nil {
return "", err
}
if len(c.PrivateClaims) == 0 {
return base64Encode(b), nil
}
// Marshal private claim set and then append it to b.
prv, err := json.Marshal(c.PrivateClaims)
if err != nil {
return "", fmt.Errorf("jws: invalid map of private claims %v", c.PrivateClaims)
}
// Concatenate public and private claim JSON objects.
if !bytes.HasSuffix(b, []byte{'}'}) {
return "", fmt.Errorf("jws: invalid JSON %s", b)
}
if !bytes.HasPrefix(prv, []byte{'{'}) {
return "", fmt.Errorf("jws: invalid JSON %s", prv)
}
b[len(b)-1] = ',' // Replace closing curly brace with a comma.
b = append(b, prv[1:]...) // Append private claims.
return base64Encode(b), nil
}
// Header represents the header for the signed JWS payloads.
type Header struct {
// The algorithm used for signature.
Algorithm string `json:"alg"`
// Represents the token type.
Typ string `json:"typ"`
}
func (h *Header) encode() (string, error) {
b, err := json.Marshal(h)
if err != nil {
return "", err
}
return base64Encode(b), nil
}
// Decode decodes a claim set from a JWS payload.
func Decode(payload string) (c *ClaimSet, err error) {
// decode returned id token to get expiry
s := strings.Split(payload, ".")
if len(s) < 2 {
// TODO(jbd): Provide more context about the error.
return nil, errors.New("jws: invalid token received")
}
decoded, err := base64Decode(s[1])
if err != nil {
return nil, err
}
c = &ClaimSet{}
err = json.NewDecoder(bytes.NewBuffer(decoded)).Decode(c)
return c, err
}
// Encode encodes a signed JWS with provided header and claim set.
func Encode(header *Header, c *ClaimSet, signature *rsa.PrivateKey) (payload string, err error) {
var encodedHeader, encodedClaimSet string
encodedHeader, err = header.encode()
if err != nil {
return
}
encodedClaimSet, err = c.encode()
if err != nil {
return
}
ss := fmt.Sprintf("%s.%s", encodedHeader, encodedClaimSet)
h := sha256.New()
h.Write([]byte(ss))
b, err := rsa.SignPKCS1v15(rand.Reader, signature, crypto.SHA256, h.Sum(nil))
if err != nil {
return
}
sig := base64Encode(b)
return fmt.Sprintf("%s.%s", ss, sig), nil
}
// base64Encode returns and Base64url encoded version of the input string with any
// trailing "=" stripped.
func base64Encode(b []byte) string {
return strings.TrimRight(base64.URLEncoding.EncodeToString(b), "=")
}
// base64Decode decodes the Base64url encoded string
func base64Decode(s string) ([]byte, error) {
// add back missing padding
switch len(s) % 4 {
case 2:
s += "=="
case 3:
s += "="
}
return base64.URLEncoding.DecodeString(s)
}
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
"time"
"github.com/golang/oauth2/internal"
"github.com/golang/oauth2/jws"
)
var (
defaultGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
defaultHeader = &jws.Header{Algorithm: "RS256", Typ: "JWT"}
)
// JWTClient requires OAuth 2.0 JWT credentials.
// Required for the 2-legged JWT flow.
func JWTClient(email string, key []byte) Option {
return func(o *Options) error {
pk, err := internal.ParseKey(key)
if err != nil {
return err
}
o.Email = email
o.PrivateKey = pk
return nil
}
}
// JWTEndpoint requires the JWT token endpoint of the OAuth 2.0 provider.
func JWTEndpoint(aud string) Option {
return func(o *Options) error {
au, err := url.Parse(aud)
if err != nil {
return err
}
o.AUD = au
return nil
}
}
// Subject requires a user to impersonate.
// Optional.
func Subject(user string) Option {
return func(o *Options) error {
o.Subject = user
return nil
}
}
func makeTwoLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
return func(t *Token) (*Token, error) {
if t == nil {
t = &Token{}
}
claimSet := &jws.ClaimSet{
Iss: o.Email,
Scope: strings.Join(o.Scopes, " "),
Aud: o.AUD.String(),
}
if o.Subject != "" {
claimSet.Sub = o.Subject
// prn is the old name of sub. Keep setting it
// to be compatible with legacy OAuth 2.0 providers.
claimSet.Prn = o.Subject
}
payload, err := jws.Encode(defaultHeader, claimSet, o.PrivateKey)
if err != nil {
return nil, err
}
v := url.Values{}
v.Set("grant_type", defaultGrantType)
v.Set("assertion", payload)
c := o.Client
if c == nil {
c = &http.Client{}
}
resp, err := c.PostForm(o.AUD.String(), v)
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
if c := resp.StatusCode; c < 200 || c > 299 {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", resp.Status, body)
}
b := make(map[string]interface{})
if err := json.Unmarshal(body, &b); err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
token := &Token{}
token.AccessToken, _ = b["access_token"].(string)
token.TokenType, _ = b["token_type"].(string)
token.raw = b
if e, ok := b["expires_in"].(int); ok {
token.Expiry = time.Now().Add(time.Duration(e) * time.Second)
}
if idtoken, ok := b["id_token"].(string); ok {
// decode returned id token to get expiry
claimSet, err := jws.Decode(idtoken)
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
token.Expiry = time.Unix(claimSet.Exp, 0)
return token, nil
}
return token, nil
}
}
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2
import (
"net/http"
"net/http/httptest"
"testing"
)
var dummyPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAx4fm7dngEmOULNmAs1IGZ9Apfzh+BkaQ1dzkmbUgpcoghucE
DZRnAGd2aPyB6skGMXUytWQvNYav0WTR00wFtX1ohWTfv68HGXJ8QXCpyoSKSSFY
fuP9X36wBSkSX9J5DVgiuzD5VBdzUISSmapjKm+DcbRALjz6OUIPEWi1Tjl6p5RK
1w41qdbmt7E5/kGhKLDuT7+M83g4VWhgIvaAXtnhklDAggilPPa8ZJ1IFe31lNlr
k4DRk38nc6sEutdf3RL7QoH7FBusI7uXV03DC6dwN1kP4GE7bjJhcRb/7jYt7CQ9
/E9Exz3c0yAp0yrTg0Fwh+qxfH9dKwN52S7SBwIDAQABAoIBAQCaCs26K07WY5Jt
3a2Cw3y2gPrIgTCqX6hJs7O5ByEhXZ8nBwsWANBUe4vrGaajQHdLj5OKfsIDrOvn
2NI1MqflqeAbu/kR32q3tq8/Rl+PPiwUsW3E6Pcf1orGMSNCXxeducF2iySySzh3
nSIhCG5uwJDWI7a4+9KiieFgK1pt/Iv30q1SQS8IEntTfXYwANQrfKUVMmVF9aIK
6/WZE2yd5+q3wVVIJ6jsmTzoDCX6QQkkJICIYwCkglmVy5AeTckOVwcXL0jqw5Kf
5/soZJQwLEyBoQq7Kbpa26QHq+CJONetPP8Ssy8MJJXBT+u/bSseMb3Zsr5cr43e
DJOhwsThAoGBAPY6rPKl2NT/K7XfRCGm1sbWjUQyDShscwuWJ5+kD0yudnT/ZEJ1
M3+KS/iOOAoHDdEDi9crRvMl0UfNa8MAcDKHflzxg2jg/QI+fTBjPP5GOX0lkZ9g
z6VePoVoQw2gpPFVNPPTxKfk27tEzbaffvOLGBEih0Kb7HTINkW8rIlzAoGBAM9y
1yr+jvfS1cGFtNU+Gotoihw2eMKtIqR03Yn3n0PK1nVCDKqwdUqCypz4+ml6cxRK
J8+Pfdh7D+ZJd4LEG6Y4QRDLuv5OA700tUoSHxMSNn3q9As4+T3MUyYxWKvTeu3U
f2NWP9ePU0lV8ttk7YlpVRaPQmc1qwooBA/z/8AdAoGAW9x0HWqmRICWTBnpjyxx
QGlW9rQ9mHEtUotIaRSJ6K/F3cxSGUEkX1a3FRnp6kPLcckC6NlqdNgNBd6rb2rA
cPl/uSkZP42Als+9YMoFPU/xrrDPbUhu72EDrj3Bllnyb168jKLa4VBOccUvggxr
Dm08I1hgYgdN5huzs7y6GeUCgYEAj+AZJSOJ6o1aXS6rfV3mMRve9bQ9yt8jcKXw
5HhOCEmMtaSKfnOF1Ziih34Sxsb7O2428DiX0mV/YHtBnPsAJidL0SdLWIapBzeg
KHArByIRkwE6IvJvwpGMdaex1PIGhx5i/3VZL9qiq/ElT05PhIb+UXgoWMabCp84
OgxDK20CgYAeaFo8BdQ7FmVX2+EEejF+8xSge6WVLtkaon8bqcn6P0O8lLypoOhd
mJAYH8WU+UAy9pecUnDZj14LAGNVmYcse8HFX71MoshnvCTFEPVo4rZxIAGwMpeJ
5jgQ3slYLpqrGlcbLgUXBUgzEO684Wk/UV9DFPlHALVqCfXQ9dpJPg==
-----END RSA PRIVATE KEY-----`)
func TestJWTFetch_JSONResponse(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
"scope": "user",
"token_type": "bearer"
}`))
}))
defer ts.Close()
f, err := New(
JWTClient("aaa@xxx.com", dummyPrivateKey),
JWTEndpoint(ts.URL),
)
if err != nil {
t.Error(err)
}
tr := f.NewTransport()
c := http.Client{Transport: tr}
c.Get(ts.URL)
tok := tr.Token()
if tok.Expired() {
t.Errorf("Token shouldn't be expired.")
}
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
}
if tok.TokenType != "bearer" {
t.Errorf("Unexpected token type, %#v.", tok.TokenType)
}
scope := tok.Extra("scope")
if scope != "user" {
t.Errorf("Unexpected value for scope: %v", scope)
}
}
func TestJWTFetch_BadResponse(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
}))
defer ts.Close()
f, err := New(
JWTClient("aaa@xxx.com", dummyPrivateKey),
JWTEndpoint(ts.URL),
)
if err != nil {
t.Error(err)
}
tr := f.NewTransport()
c := http.Client{Transport: tr}
c.Get(ts.URL)
tok := tr.Token()
if err != nil {
t.Errorf("Failed retrieving token: %s.", err)
}
if tok.AccessToken != "" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
}
if tok.TokenType != "bearer" {
t.Errorf("Unexpected token type, %#v.", tok.TokenType)
}
scope := tok.Extra("scope")
if scope != "user" {
t.Errorf("Unexpected value for scope: %v", scope)
}
}
func TestJWTFetch_BadResponseType(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`))
}))
defer ts.Close()
f, err := New(
JWTClient("aaa@xxx.com", dummyPrivateKey),
JWTEndpoint(ts.URL),
)
if err != nil {
t.Error(err)
}
tr := f.NewTransport()
c := http.Client{Transport: tr}
c.Get(ts.URL)
tok := tr.Token()
if err != nil {
t.Errorf("Failed retrieving token: %s.", err)
}
if tok.AccessToken != "" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
}
}
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package oauth2 provides support for making
// OAuth2 authorized and authenticated HTTP requests.
// It can additionally grant authorization with Bearer JWT.
package oauth2
import (
"crypto/rsa"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"mime"
"time"
"net/http"
"net/url"
"strconv"
"strings"
)
// TokenStore implementations read and write OAuth 2.0 tokens from a persistence layer.
type TokenStore interface {
// ReadToken reads the token from the store.
// If the read is successful, it should return the token and a nil error.
// The returned tokens may be expired tokens.
// If there is no token in the store, it should return a nil token and a nil error.
// It should return a non-nil error when an unrecoverable failure occurs.
ReadToken() (*Token, error)
// WriteToken writes the token to the cache.
WriteToken(*Token)
}
// Option represents a function that applies some state to
// an Options object.
type Option func(*Options) error
// Client requires the OAuth 2.0 client credentials. You need to provide
// the client identifier and optionally the client secret that are
// assigned to your application by the OAuth 2.0 provider.
func Client(id, secret string) Option {
return func(opts *Options) error {
opts.ClientID = id
opts.ClientSecret = secret
return nil
}
}
// RedirectURL requires the URL to which the user will be returned after
// granting (or denying) access.
func RedirectURL(url string) Option {
return func(opts *Options) error {
opts.RedirectURL = url
return nil
}
}
// Scope requires a list of requested permission scopes.
// It is optinal to specify scopes.
func Scope(scopes ...string) Option {
return func(o *Options) error {
o.Scopes = scopes
return nil
}
}
// Endpoint requires OAuth 2.0 provider's authorization and token endpoints.
func Endpoint(authURL, tokenURL string) Option {
return func(o *Options) error {
au, err := url.Parse(authURL)
if err != nil {
return err
}
tu, err := url.Parse(tokenURL)
if err != nil {
return err
}
o.AuthURL = au
o.TokenURL = tu
return nil
}
}
// HTTPClient allows you to provide a custom http.Client to be
// used to retrieve tokens from the OAuth 2.0 provider.
func HTTPClient(c *http.Client) Option {
return func(o *Options) error {
o.Client = c
return nil
}
}
// New builds a new options object and determines the type of the OAuth 2.0
// (2-legged, 3-legged or custom) by looking at the provided options.
// If the flow type cannot determined automatically, an error is returned.
func New(option ...Option) (*Options, error) {
opts := &Options{}
for _, fn := range option {
if err := fn(opts); err != nil {
return nil, err
}
}
switch {
case opts.TokenFetcherFunc != nil:
return opts, nil
case opts.AUD != nil:
// TODO(jbd): Assert the required JWT params.
opts.TokenFetcherFunc = makeTwoLeggedFetcher(opts)
return opts, nil
case opts.AuthURL != nil && opts.TokenURL != nil:
// TODO(jbd): Assert the required OAuth2 params.
opts.TokenFetcherFunc = makeThreeLeggedFetcher(opts)
return opts, nil
default:
return nil, errors.New("oauth2: missing endpoints, can't determine how to fetch tokens")
}
}
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
// that asks for permissions for the required scopes explicitly.
//
// State is a token to protect the user from CSRF attacks. You must
// always provide a non-zero string and validate that it matches the
// the state query parameter on your redirect callback.
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
//
// Access type is an OAuth extension that gets sent as the
// "access_type" field in the URL from AuthCodeURL.
// It may be "online" (default) or "offline".
// If your application needs to refresh access tokens when the
// user is not present at the browser, then use offline. This
// will result in your application obtaining a refresh token
// the first time your application exchanges an authorization
// code for a user.
//
// Approval prompt indicates whether the user should be
// re-prompted for consent. If set to "auto" (default) the
// user will be prompted only if they haven't previously
// granted consent and the code can only be exchanged for an
// access token. If set to "force" the user will always be prompted,
// and the code can be exchanged for a refresh token.
func (o *Options) AuthCodeURL(state, accessType, prompt string) string {
u := *o.AuthURL
v := url.Values{
"response_type": {"code"},
"client_id": {o.ClientID},
"redirect_uri": condVal(o.RedirectURL),
"scope": condVal(strings.Join(o.Scopes, " ")),
"state": condVal(state),
"access_type": condVal(accessType),
"approval_prompt": condVal(prompt),
}
q := v.Encode()
if u.RawQuery == "" {
u.RawQuery = q
} else {
u.RawQuery += "&" + q
}
return u.String()
}
// exchange exchanges the authorization code with the OAuth 2.0 provider
// to retrieve a new access token.
func (o *Options) exchange(code string) (*Token, error) {
return retrieveToken(o, url.Values{
"grant_type": {"authorization_code"},
"code": {code},
"redirect_uri": condVal(o.RedirectURL),
"scope": condVal(strings.Join(o.Scopes, " ")),
})
}
// NewTransportFromTokenStore reads the token from the store and returns
// a Transport that is authorized and the authenticated
// by the returned token.
func (o *Options) NewTransportFromTokenStore(store TokenStore) (*Transport, error) {
tok, err := store.ReadToken()
if err != nil {
return nil, err
}
o.TokenStore = store
if tok == nil {
return nil, nil
}
return o.newTransportFromToken(tok), nil
}
// NewTransportFromCode exchanges the code to retrieve a new access token
// and returns an authorized and authenticated Transport.
func (o *Options) NewTransportFromCode(code string) (*Transport, error) {
token, err := o.exchange(code)
if err != nil {
return nil, err
}
return o.newTransportFromToken(token), nil
}
// NewTransport returns a Transport.
func (o *Options) NewTransport() *Transport {
return o.newTransportFromToken(nil)
}
// newTransportFromToken returns a new Transport that is authorized
// and authenticated with the provided token.
func (o *Options) newTransportFromToken(t *Token) *Transport {
// TODO(jbd): App Engine options initiate an http.Client that
// depends on the urlfetcher, but it breaks the promise we made
// that the options object should be working finely with nil-values
// for the http.Client.
tr := http.DefaultTransport
if o.Client != nil && o.Client.Transport != nil {
tr = o.Client.Transport
}
return newTransport(tr, o, t)
}
func makeThreeLeggedFetcher(o *Options) func(t *Token) (*Token, error) {
return func(t *Token) (*Token, error) {
if t == nil || t.RefreshToken == "" {
return nil, errors.New("oauth2: cannot fetch access token without refresh token")
}
return retrieveToken(o, url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {t.RefreshToken},
})
}
}
// Options represents an object to keep the state of the OAuth 2.0 flow.
type Options struct {
// ClientID is the OAuth client identifier used when communicating with
// the configured OAuth provider.
ClientID string
// ClientSecret is the OAuth client secret used when communicating with
// the configured OAuth provider.
ClientSecret string
// RedirectURL is the URL to which the user will be returned after
// granting (or denying) access.
RedirectURL string
// Email is the OAuth client identifier used when communicating with
// the configured OAuth provider.
Email string
// PrivateKey contains the contents of an RSA private key or the
// contents of a PEM file that contains a private key. The provided
// private key is used to sign JWT payloads.
// PEM containers with a passphrase are not supported.
// Use the following command to convert a PKCS 12 file into a PEM.
//
// $ openssl pkcs12 -in key.p12 -out key.pem -nodes
//
PrivateKey *rsa.PrivateKey
// Scopes identify the level of access being requested.
Subject string
// Scopes optionally specifies a list of requested permission scopes.
Scopes []string
// AuthURL represents the authorization endpoint of the OAuth 2.0 provider.
AuthURL *url.URL
// TokenURL represents the token endpoint of the OAuth 2.0 provider.
TokenURL *url.URL
// AUD represents the token endpoint required to complete the 2-legged JWT flow.
AUD *url.URL
// TokenStore reads a token from the store and writes it back to the store
// if a token refresh occurs.
// Optional.
TokenStore TokenStore
TokenFetcherFunc func(t *Token) (*Token, error)
Client *http.Client
}
func retrieveToken(o *Options, v url.Values) (*Token, error) {
v.Set("client_id", o.ClientID)
bustedAuth := !providerAuthHeaderWorks(o.TokenURL.String())
if bustedAuth && o.ClientSecret != "" {
v.Set("client_secret", o.ClientSecret)
}
req, err := http.NewRequest("POST", o.TokenURL.String(), strings.NewReader(v.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
if !bustedAuth && o.ClientSecret != "" {
req.SetBasicAuth(o.ClientID, o.ClientSecret)
}
c := o.Client
if c == nil {
c = &http.Client{}
}
r, err := c.Do(req)
if err != nil {
return nil, err
}
defer r.Body.Close()
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
}
if code := r.StatusCode; code < 200 || code > 299 {
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body)
}
token := &Token{}
expires := int(0)
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
switch content {
case "application/x-www-form-urlencoded", "text/plain":
vals, err := url.ParseQuery(string(body))
if err != nil {
return nil, err
}
token.AccessToken = vals.Get("access_token")
token.TokenType = vals.Get("token_type")
token.RefreshToken = vals.Get("refresh_token")
token.raw = vals
e := vals.Get("expires_in")
if e == "" {
// TODO(jbd): Facebook's OAuth2 implementation is broken and
// returns expires_in field in expires. Remove the fallback to expires,
// when Facebook fixes their implementation.
e = vals.Get("expires")
}
expires, _ = strconv.Atoi(e)
default:
b := make(map[string]interface{})
if err = json.Unmarshal(body, &b); err != nil {
return nil, err
}
token.AccessToken, _ = b["access_token"].(string)
token.TokenType, _ = b["token_type"].(string)
token.RefreshToken, _ = b["refresh_token"].(string)
token.raw = b
e, ok := b["expires_in"].(float64)
if !ok {
// TODO(jbd): Facebook's OAuth2 implementation is broken and
// returns expires_in field in expires. Remove the fallback to expires,
// when Facebook fixes their implementation.
e, _ = b["expires"].(float64)
}
expires = int(e)
}
// Don't overwrite `RefreshToken` with an empty value
// if this was a token refreshing request.
if token.RefreshToken == "" {
token.RefreshToken = v.Get("refresh_token")
}
if expires == 0 {
token.Expiry = time.Time{}
} else {
token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
}
return token, nil
}
func condVal(v string) []string {
if v == "" {
return nil
}
return []string{v}
}
// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL
// implements the OAuth2 spec correctly
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
// In summary:
// - Reddit only accepts client secret in the Authorization header
// - Dropbox accepts either it in URL param or Auth header, but not both.
// - Google only accepts URL param (not spec compliant?), not Auth header
func providerAuthHeaderWorks(tokenURL string) bool {
if strings.HasPrefix(tokenURL, "https://accounts.google.com/") ||
strings.HasPrefix(tokenURL, "https://github.com/") ||
strings.HasPrefix(tokenURL, "https://api.instagram.com/") ||
strings.HasPrefix(tokenURL, "https://www.douban.com/") ||
strings.HasPrefix(tokenURL, "https://api.dropbox.com/") ||
strings.HasPrefix(tokenURL, "https://api.soundcloud.com/") ||
strings.HasPrefix(tokenURL, "https://www.linkedin.com/") {
// Some sites fail to implement the OAuth2 spec fully.
return false
}
// Assume the provider implements the spec properly
// otherwise. We can add more exceptions as they're
// discovered. We will _not_ be adding configurable hooks
// to this package to let users select server bugs.
return true
}
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2
import (
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
)
type mockTransport struct {
rt func(req *http.Request) (resp *http.Response, err error)
}
func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
return t.rt(req)
}
type mockCache struct {
token *Token
readErr error
}
func (c *mockCache) ReadToken() (*Token, error) {
return c.token, c.readErr
}
func (c *mockCache) WriteToken(*Token) {
// do nothing
}
func newOpts(url string) *Options {
opts, _ := New(
Client("CLIENT_ID", "CLIENT_SECRET"),
RedirectURL("REDIRECT_URL"),
Scope("scope1", "scope2"),
Endpoint(url+"/auth", url+"/token"),
)
return opts
}
func TestAuthCodeURL(t *testing.T) {
opts := newOpts("server")
url := opts.AuthCodeURL("foo", "offline", "force")
if url != "server/auth?access_type=offline&approval_prompt=force&client_id=CLIENT_ID&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=foo" {
t.Errorf("Auth code URL doesn't match the expected, found: %v", url)
}
}
func TestAuthCodeURL_Optional(t *testing.T) {
opts, _ := New(
Client("CLIENT_ID", ""),
Endpoint("auth-url", "token-token"),
)
url := opts.AuthCodeURL("", "", "")
if url != "auth-url?client_id=CLIENT_ID&response_type=code" {
t.Fatalf("Auth code URL doesn't match the expected, found: %v", url)
}
}
func TestExchangeRequest(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() != "/token" {
t.Errorf("Unexpected exchange request URL, %v is found.", r.URL)
}
headerAuth := r.Header.Get("Authorization")
if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" {
t.Errorf("Unexpected authorization header, %v is found.", headerAuth)
}
headerContentType := r.Header.Get("Content-Type")
if headerContentType != "application/x-www-form-urlencoded" {
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed reading request body: %s.", err)
}
if string(body) != "client_id=CLIENT_ID&code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL&scope=scope1+scope2" {
t.Errorf("Unexpected exchange payload, %v is found.", string(body))
}
w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
}))
defer ts.Close()
opts := newOpts(ts.URL)
tr, err := opts.NewTransportFromCode("exchange-code")
if err != nil {
t.Error(err)
}
tok := tr.Token()
if tok.Expired() {
t.Errorf("Token shouldn't be expired.")
}
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
}
if tok.TokenType != "bearer" {
t.Errorf("Unexpected token type, %#v.", tok.TokenType)
}
scope := tok.Extra("scope")
if scope != "user" {
t.Errorf("Unexpected value for scope: %v", scope)
}
}
func TestExchangeRequest_JSONResponse(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() != "/token" {
t.Errorf("Unexpected exchange request URL, %v is found.", r.URL)
}
headerAuth := r.Header.Get("Authorization")
if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" {
t.Errorf("Unexpected authorization header, %v is found.", headerAuth)
}
headerContentType := r.Header.Get("Content-Type")
if headerContentType != "application/x-www-form-urlencoded" {
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
}
body, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Errorf("Failed reading request body: %s.", err)
}
if string(body) != "client_id=CLIENT_ID&code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL&scope=scope1+scope2" {
t.Errorf("Unexpected exchange payload, %v is found.", string(body))
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer", "expires_in": 86400}`))
}))
defer ts.Close()
opts := newOpts(ts.URL)
tr, err := opts.NewTransportFromCode("exchange-code")
if err != nil {
t.Error(err)
}
tok := tr.Token()
if tok.Expiry.IsZero() {
t.Errorf("Token expiry should not be zero.")
}
if tok.Expired() {
t.Errorf("Token shouldn't be expired.")
}
if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
}
if tok.TokenType != "bearer" {
t.Errorf("Unexpected token type, %#v.", tok.TokenType)
}
scope := tok.Extra("scope")
if scope != "user" {
t.Errorf("Unexpected value for scope: %v", scope)
}
}
func TestExchangeRequest_BadResponse(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
}))
defer ts.Close()
opts := newOpts(ts.URL)
tr, err := opts.NewTransportFromCode("exchange-code")
if err != nil {
t.Error(err)
}
tok := tr.Token()
if tok.AccessToken != "" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
}
}
func TestExchangeRequest_BadResponseType(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`))
}))
defer ts.Close()
opts := newOpts(ts.URL)
tr, err := opts.NewTransportFromCode("exchange-code")
if err != nil {
t.Error(err)
}
tok := tr.Token()
if tok.AccessToken != "" {
t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
}
}
func TestExchangeRequest_NonBasicAuth(t *testing.T) {
tr := &mockTransport{
rt: func(r *http.Request) (w *http.Response, err error) {
headerAuth := r.Header.Get("Authorization")
if headerAuth != "" {
t.Errorf("Unexpected authorization header, %v is found.", headerAuth)
}
return nil, errors.New("no response")
},
}
c := &http.Client{Transport: tr}
opts, err := New(
Client("CLIENT_ID", ""),
Endpoint("https://accounts.google.com/auth", "https://accounts.google.com/token"),
HTTPClient(c),
)
if err != nil {
t.Error(err)
}
opts.NewTransportFromCode("code")
}
func TestTokenRefreshRequest(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == "/somethingelse" {
return
}
if r.URL.String() != "/token" {
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
}
headerContentType := r.Header.Get("Content-Type")
if headerContentType != "application/x-www-form-urlencoded" {
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
}
body, _ := ioutil.ReadAll(r.Body)
if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
t.Errorf("Unexpected refresh token payload, %v is found.", string(body))
}
}))
defer ts.Close()
opts := newOpts(ts.URL)
tr := opts.NewTransport()
tr.token = &Token{RefreshToken: "REFRESH_TOKEN"}
c := http.Client{Transport: tr}
c.Get(ts.URL + "/somethingelse")
}
func TestFetchWithNoRefreshToken(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.String() == "/somethingelse" {
return
}
if r.URL.String() != "/token" {
t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
}
headerContentType := r.Header.Get("Content-Type")
if headerContentType != "application/x-www-form-urlencoded" {
t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
}
body, _ := ioutil.ReadAll(r.Body)
if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
t.Errorf("Unexpected refresh token payload, %v is found.", string(body))
}
}))
defer ts.Close()
opts := newOpts(ts.URL)
tr := opts.NewTransport()
c := http.Client{Transport: tr}
_, err := c.Get(ts.URL + "/somethingelse")
if err == nil {
t.Errorf("Fetch should return an error if no refresh token is set")
}
}
func TestCacheNoToken(t *testing.T) {
opts, err := New(
Client("CLIENT_ID", "CLIENT_SECRET"),
Endpoint("/auth", "/token"),
)
if err != nil {
t.Error(err)
}
tr, err := opts.NewTransportFromTokenStore(&mockCache{token: nil, readErr: nil})
if err != nil {
t.Errorf("No error expected, %v is found", err)
}
if tr != nil {
t.Errorf("No transport should have been initiated, tr is found to be %v", tr)
}
}
// Copyright 2014 The oauth2 Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package oauth2
import (
"net/http"
"net/url"
"sync"
"time"
)
const (
defaultTokenType = "Bearer"
)
// Token represents the crendentials used to authorize
// the requests to access protected resources on the OAuth 2.0
// provider's backend.
type Token struct {
// A token that authorizes and authenticates the requests.
AccessToken string `json:"access_token"`
// Identifies the type of token returned.
TokenType string `json:"token_type,omitempty"`
// A token that may be used to obtain a new access token.
RefreshToken string `json:"refresh_token,omitempty"`
// The remaining lifetime of the access token.
Expiry time.Time `json:"expiry,omitempty"`
// raw optionally contains extra metadata from the server
// when updating a token.
raw interface{}
}
// Extra returns an extra field returned from the server during token retrieval.
// E.g.
// idToken := token.Extra("id_token")
//
func (t *Token) Extra(key string) string {
if vals, ok := t.raw.(url.Values); ok {
return vals.Get(key)
}
if raw, ok := t.raw.(map[string]interface{}); ok {
if val, ok := raw[key].(string); ok {
return val
}
}
return ""
}
// Expired returns true if there is no access token or the
// access token is expired.
func (t *Token) Expired() bool {
if t.AccessToken == "" {
return true
}
if t.Expiry.IsZero() {
return false
}
return t.Expiry.Before(time.Now())
}
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests.
type Transport struct {
opts *Options
base http.RoundTripper
mu sync.RWMutex
token *Token
}
// NewTransport creates a new Transport that uses the provided
// token fetcher as token retrieving strategy. It authenticates
// the requests and delegates origTransport to make the actual requests.
func newTransport(base http.RoundTripper, opts *Options, token *Token) *Transport {
return &Transport{
base: base,
opts: opts,
token: token,
}
}
// RoundTrip authorizes and authenticates the request with an
// access token. If no token exists or token is expired,
// tries to refresh/fetch a new token.
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
token := t.token
if token == nil || token.Expired() {
// Check if the token is refreshable.
// If token is refreshable, don't return an error,
// rather refresh.
if err := t.refreshToken(); err != nil {
return nil, err
}
token = t.token
if t.opts.TokenStore != nil {
t.opts.TokenStore.WriteToken(token)
}
}
// To set the Authorization header, we must make a copy of the Request
// so that we don't modify the Request we were given.
// This is required by the specification of http.RoundTripper.
req = cloneRequest(req)
typ := token.TokenType
if typ == "" {
typ = defaultTokenType
}
req.Header.Set("Authorization", typ+" "+token.AccessToken)
return t.base.RoundTrip(req)
}
// Token returns the token that authorizes and
// authenticates the transport.
func (t *Transport) Token() *Token {
t.mu.RLock()
defer t.mu.RUnlock()
return t.token
}
// refreshToken retrieves a new token, if a refreshing/fetching
// method is known and required credentials are presented
// (such as a refresh token).
func (t *Transport) refreshToken() error {
t.mu.Lock()
defer t.mu.Unlock()
token, err := t.opts.TokenFetcherFunc(t.token)
if err != nil {
return err
}
t.token = token
return nil
}
// cloneRequest returns a clone of the provided *http.Request.
// The clone is a shallow copy of the struct and its Header map.
func cloneRequest(r *http.Request) *http.Request {
// shallow copy of the struct
r2 := new(http.Request)
*r2 = *r
// deep copy of the Header
r2.Header = make(http.Header)
for k, s := range r.Header {
r2.Header[k] = s
}
return r2
}
package oauth2
import (
"net/http"
"net/http/httptest"
"testing"
"time"
)
type mockTokenFetcher struct{ token *Token }
func (f *mockTokenFetcher) Fn() func(*Token) (*Token, error) {
return func(*Token) (*Token, error) {
return f.token, nil
}
}
func TestInitialTokenRead(t *testing.T) {
tr := newTransport(http.DefaultTransport, nil, &Token{AccessToken: "abc"})
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "Bearer abc" {
t.Errorf("Transport doesn't set the Authorization header from the initial token")
}
})
defer server.Close()
client := http.Client{Transport: tr}
client.Get(server.URL)
}
func TestTokenFetch(t *testing.T) {
fetcher := &mockTokenFetcher{
token: &Token{
AccessToken: "abc",
},
}
tr := newTransport(http.DefaultTransport, &Options{TokenFetcherFunc: fetcher.Fn()}, nil)
server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
if r.Header.Get("Authorization") != "Bearer abc" {
t.Errorf("Transport doesn't set the Authorization header from the fetched token")
}
})
defer server.Close()
client := http.Client{Transport: tr}
client.Get(server.URL)
if tr.Token().AccessToken != "abc" {
t.Errorf("New token is not set, found %v", tr.Token())
}
}
func TestExpiredWithNoAccessToken(t *testing.T) {
token := &Token{}
if !token.Expired() {
t.Errorf("Token should be expired if no access token is provided")
}
}
func TestExpiredWithExpiry(t *testing.T) {
token := &Token{
Expiry: time.Now().Add(-5 * time.Hour),
}
if !token.Expired() {
t.Errorf("Token should be expired if no access token is provided")
}
}
func newMockServer(handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(handler))
}
File added
......@@ -4,6 +4,8 @@ import (
"errors"
"fmt"
"golang.org/x/oauth2"
"github.com/torkelo/grafana-pro/pkg/bus"
"github.com/torkelo/grafana-pro/pkg/log"
"github.com/torkelo/grafana-pro/pkg/middleware"
......@@ -27,13 +29,13 @@ func OAuthLogin(ctx *middleware.Context) {
code := ctx.Query("code")
if code == "" {
ctx.Redirect(connect.AuthCodeURL("", "online", "auto"))
ctx.Redirect(connect.AuthCodeURL("", oauth2.AccessTypeOnline))
return
}
log.Info("code: %v", code)
// handle call back
transport, err := connect.NewTransportFromCode(code)
token, err := connect.Exchange(oauth2.NoContext, code)
if err != nil {
ctx.Handle(500, "login.OAuthLogin(NewTransportWithCode)", err)
return
......@@ -41,7 +43,7 @@ func OAuthLogin(ctx *middleware.Context) {
log.Trace("login.OAuthLogin(Got token)")
userInfo, err := connect.UserInfo(transport)
userInfo, err := connect.UserInfo(token)
if err != nil {
ctx.Handle(500, fmt.Sprintf("login.OAuthLogin(get info from %s)", name), err)
return
......
......@@ -2,15 +2,13 @@ package social
import (
"encoding/json"
"net/http"
"strconv"
"strings"
"github.com/torkelo/grafana-pro/pkg/log"
"github.com/torkelo/grafana-pro/pkg/models"
"github.com/torkelo/grafana-pro/pkg/setting"
"github.com/golang/oauth2"
"golang.org/x/oauth2"
)
type BasicUserInfo struct {
......@@ -23,10 +21,10 @@ type BasicUserInfo struct {
type SocialConnector interface {
Type() int
UserInfo(transport *oauth2.Transport) (*BasicUserInfo, error)
UserInfo(token *oauth2.Token) (*BasicUserInfo, error)
AuthCodeURL(state, accessType, prompt string) string
NewTransportFromCode(code string) (*oauth2.Transport, error)
AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string
Exchange(ctx oauth2.Context, code string) (*oauth2.Token, error)
}
var (
......@@ -60,41 +58,40 @@ func NewOAuthService() {
}
setting.OAuthService.OAuthInfos[name] = info
options, err := oauth2.New(
oauth2.Client(info.ClientId, info.ClientSecret),
oauth2.Scope(info.Scopes...),
oauth2.Endpoint(info.AuthUrl, info.TokenUrl),
oauth2.RedirectURL(strings.TrimSuffix(setting.AppUrl, "/")+SocialBaseUrl+name),
)
if err != nil {
log.Error(3, "Failed to init oauth service", err)
continue
config := oauth2.Config{
ClientID: info.ClientId,
ClientSecret: info.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: info.AuthUrl,
TokenURL: info.TokenUrl,
},
RedirectURL: strings.TrimSuffix(setting.AppUrl, "/") + SocialBaseUrl + name,
Scopes: info.Scopes,
}
// GitHub.
if name == "github" {
setting.OAuthService.GitHub = true
SocialMap["github"] = &SocialGithub{Options: options}
SocialMap["github"] = &SocialGithub{Config: &config}
}
// Google.
if name == "google" {
setting.OAuthService.Google = true
SocialMap["google"] = &SocialGoogle{Options: options}
SocialMap["google"] = &SocialGoogle{Config: &config}
}
}
}
type SocialGithub struct {
*oauth2.Options
*oauth2.Config
}
func (s *SocialGithub) Type() int {
return int(models.GITHUB)
}
func (s *SocialGithub) UserInfo(transport *oauth2.Transport) (*BasicUserInfo, error) {
func (s *SocialGithub) UserInfo(token *oauth2.Token) (*BasicUserInfo, error) {
var data struct {
Id int `json:"id"`
Name string `json:"login"`
......@@ -102,7 +99,7 @@ func (s *SocialGithub) UserInfo(transport *oauth2.Transport) (*BasicUserInfo, er
}
var err error
client := http.Client{Transport: transport}
client := s.Client(oauth2.NoContext, token)
r, err := client.Get("https://api.github.com/user")
if err != nil {
return nil, err
......@@ -129,14 +126,14 @@ func (s *SocialGithub) UserInfo(transport *oauth2.Transport) (*BasicUserInfo, er
// \/ /_____/ \/
type SocialGoogle struct {
*oauth2.Options
*oauth2.Config
}
func (s *SocialGoogle) Type() int {
return int(models.GOOGLE)
}
func (s *SocialGoogle) UserInfo(transport *oauth2.Transport) (*BasicUserInfo, error) {
func (s *SocialGoogle) UserInfo(token *oauth2.Token) (*BasicUserInfo, error) {
var data struct {
Id string `json:"id"`
Name string `json:"name"`
......@@ -145,7 +142,7 @@ func (s *SocialGoogle) UserInfo(transport *oauth2.Transport) (*BasicUserInfo, er
var err error
reqUrl := "https://www.googleapis.com/oauth2/v1/userinfo"
client := http.Client{Transport: transport}
client := s.Client(oauth2.NoContext, token)
r, err := client.Get(reqUrl)
if err != nil {
return nil, err
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment