Skip to content

Instantly share code, notes, and snippets.

@vearutop
Created July 1, 2020 10:13
Show Gist options
  • Save vearutop/7cad49deae438409ca71b6ebf5a24e3f to your computer and use it in GitHub Desktop.
Save vearutop/7cad49deae438409ca71b6ebf5a24e3f to your computer and use it in GitHub Desktop.
Go OAuth2 password-based http transport middleware
// Package oauth2 implements password-based authentication middleware.
package oauth2
import (
"context"
"fmt"
"net/http"
"time"
"golang.org/x/oauth2"
)
// Config configures OAuth2 client.
type Config struct {
TokenURL string
ID string
Secret string
User string
Password string
TokenRetrieveTimeout time.Duration `split_words:"true" default:"10s"`
}
// NewHTTPTransportMiddleware creates http client transport middleware.
func NewHTTPTransportMiddleware(cfg Config) (func(tripper http.RoundTripper) http.RoundTripper, error) {
conf := &oauth2.Config{
ClientID: cfg.ID,
ClientSecret: cfg.Secret,
Endpoint: oauth2.Endpoint{
TokenURL: cfg.TokenURL,
},
}
ctx, cancel := context.WithTimeout(context.Background(), cfg.TokenRetrieveTimeout)
defer cancel()
t, err := conf.PasswordCredentialsToken(ctx, cfg.User, cfg.Password)
if err != nil {
return nil, err
}
c := conf.Client(context.Background(), t)
tr, ok := c.Transport.(*oauth2.Transport)
if !ok {
return nil, fmt.Errorf("failed to assert transport type: %t", c.Transport)
}
return func(tripper http.RoundTripper) http.RoundTripper {
tr.Base = tripper
return tr
}, nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment