Skip to content

Instantly share code, notes, and snippets.

@xeoncross
Created September 23, 2019 16:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xeoncross/9d7bbcdded6b7c5613f7ba2645c928e1 to your computer and use it in GitHub Desktop.
Save xeoncross/9d7bbcdded6b7c5613f7ba2645c928e1 to your computer and use it in GitHub Desktop.
Simple OAuth token with auto-renew based on a client id and secret
package oauthtoken
import (
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/pkg/errors"
)
type OAuthToken struct {
ClientId string
ClientSecret string
token string
expires time.Time
m sync.Mutex
}
// NewToken with auto-renewing background process
func NewToken(ctx context.Context, clientId, clientSecret string) (*OAuthToken, error) {
t := &OAuthToken{
ClientId: clientId,
ClientSecret: clientSecret,
}
// Fetch the first token (blocking)
err := t.fetchNew()
if err != nil {
return nil, err
}
// Continue to refresh
go func() {
for {
select {
case <-ctx.Done():
return
default:
time.Sleep(time.Minute * 10)
// Start trying to refresh 2 hours before it expires
if time.Now().Sub(t.expires) < time.Hour*2 {
err = t.fetchNew()
if err != nil {
log.Println(err)
}
}
// Now the token has expired
// if int64(time.Now().Sub(t.expires)) < 0 {
// // big trouble now!
// }
}
}
}()
return t, nil
}
// AccessToken returns a valid access token
func (t *OAuthToken) AccessToken() string {
t.m.Lock()
defer t.m.Unlock()
return t.token
}
func (t *OAuthToken) fetchNew() error {
endpoint := "https://example.com/oauth/token"
values := url.Values{
"grant_type": []string{"client_credentials"},
"client_id": []string{t.ClientId},
"client_secret": []string{t.ClientSecret},
}
body := strings.NewReader(values.Encode())
req, err := http.NewRequest("POST", endpoint, body)
if err != nil {
return errors.Wrap(err, "Fetch OAuth Token")
}
// covers the entire exchange, from Dial to reading the body
c := &http.Client{
Timeout: 15 * time.Second,
}
// req = req.WithContext(ctx)
res, err := c.Do(req)
if err != nil {
return errors.Wrap(err, "Fetch OAuth Token")
}
defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode > 299 {
return fmt.Errorf("Bad Status Code: got %v", res.Status)
}
tokenJSON := struct {
AccessToken string `json:"access_token"`
Expires int64 `json:"expires_in"`
}{}
if err := json.NewDecoder(res.Body).Decode(&tokenJSON); err != nil {
return errors.Wrap(err, "Fetch OAuth Token")
}
t.m.Lock()
t.token = tokenJSON.AccessToken
t.expires = time.Now().Add(time.Duration(tokenJSON.Expires) * time.Second)
t.m.Unlock()
return nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment