Last active
August 29, 2015 14:09
-
-
Save abursavich/ec5d1cf2be49ec9fec38 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package oauth2 | |
import ( | |
"errors" | |
"io" | |
"net/http" | |
"sync" | |
"time" | |
"github.com/golang/groupcache/singleflight" | |
) | |
// START IGNORE -- This has no bearing on the cancellation stuff and is held over | |
// from something else I was fiddling with a while back. | |
const defaultTokenType = "Bearer" | |
type Token struct { | |
AccessToken string `json:"access_token"` | |
TokenType string `json:"token_type,omitempty"` | |
RefreshToken string `json:"refresh_token,omitempty"` | |
Expiry time.Time `json:"expiry,omitempty"` | |
Extra map[string]string `json:"extra,omitempty"` | |
Subject string `json:"subject,omitempty"` | |
} | |
func (t *Token) Expired() bool { | |
if t.AccessToken == "" { | |
return true | |
} | |
if t.Expiry.IsZero() { | |
return false | |
} | |
return t.Expiry.Before(time.Now()) | |
} | |
type TokenFetcher interface { | |
FetchToken(existing *Token) (*Token, error) | |
} | |
type TokenStore interface { | |
Token() *Token | |
SetToken(t *Token) error | |
RefreshToken() (*Token, error) | |
} | |
func NewTokenStore(t *Token, tf TokenFetcher) TokenStore { | |
return &tokenStore{t: t, tf: tf} | |
} | |
type tokenStore struct { | |
mu sync.RWMutex // protects t | |
t *Token | |
sf singleflight.Group // collapses concurrent fetches | |
tf TokenFetcher | |
} | |
func (ts *tokenStore) Token() *Token { | |
ts.mu.RLock() | |
t := ts.t | |
ts.mu.RUnlock() | |
return t | |
} | |
func (ts *tokenStore) SetToken(t *Token) error { | |
ts.mu.Lock() | |
ts.t = t | |
ts.mu.Unlock() | |
return nil | |
} | |
func (ts *tokenStore) RefreshToken() (*Token, error) { | |
t, err := ts.sf.Do("", func() (interface{}, error) { | |
var err error | |
t := ts.Token() | |
t, err = ts.tf.FetchToken(t) | |
if err != nil { | |
return nil, err | |
} | |
ts.SetToken(t) | |
return t, nil | |
}) | |
return t.(*Token), err | |
} | |
// END IGNORE | |
type Transport interface { | |
http.RoundTripper | |
TokenStore // IGNORE | |
} | |
type cancelRoundTripper interface { | |
http.RoundTripper | |
CancelRequest(*http.Request) | |
} | |
func NewTransport(rt http.RoundTripper, ts TokenStore) Transport { | |
if rt == nil { | |
rt = http.DefaultTransport | |
} | |
if crt, ok := rt.(cancelRoundTripper); ok { | |
return &cancelerTransport{ | |
crt: crt, | |
TokenStore: ts, | |
reqCanceler: make(map[*http.Request]func()), | |
} | |
} | |
return &transport{rt: rt, TokenStore: ts} | |
} | |
type transport struct { | |
rt http.RoundTripper | |
TokenStore // IGNORE | |
} | |
func (t *transport) RoundTrip(r *http.Request) (*http.Response, error) { | |
req, err := makeRequest(t, r) | |
if err != nil { | |
return nil, err | |
} | |
return t.rt.RoundTrip(req) | |
} | |
type cancelerTransport struct { | |
crt cancelRoundTripper | |
TokenStore // IGNORE | |
reqMu sync.Mutex | |
reqCanceler map[*http.Request]func() | |
} | |
func (t *cancelerTransport) RoundTrip(r *http.Request) (*http.Response, error) { | |
cancel := make(chan struct{}) | |
t.setReqCanceler(r, false, func() { | |
close(cancel) | |
}) | |
req, err := cancelMakeRequest(cancel, t, r) | |
if err != nil { | |
t.setReqCanceler(r, true, nil) | |
return nil, err | |
} | |
done := make(chan struct{}) | |
defer close(done) | |
go func() { | |
select { | |
case <-cancel: | |
t.crt.CancelRequest(req) | |
case <-done: | |
} | |
}() | |
resp, err := t.crt.RoundTrip(req) | |
if err != nil { | |
t.setReqCanceler(r, true, nil) | |
return resp, err | |
} | |
if !t.setReqCanceler(r, true, func() { t.crt.CancelRequest(req) }) { | |
// Lost a race and there was already an attempt to cancel the request. | |
t.crt.CancelRequest(req) | |
return resp, nil | |
} | |
resp.Body = &cancelBody{resp.Body, func() { | |
t.setReqCanceler(r, true, nil) | |
}} | |
return resp, nil | |
} | |
func (t *cancelerTransport) CancelRequest(r *http.Request) { | |
t.reqMu.Lock() | |
if cancel, ok := t.reqCanceler[r]; ok { | |
delete(t.reqCanceler, r) | |
defer cancel() | |
} | |
t.reqMu.Unlock() | |
} | |
func (t *cancelerTransport) setReqCanceler(r *http.Request, exists bool, fn func()) bool { | |
t.reqMu.Lock() | |
defer t.reqMu.Unlock() | |
if _, ok := t.reqCanceler[r]; ok != exists { | |
return false | |
} | |
if fn != nil { | |
t.reqCanceler[r] = fn | |
} else { | |
delete(t.reqCanceler, r) | |
} | |
return true | |
} | |
func cancelMakeRequest(cancel chan struct{}, t Transport, r *http.Request) (*http.Request, error) { | |
type result struct { | |
req *http.Request | |
err error | |
} | |
resCh := make(chan result, 1) | |
go func() { | |
req, err := makeRequest(t, r) | |
resCh <- result{req, err} | |
}() | |
select { | |
case res := <-resCh: | |
return res.req, res.err | |
case <-cancel: | |
return nil, errors.New("oauth2: request canceled") | |
} | |
} | |
func makeRequest(t Transport, r *http.Request) (*http.Request, error) { | |
tok := t.Token() | |
if tok == nil || tok.Expired() { | |
var err error | |
tok, err = t.RefreshToken() | |
if err != nil { | |
return nil, err | |
} | |
} | |
req := *r | |
req.Header = make(http.Header, len(r.Header)+1) | |
for k, v := range r.Header { | |
req.Header[k] = v | |
} | |
typ := tok.TokenType | |
if typ == "" { | |
typ = defaultTokenType | |
} | |
req.Header.Set("Authorization", typ+" "+tok.AccessToken) | |
return &req, nil | |
} | |
type cancelBody struct { | |
rc io.ReadCloser | |
done func() | |
} | |
func (b *cancelBody) Read(p []byte) (n int, err error) { | |
n, err = b.rc.Read(p) | |
if err == io.EOF { | |
// TODO: It might be overkill to do this here. | |
// Perhaps only on Close is sufficient. | |
b.done() | |
} | |
return | |
} | |
func (b *cancelBody) Close() error { | |
err := b.rc.Close() | |
b.done() | |
return err | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package oauth2 | |
import ( | |
"io/ioutil" | |
"net/http" | |
"net/http/httptest" | |
"strings" | |
"sync" | |
"testing" | |
"time" | |
) | |
type testTokenStore struct { | |
mu sync.Mutex | |
tok *Token | |
d time.Duration | |
} | |
func (t *testTokenStore) Token() *Token { | |
t.mu.Lock() | |
defer t.mu.Unlock() | |
return t.tok | |
} | |
func (t *testTokenStore) SetToken(tok *Token) error { | |
t.mu.Lock() | |
t.tok = tok | |
t.mu.Unlock() | |
return nil | |
} | |
func (t *testTokenStore) RefreshToken() (*Token, error) { | |
if t.d > 0 { | |
time.Sleep(t.d) | |
} | |
t.mu.Lock() | |
defer t.mu.Unlock() | |
tok := new(Token) | |
*tok = *t.tok | |
tok.Expiry = time.Now().Add(10 * time.Second) | |
t.tok = tok | |
return tok, nil | |
} | |
type requestCanceler interface { | |
CancelRequest(r *http.Request) | |
} | |
func TestHeaderTimeout(t *testing.T) { | |
t.Parallel() | |
ts := &testTokenStore{tok: &Token{AccessToken: "FOOBARTOKEN"}} | |
transport := NewTransport(nil, ts) | |
if _, ok := transport.(requestCanceler); !ok { | |
t.Fatal("transport does not implement CancelRequest(*http.Request)") | |
} | |
client := &http.Client{ | |
Transport: transport, | |
Timeout: 10 * time.Millisecond, | |
} | |
done := make(chan struct{}) | |
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
select { | |
case <-done: | |
case <-time.After(client.Timeout * 10): | |
} | |
})) | |
defer srv.Close() | |
defer close(done) | |
_, err := client.Get(srv.URL) | |
if err == nil { | |
t.Error("no error getting headers") | |
} | |
} | |
func TestBodyTimeout(t *testing.T) { | |
t.Parallel() | |
ts := &testTokenStore{tok: &Token{AccessToken: "FOOBARTOKEN"}} | |
transport := NewTransport(nil, ts) | |
if _, ok := transport.(requestCanceler); !ok { | |
t.Fatal("transport does not implement CancelRequest(*http.Request)") | |
} | |
client := &http.Client{ | |
Transport: transport, | |
Timeout: 100 * time.Millisecond, | |
} | |
done := make(chan struct{}) | |
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
w.Write(make([]byte, 1<<14)) // large enough to flush the buffer | |
select { | |
case <-done: | |
case <-time.After(client.Timeout * 10): | |
} | |
})) | |
defer srv.Close() | |
defer close(done) | |
resp, err := client.Get(srv.URL) | |
if err != nil { | |
// TODO: This might be a little flaky since it relies on timing | |
// but a better solution isn't jumping out at me at the moment. | |
t.Errorf("unexpected error: %q", err.Error()) | |
return | |
} | |
_, err = ioutil.ReadAll(resp.Body) | |
if err == nil { | |
t.Error("no error reading body") | |
} | |
} | |
func TestRefreshTimeout(t *testing.T) { | |
t.Parallel() | |
ts := &testTokenStore{ | |
tok: &Token{ | |
AccessToken: "FOOBARTOKEN", | |
Expiry: time.Now().Add(-time.Hour), | |
}, | |
d: 5 * time.Second, | |
} | |
transport := NewTransport(nil, ts) | |
client := &http.Client{ | |
Transport: transport, | |
Timeout: 10 * time.Millisecond, | |
} | |
_, err := client.Get("http://localhost") | |
if exp := "oauth2: request canceled"; !strings.HasSuffix(err.Error(), exp) { | |
t.Errorf("error expected: %q; got: %q", exp, err.Error()) | |
} | |
} | |
func TestCancel(t *testing.T) { | |
t.Parallel() | |
ts := &testTokenStore{tok: &Token{AccessToken: "FOOBARTOKEN"}} | |
transport := NewTransport(nil, ts) | |
canceler, ok := transport.(requestCanceler) | |
if !ok { | |
t.Fatal("transport does not implement CancelRequest(*http.Request)") | |
} | |
client := &http.Client{Transport: transport} | |
sig := make(chan bool) | |
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | |
<-sig | |
<-sig | |
})) | |
defer srv.Close() | |
defer close(sig) | |
req, _ := http.NewRequest("GET", srv.URL, nil) | |
errCh := make(chan error, 1) | |
go func() { | |
_, err := client.Do(req) | |
errCh <- err | |
}() | |
sig <- true | |
canceler.CancelRequest(req) | |
select { | |
case err := <-errCh: | |
if err == nil { | |
t.Error("no error") | |
} | |
case <-time.After(time.Second): | |
t.Error("timeout: no error") | |
} | |
} | |
func TestUnsupportedTimeout(t *testing.T) { | |
t.Parallel() | |
type basicRoundTripper struct { | |
http.RoundTripper | |
} | |
rt := basicRoundTripper{http.DefaultTransport} | |
ts := &testTokenStore{tok: &Token{AccessToken: "FOOBARTOKEN"}} | |
transport := NewTransport(rt, ts) | |
if _, ok := transport.(requestCanceler); ok { | |
t.Fatal("transport should not implement CancelRequest(*http.Request)") | |
} | |
client := &http.Client{ | |
Transport: transport, | |
Timeout: time.Millisecond, | |
} | |
_, err := client.Get("http://localhost") | |
if exp := "Timeout not supported"; err == nil { | |
t.Errorf("error expected: %q; got: <nil>", exp) | |
} else if !strings.HasSuffix(err.Error(), exp) { | |
t.Errorf("error expected: %q; got: %q", exp, err.Error()) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment