Created
September 23, 2019 16:20
-
-
Save xeoncross/9d7bbcdded6b7c5613f7ba2645c928e1 to your computer and use it in GitHub Desktop.
Simple OAuth token with auto-renew based on a client id and secret
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package oauthtoken | |
import ( | |
"context" | |
"encoding/json" | |
"fmt" | |
"log" | |
"net/http" | |
"net/url" | |
"strings" | |
"sync" | |
"time" | |
"github.com/pkg/errors" | |
) | |
type OAuthToken struct { | |
ClientId string | |
ClientSecret string | |
token string | |
expires time.Time | |
m sync.Mutex | |
} | |
// NewToken with auto-renewing background process | |
func NewToken(ctx context.Context, clientId, clientSecret string) (*OAuthToken, error) { | |
t := &OAuthToken{ | |
ClientId: clientId, | |
ClientSecret: clientSecret, | |
} | |
// Fetch the first token (blocking) | |
err := t.fetchNew() | |
if err != nil { | |
return nil, err | |
} | |
// Continue to refresh | |
go func() { | |
for { | |
select { | |
case <-ctx.Done(): | |
return | |
default: | |
time.Sleep(time.Minute * 10) | |
// Start trying to refresh 2 hours before it expires | |
if time.Now().Sub(t.expires) < time.Hour*2 { | |
err = t.fetchNew() | |
if err != nil { | |
log.Println(err) | |
} | |
} | |
// Now the token has expired | |
// if int64(time.Now().Sub(t.expires)) < 0 { | |
// // big trouble now! | |
// } | |
} | |
} | |
}() | |
return t, nil | |
} | |
// AccessToken returns a valid access token | |
func (t *OAuthToken) AccessToken() string { | |
t.m.Lock() | |
defer t.m.Unlock() | |
return t.token | |
} | |
func (t *OAuthToken) fetchNew() error { | |
endpoint := "https://example.com/oauth/token" | |
values := url.Values{ | |
"grant_type": []string{"client_credentials"}, | |
"client_id": []string{t.ClientId}, | |
"client_secret": []string{t.ClientSecret}, | |
} | |
body := strings.NewReader(values.Encode()) | |
req, err := http.NewRequest("POST", endpoint, body) | |
if err != nil { | |
return errors.Wrap(err, "Fetch OAuth Token") | |
} | |
// covers the entire exchange, from Dial to reading the body | |
c := &http.Client{ | |
Timeout: 15 * time.Second, | |
} | |
// req = req.WithContext(ctx) | |
res, err := c.Do(req) | |
if err != nil { | |
return errors.Wrap(err, "Fetch OAuth Token") | |
} | |
defer res.Body.Close() | |
if res.StatusCode < 200 || res.StatusCode > 299 { | |
return fmt.Errorf("Bad Status Code: got %v", res.Status) | |
} | |
tokenJSON := struct { | |
AccessToken string `json:"access_token"` | |
Expires int64 `json:"expires_in"` | |
}{} | |
if err := json.NewDecoder(res.Body).Decode(&tokenJSON); err != nil { | |
return errors.Wrap(err, "Fetch OAuth Token") | |
} | |
t.m.Lock() | |
t.token = tokenJSON.AccessToken | |
t.expires = time.Now().Add(time.Duration(tokenJSON.Expires) * time.Second) | |
t.m.Unlock() | |
return nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment