Skip to content

Instantly share code, notes, and snippets.

@abursavich
Last active August 29, 2015 14:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save abursavich/ec5d1cf2be49ec9fec38 to your computer and use it in GitHub Desktop.
Save abursavich/ec5d1cf2be49ec9fec38 to your computer and use it in GitHub Desktop.
package oauth2
import (
"errors"
"io"
"net/http"
"sync"
"time"
"github.com/golang/groupcache/singleflight"
)
// START IGNORE -- This has no bearing on the cancellation stuff and is held over
// from something else I was fiddling with a while back.
const defaultTokenType = "Bearer"
type Token struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type,omitempty"`
RefreshToken string `json:"refresh_token,omitempty"`
Expiry time.Time `json:"expiry,omitempty"`
Extra map[string]string `json:"extra,omitempty"`
Subject string `json:"subject,omitempty"`
}
func (t *Token) Expired() bool {
if t.AccessToken == "" {
return true
}
if t.Expiry.IsZero() {
return false
}
return t.Expiry.Before(time.Now())
}
type TokenFetcher interface {
FetchToken(existing *Token) (*Token, error)
}
type TokenStore interface {
Token() *Token
SetToken(t *Token) error
RefreshToken() (*Token, error)
}
func NewTokenStore(t *Token, tf TokenFetcher) TokenStore {
return &tokenStore{t: t, tf: tf}
}
type tokenStore struct {
mu sync.RWMutex // protects t
t *Token
sf singleflight.Group // collapses concurrent fetches
tf TokenFetcher
}
func (ts *tokenStore) Token() *Token {
ts.mu.RLock()
t := ts.t
ts.mu.RUnlock()
return t
}
func (ts *tokenStore) SetToken(t *Token) error {
ts.mu.Lock()
ts.t = t
ts.mu.Unlock()
return nil
}
func (ts *tokenStore) RefreshToken() (*Token, error) {
t, err := ts.sf.Do("", func() (interface{}, error) {
var err error
t := ts.Token()
t, err = ts.tf.FetchToken(t)
if err != nil {
return nil, err
}
ts.SetToken(t)
return t, nil
})
return t.(*Token), err
}
// END IGNORE
type Transport interface {
http.RoundTripper
TokenStore // IGNORE
}
type cancelRoundTripper interface {
http.RoundTripper
CancelRequest(*http.Request)
}
func NewTransport(rt http.RoundTripper, ts TokenStore) Transport {
if rt == nil {
rt = http.DefaultTransport
}
if crt, ok := rt.(cancelRoundTripper); ok {
return &cancelerTransport{
crt: crt,
TokenStore: ts,
reqCanceler: make(map[*http.Request]func()),
}
}
return &transport{rt: rt, TokenStore: ts}
}
type transport struct {
rt http.RoundTripper
TokenStore // IGNORE
}
func (t *transport) RoundTrip(r *http.Request) (*http.Response, error) {
req, err := makeRequest(t, r)
if err != nil {
return nil, err
}
return t.rt.RoundTrip(req)
}
type cancelerTransport struct {
crt cancelRoundTripper
TokenStore // IGNORE
reqMu sync.Mutex
reqCanceler map[*http.Request]func()
}
func (t *cancelerTransport) RoundTrip(r *http.Request) (*http.Response, error) {
cancel := make(chan struct{})
t.setReqCanceler(r, false, func() {
close(cancel)
})
req, err := cancelMakeRequest(cancel, t, r)
if err != nil {
t.setReqCanceler(r, true, nil)
return nil, err
}
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-cancel:
t.crt.CancelRequest(req)
case <-done:
}
}()
resp, err := t.crt.RoundTrip(req)
if err != nil {
t.setReqCanceler(r, true, nil)
return resp, err
}
if !t.setReqCanceler(r, true, func() { t.crt.CancelRequest(req) }) {
// Lost a race and there was already an attempt to cancel the request.
t.crt.CancelRequest(req)
return resp, nil
}
resp.Body = &cancelBody{resp.Body, func() {
t.setReqCanceler(r, true, nil)
}}
return resp, nil
}
func (t *cancelerTransport) CancelRequest(r *http.Request) {
t.reqMu.Lock()
if cancel, ok := t.reqCanceler[r]; ok {
delete(t.reqCanceler, r)
defer cancel()
}
t.reqMu.Unlock()
}
func (t *cancelerTransport) setReqCanceler(r *http.Request, exists bool, fn func()) bool {
t.reqMu.Lock()
defer t.reqMu.Unlock()
if _, ok := t.reqCanceler[r]; ok != exists {
return false
}
if fn != nil {
t.reqCanceler[r] = fn
} else {
delete(t.reqCanceler, r)
}
return true
}
func cancelMakeRequest(cancel chan struct{}, t Transport, r *http.Request) (*http.Request, error) {
type result struct {
req *http.Request
err error
}
resCh := make(chan result, 1)
go func() {
req, err := makeRequest(t, r)
resCh <- result{req, err}
}()
select {
case res := <-resCh:
return res.req, res.err
case <-cancel:
return nil, errors.New("oauth2: request canceled")
}
}
func makeRequest(t Transport, r *http.Request) (*http.Request, error) {
tok := t.Token()
if tok == nil || tok.Expired() {
var err error
tok, err = t.RefreshToken()
if err != nil {
return nil, err
}
}
req := *r
req.Header = make(http.Header, len(r.Header)+1)
for k, v := range r.Header {
req.Header[k] = v
}
typ := tok.TokenType
if typ == "" {
typ = defaultTokenType
}
req.Header.Set("Authorization", typ+" "+tok.AccessToken)
return &req, nil
}
type cancelBody struct {
rc io.ReadCloser
done func()
}
func (b *cancelBody) Read(p []byte) (n int, err error) {
n, err = b.rc.Read(p)
if err == io.EOF {
// TODO: It might be overkill to do this here.
// Perhaps only on Close is sufficient.
b.done()
}
return
}
func (b *cancelBody) Close() error {
err := b.rc.Close()
b.done()
return err
}
package oauth2
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"
)
type testTokenStore struct {
mu sync.Mutex
tok *Token
d time.Duration
}
func (t *testTokenStore) Token() *Token {
t.mu.Lock()
defer t.mu.Unlock()
return t.tok
}
func (t *testTokenStore) SetToken(tok *Token) error {
t.mu.Lock()
t.tok = tok
t.mu.Unlock()
return nil
}
func (t *testTokenStore) RefreshToken() (*Token, error) {
if t.d > 0 {
time.Sleep(t.d)
}
t.mu.Lock()
defer t.mu.Unlock()
tok := new(Token)
*tok = *t.tok
tok.Expiry = time.Now().Add(10 * time.Second)
t.tok = tok
return tok, nil
}
type requestCanceler interface {
CancelRequest(r *http.Request)
}
func TestHeaderTimeout(t *testing.T) {
t.Parallel()
ts := &testTokenStore{tok: &Token{AccessToken: "FOOBARTOKEN"}}
transport := NewTransport(nil, ts)
if _, ok := transport.(requestCanceler); !ok {
t.Fatal("transport does not implement CancelRequest(*http.Request)")
}
client := &http.Client{
Transport: transport,
Timeout: 10 * time.Millisecond,
}
done := make(chan struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case <-done:
case <-time.After(client.Timeout * 10):
}
}))
defer srv.Close()
defer close(done)
_, err := client.Get(srv.URL)
if err == nil {
t.Error("no error getting headers")
}
}
func TestBodyTimeout(t *testing.T) {
t.Parallel()
ts := &testTokenStore{tok: &Token{AccessToken: "FOOBARTOKEN"}}
transport := NewTransport(nil, ts)
if _, ok := transport.(requestCanceler); !ok {
t.Fatal("transport does not implement CancelRequest(*http.Request)")
}
client := &http.Client{
Transport: transport,
Timeout: 100 * time.Millisecond,
}
done := make(chan struct{})
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(make([]byte, 1<<14)) // large enough to flush the buffer
select {
case <-done:
case <-time.After(client.Timeout * 10):
}
}))
defer srv.Close()
defer close(done)
resp, err := client.Get(srv.URL)
if err != nil {
// TODO: This might be a little flaky since it relies on timing
// but a better solution isn't jumping out at me at the moment.
t.Errorf("unexpected error: %q", err.Error())
return
}
_, err = ioutil.ReadAll(resp.Body)
if err == nil {
t.Error("no error reading body")
}
}
func TestRefreshTimeout(t *testing.T) {
t.Parallel()
ts := &testTokenStore{
tok: &Token{
AccessToken: "FOOBARTOKEN",
Expiry: time.Now().Add(-time.Hour),
},
d: 5 * time.Second,
}
transport := NewTransport(nil, ts)
client := &http.Client{
Transport: transport,
Timeout: 10 * time.Millisecond,
}
_, err := client.Get("http://localhost")
if exp := "oauth2: request canceled"; !strings.HasSuffix(err.Error(), exp) {
t.Errorf("error expected: %q; got: %q", exp, err.Error())
}
}
func TestCancel(t *testing.T) {
t.Parallel()
ts := &testTokenStore{tok: &Token{AccessToken: "FOOBARTOKEN"}}
transport := NewTransport(nil, ts)
canceler, ok := transport.(requestCanceler)
if !ok {
t.Fatal("transport does not implement CancelRequest(*http.Request)")
}
client := &http.Client{Transport: transport}
sig := make(chan bool)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
<-sig
<-sig
}))
defer srv.Close()
defer close(sig)
req, _ := http.NewRequest("GET", srv.URL, nil)
errCh := make(chan error, 1)
go func() {
_, err := client.Do(req)
errCh <- err
}()
sig <- true
canceler.CancelRequest(req)
select {
case err := <-errCh:
if err == nil {
t.Error("no error")
}
case <-time.After(time.Second):
t.Error("timeout: no error")
}
}
func TestUnsupportedTimeout(t *testing.T) {
t.Parallel()
type basicRoundTripper struct {
http.RoundTripper
}
rt := basicRoundTripper{http.DefaultTransport}
ts := &testTokenStore{tok: &Token{AccessToken: "FOOBARTOKEN"}}
transport := NewTransport(rt, ts)
if _, ok := transport.(requestCanceler); ok {
t.Fatal("transport should not implement CancelRequest(*http.Request)")
}
client := &http.Client{
Transport: transport,
Timeout: time.Millisecond,
}
_, err := client.Get("http://localhost")
if exp := "Timeout not supported"; err == nil {
t.Errorf("error expected: %q; got: <nil>", exp)
} else if !strings.HasSuffix(err.Error(), exp) {
t.Errorf("error expected: %q; got: %q", exp, err.Error())
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment