Skip to content

Instantly share code, notes, and snippets.

@paulbdavis
Last active February 12, 2020 04:01
Show Gist options
  • Save paulbdavis/e684956357bde4a18ab1b8d1957e21fe to your computer and use it in GitHub Desktop.
Save paulbdavis/e684956357bde4a18ab1b8d1957e21fe to your computer and use it in GitHub Desktop.
package auth
type tokenRefresher struct {
oauth *oauth2.Config
tokens map[string]*oauth2.Token
locks map[string]*sync.Mutex
accessed map[string]time.Time
}
func newRefresher(oauthConfig *oauth2.Config) *tokenRefresher {
return &tokenRefresher{
oauth: oauthConfig,
tokens: map[string]*oauth2.Token{},
locks: map[string]*sync.Mutex{},
accessed: map[string]time.Time{},
}
}
func (tr *tokenRefresher) maybeRefreshToken(ctx context.Context, t *oauth2.Token) (*oauth2.Token, error) {
key := t.AccessToken
log := logger.Ctx(ctx).With().Str("access token", key).Logger()
mx := tr.locks[key]
if mx == nil {
mx = &sync.Mutex{}
tr.locks[key] = mx
}
log.Debug().
Msg("setting mutex lock")
mx.Lock()
defer mx.Unlock()
// if we already refreshed this access token, try to refresh the new one
refreshedToken := tr.tokens[key]
if refreshedToken != nil {
log.Debug().
Interface("refreshed token", refreshedToken).
Msg("found a new token from this one")
return tr.maybeRefreshToken(ctx, refreshedToken)
}
log.Debug().
Msg("no existing refresh, checking this token")
source := tr.oauth.TokenSource(ctx, t)
newToken, err := source.Token()
if err != nil {
return nil, fmt.Errorf("checking for refresh: %w", err)
}
// if this is a new token, save it to the map so that next time this token is used
if newToken.AccessToken != t.AccessToken {
log.Debug().
Msg("token refreshed, adding to refresher cache")
tr.tokens[key] = newToken
}
tr.accessed[key] = time.Now()
log.Debug().
Msg("finished refresh check")
return newToken, nil
}
func (tr *tokenRefresher) cleanup(ctx context.Context) {
log := logger.Ctx(ctx)
cutoff := time.Now().Add(-15 * time.Minute)
for key, accessed := range tr.accessed {
if accessed.Before(cutoff) {
log.Debug().
Interface("old token", tr.tokens[key]).
Msg("removing old token record from refresher")
delete(tr.tokens, key)
delete(tr.locks, key)
delete(tr.accessed, key)
}
}
}
func CheckAuthHandler(oauthConf *oauth2.Config, next http.Handler) http.Handler {
refresher := newRefresher(oauthConf)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log := logger.Ctx(r.Context())
tokenStr := r.Header.Get("x-auth-token")
incomingToken, err := parseTokenString(tokenStr)
if err != nil {
util.SendError(w, errors.New("unparsable token", http.StatusUnauthorized))
return
}
token, err := refresher.maybeRefreshToken(r.Context(), incomingToken)
if err != nil {
util.SendError(w, err)
return
}
refresher.cleanup(r.Context())
userInfo, err := getUserInfo(r.Context(), oauthConf, token)
if err != nil {
util.SendError(w, err)
return
}
ctx = context.WithValue(ctx, ContextKeyUserClient, oauthConf.Client(ctx, token))
})
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment