Skip to content

Instantly share code, notes, and snippets.

@stream-jdibling
Forked from miguelmota/auth.go
Last active December 1, 2022 12:25
Show Gist options
  • Save stream-jdibling/2c27416d01c99cbeea9fdd07d74e8b0f to your computer and use it in GitHub Desktop.
Save stream-jdibling/2c27416d01c99cbeea9fdd07d74e8b0f to your computer and use it in GitHub Desktop.
Golang AWS Cognito Validate JWT token
/* This algorithm is based off of the AWESOME gist by miguelmota:
https://gist.github.com/miguelmota/06f563756448b0d4ce2ba508b3cbe6e2
However that code had two problems -
1) It used a deprecated jwt library (github.com/dgrijalva/jwt-go). The code below uses the current implementation at github.com/golang-jwt/jwt/v4
2) The KeyFunc closure passed to jwt.Parse() always used the second JWK, even if the KID in the JWT header pointed to a different JWK. The code below finds the JWK for the KID specified in the header and verifies against that.
*/
package auth
import (
"crypto/rsa"
"encoding/base64"
"encoding/binary"
"encoding/json"
"fmt"
"io/ioutil"
"math/big"
"net/http"
jwt "github.com/golang-jwt/jwt/v4"
)
// Auth ...
type Auth struct {
jwk *JWK
jwkURL string
cognitoRegion string
cognitoUserPoolID string
}
// Config ...
type Config struct {
CognitoRegion string
CognitoUserPoolID string
}
// JWK ...
type KeySet struct {
Alg string `json:"alg"`
E string `json:"e"`
Kid string `json:"kid"`
Kty string `json:"kty"`
N string `json:"n"`
}
type JWK struct {
Keys []KeySet `json:"keys"`
}
// MapKeys indexes each KeySet against its KID
func (jwk *JWK) MapKeys() map[string]KeySet {
keymap := make(map[string]KeySet)
for _, keys := range jwk.Keys {
keymap[keys.Kid] = keys
}
return keymap
}
// NewAuth ...
func NewAuth(config *Config) (*Auth, error) {
a := &Auth{
cognitoRegion: config.CognitoRegion,
cognitoUserPoolID: config.CognitoUserPoolID,
}
a.jwkURL = fmt.Sprintf("https://cognito-idp.%s.amazonaws.com/%s/.well-known/jwks.json", a.cognitoRegion, a.cognitoUserPoolID)
err := a.CacheJWK()
if err != nil {
return nil, fmt.Errorf("caching jwk; %w", err)
}
return a, nil
}
// CacheJWK ...
func (a *Auth) CacheJWK() error {
req, err := http.NewRequest("GET", a.jwkURL, nil)
if err != nil {
return err
}
req.Header.Add("Accept", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
jwk := new(JWK)
err = json.Unmarshal(body, jwk)
if err != nil {
return err
}
a.jwk = jwk
return nil
}
// ParseJWT ...
func (a *Auth) ParseJWT(tokenString string) (*jwt.Token, error) {
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
kid, ok := token.Header["kid"].(string)
if !ok {
return nil, fmt.Errorf("getting kid; not a string")
}
keymap := a.jwk.MapKeys()
keyset, ok := keymap[kid]
if !ok {
return nil, fmt.Errorf("keyset not found for kid %s", kid)
}
key := convertKey(keyset.E, keyset.N)
return key, nil
})
if err != nil {
return token, fmt.Errorf("parsing jwt; %w", err)
}
return token, nil
}
// JWK ...
func (a *Auth) JWK() *JWK {
return a.jwk
}
// JWKURL ...
func (a *Auth) JWKURL() string {
return a.jwkURL
}
// https://gist.github.com/MathieuMailhos/361f24316d2de29e8d41e808e0071b13
func convertKey(rawE, rawN string) *rsa.PublicKey {
decodedE, err := base64.RawURLEncoding.DecodeString(rawE)
if err != nil {
panic(err)
}
if len(decodedE) < 4 {
ndata := make([]byte, 4)
copy(ndata[4-len(decodedE):], decodedE)
decodedE = ndata
}
pubKey := &rsa.PublicKey{
N: &big.Int{},
E: int(binary.BigEndian.Uint32(decodedE[:])),
}
decodedN, err := base64.RawURLEncoding.DecodeString(rawN)
if err != nil {
panic(err)
}
pubKey.N.SetBytes(decodedN)
return pubKey
}
package auth
import (
"os"
"testing"
)
func TestCacheJWT(t *testing.T) {
if !(os.Getenv("AWS_COGNITO_USER_POOL_ID") != "" && os.Getenv("AWS_COGNITO_REGION") != "") {
t.Skip("requires AWS Cognito environment variables")
}
auth := NewAuth(&Config{
CognitoRegion: os.Getenv("AWS_COGNITO_REGION"),
CognitoUserPoolID: os.Getenv("AWS_COGNITO_USER_POOL_ID"),
})
err := auth.CacheJWK()
if err != nil {
t.Error(err)
}
jwt := "eyJraWQiOiJlS3lvdytnb1wvXC9yWmtkbGFhRFNOM25jTTREd0xTdFhibks4TTB5b211aE09IiwiYWxnIjoiUlMyNTYifQ.eyJzdWIiOiJjMTcxOGY3OC00ODY5LTRmMmEtYTk2ZS1lYmEwYmJkY2RkMjEiLCJldmVudF9pZCI6IjZmYWMyZGNjLTJlMzUtMTFlOS05NDZjLTZiZDI0YmRlZjFiNiIsInRva2VuX3VzZSI6ImFjY2VzcyIsInNjb3BlIjoiYXdzLmNvZ25pdG8uc2lnbmluLnVzZXIuYWRtaW4iLCJhdXRoX3RpbWUiOjE1NDk5MTQyNjUsImlzcyI6Imh0dHBzOlwvXC9jb2duaXRvLWlkcC51cy13ZXN0LTIuYW1hem9uYXdzLmNvbVwvdXMtd2VzdC0yX0wwVldGSEVueSIsImV4cCI6MTU0OTkxNzg2NSwiaWF0IjoxNTQ5OTE0MjY1LCJqdGkiOiIzMTg0MDdkMC0zZDNhLTQ0NDItOTMyYy1lY2I0MjQ2MzRiYjIiLCJjbGllbnRfaWQiOiI2ZjFzcGI2MzZwdG4wNzRvbjBwZGpnbms4bCIsInVzZXJuYW1lIjoiYzE3MThmNzgtNDg2OS00ZjJhLWE5NmUtZWJhMGJiZGNkZDIxIn0.rJl9mdCrw_lertWhC5RiJcfhRP-xwTYkPLPXmi_NQEO-LtIJ-kwVEvUaZsPnBXku3bWBM3V35jdJloiXclbffl4SDLVkkvU9vzXDETAMaZEzOY1gDVcg4YzNNR4H5kHnl-G-XiN5MajgaWbjohDHTvbPnqgW7e_4qNVXueZv2qfQ8hZ_VcyniNxMGaui-C0_YuR6jdH-T14Wl59Cyf-UFEyli1NZFlmpUQ8QODGMUI12PVFOZiHJIOZ3CQM_Xs-TlRy53RlKGFzf6RQfRm57rJw_zLyJHHnB8DZgbdCRfhNsqZka7ZZUUAlS9aMzdmSc3pPFSJ-hH3p8eFAgB4E71g"
token, err := auth.ParseJWT(jwt)
if err != nil {
t.Error(err)
}
if !token.Valid {
t.Fail()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment