-
-
Save ludydoo/9fef0858e63b79275342a2000bc68560 to your computer and use it in GitHub Desktop.
Minimal reproduction for USER_SRP_AUTH + MFA
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"context" | |
"crypto/hmac" | |
"crypto/rand" | |
"crypto/sha256" | |
"encoding/base64" | |
"encoding/hex" | |
"fmt" | |
"github.com/aws/aws-sdk-go-v2/config" | |
"github.com/aws/aws-sdk-go-v2/service/cognitoidentityprovider" | |
"github.com/aws/aws-sdk-go-v2/service/cognitoidentityprovider/types" | |
"github.com/manifoldco/promptui" | |
"math/big" | |
"net/mail" | |
"os" | |
"os/signal" | |
"strings" | |
"time" | |
) | |
func main() { | |
ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt) | |
if err := run(ctx); err != nil { | |
fmt.Println(err) | |
os.Exit(1) | |
} | |
} | |
func run(ctx context.Context) error { | |
fmt.Println("Starting Cognito SRP CLI") | |
fmt.Println("Environment variables can be used to set default values for prompts...") | |
fmt.Println() | |
fmt.Println("COGNITO_CLIENT_ID Cognito Client ID") | |
fmt.Println("COGNITO_CLIENT_SECRET Cognito Client Secret") | |
fmt.Println("COGNITO_USER_POOL_ID Cognito User Pool ID") | |
fmt.Println("COGNITO_EMAIL Email") | |
fmt.Println("COGNITO_PASSWORD Password") | |
fmt.Println() | |
cognitoConfig, err := config.LoadDefaultConfig(ctx) | |
if err != nil { | |
return fmt.Errorf("failed to load default config, %v", err) | |
} | |
cognitoObj := cognitoidentityprovider.NewFromConfig(cognitoConfig) | |
clientIdPrompt := promptui.Prompt{ | |
Label: "Cognito Client ID", | |
Default: os.Getenv("COGNITO_CLIENT_ID"), | |
} | |
clientId, err := clientIdPrompt.Run() | |
if err != nil { | |
return err | |
} | |
clientSecretPrompt := promptui.Prompt{ | |
Label: "Cognito Client Secret", | |
Default: os.Getenv("COGNITO_CLIENT_SECRET"), | |
} | |
clientSecret, err := clientSecretPrompt.Run() | |
if err != nil { | |
return err | |
} | |
poolIdPrompt := promptui.Prompt{ | |
Label: "Cognito User Pool ID", | |
Default: os.Getenv("COGNITO_USER_POOL_ID"), | |
} | |
poolId, err := poolIdPrompt.Run() | |
if err != nil { | |
return err | |
} | |
emailPrompt := promptui.Prompt{ | |
Label: "Email", | |
Default: os.Getenv("COGNITO_EMAIL"), | |
Validate: func(s string) error { | |
email, err := mail.ParseAddress(s) | |
if err != nil { | |
return err | |
} | |
if email.Address == "" { | |
return fmt.Errorf("address not allowed") | |
} | |
if email.Name != "" { | |
return fmt.Errorf("name not allowed") | |
} | |
return nil | |
}, | |
} | |
email, err := emailPrompt.Run() | |
if err != nil { | |
return err | |
} | |
hmacObj := hmac.New(sha256.New, []byte(clientSecret)) | |
hmacObj.Write([]byte(email + clientId)) | |
passwordPrompt := promptui.Prompt{ | |
Label: "Password", | |
Default: os.Getenv("COGNITO_PASSWORD"), | |
Mask: '*', | |
} | |
password, err := passwordPrompt.Run() | |
if err != nil { | |
return err | |
} | |
var secretHash string | |
if clientSecret != "" { | |
secretHash, err = getSecretHash(email, clientId, clientSecret) | |
if err != nil { | |
return err | |
} | |
} | |
fmt.Println("Initiating auth with Cognito. AuthFlow:", types.AuthFlowTypeUserSrpAuth) | |
initiateAuthResp, err := cognitoObj.InitiateAuth(ctx, &cognitoidentityprovider.InitiateAuthInput{ | |
AuthFlow: types.AuthFlowTypeUserSrpAuth, | |
ClientId: &clientId, | |
AuthParameters: GetAuthParams(email, clientId, clientSecret), | |
}) | |
if err != nil { | |
return err | |
} | |
authResult := initiateAuthResp.AuthenticationResult | |
challengeName := initiateAuthResp.ChallengeName | |
challengeParams := initiateAuthResp.ChallengeParameters | |
session := initiateAuthResp.Session | |
for challengeName != "" { | |
fmt.Println("Handling challenge:", challengeName) | |
if challengeName == types.ChallengeNameTypePasswordVerifier { | |
params, err := passwordVerifierChallenge(strings.Split(poolId, "_")[1], clientId, clientSecret, password, challengeParams, time.Now()) | |
if err != nil { | |
return err | |
} | |
respondToAuthChallengeResp, err := cognitoObj.RespondToAuthChallenge(ctx, &cognitoidentityprovider.RespondToAuthChallengeInput{ | |
ChallengeName: challengeName, | |
ChallengeResponses: params, | |
ClientId: &clientId, | |
Session: session, | |
}) | |
if err != nil { | |
return err | |
} | |
challengeName = respondToAuthChallengeResp.ChallengeName | |
challengeParams = respondToAuthChallengeResp.ChallengeParameters | |
session = respondToAuthChallengeResp.Session | |
authResult = respondToAuthChallengeResp.AuthenticationResult | |
} else if challengeName == types.ChallengeNameTypeSoftwareTokenMfa { | |
responsePrompt := promptui.Prompt{ | |
Label: "MFA Code", | |
} | |
response, err := responsePrompt.Run() | |
if err != nil { | |
return err | |
} | |
challengeResponse := map[string]string{ | |
"SOFTWARE_TOKEN_MFA_CODE": response, | |
"USERNAME": email, | |
} | |
if secretHash != "" { | |
challengeResponse["SECRET_HASH"] = secretHash | |
} | |
respondToAuthChallengeResp, err := cognitoObj.RespondToAuthChallenge(ctx, &cognitoidentityprovider.RespondToAuthChallengeInput{ | |
ChallengeName: challengeName, | |
ChallengeResponses: challengeResponse, | |
ClientId: &clientId, | |
Session: session, | |
}) | |
if err != nil { | |
return err | |
} | |
challengeName = respondToAuthChallengeResp.ChallengeName | |
challengeParams = respondToAuthChallengeResp.ChallengeParameters | |
session = respondToAuthChallengeResp.Session | |
authResult = respondToAuthChallengeResp.AuthenticationResult | |
} else { | |
return fmt.Errorf("unsupported challenge: %s", challengeName) | |
} | |
} | |
if authResult.NewDeviceMetadata != nil { | |
fmt.Println("New device metadata obtained") | |
fmt.Println("Device Key:", *authResult.NewDeviceMetadata.DeviceKey) | |
fmt.Println("Device Group Key:", *authResult.NewDeviceMetadata.DeviceGroupKey) | |
devicePassword, err := generateDevicePassword() | |
if err != nil { | |
return err | |
} | |
verifierConfig, err := secretVerifierConfig(authResult, devicePassword) | |
if err != nil { | |
return err | |
} | |
confirmResp, err := cognitoObj.ConfirmDevice(ctx, &cognitoidentityprovider.ConfirmDeviceInput{ | |
AccessToken: authResult.AccessToken, | |
DeviceKey: authResult.NewDeviceMetadata.DeviceKey, | |
DeviceName: nil, | |
DeviceSecretVerifierConfig: verifierConfig, | |
}) | |
if err != nil { | |
return err | |
} | |
if confirmResp.UserConfirmationNecessary { | |
fmt.Println("User confirmation necessary") | |
confirmPrompt := promptui.Prompt{ | |
Label: "Do you want to remember this device?", | |
Validate: func(s string) error { | |
if s == "y" || s == "n" { | |
return nil | |
} | |
return fmt.Errorf("invalid response") | |
}, | |
} | |
confirm, err := confirmPrompt.Run() | |
if err != nil { | |
return err | |
} | |
if confirm == "y" { | |
_, err := cognitoObj.UpdateDeviceStatus(ctx, &cognitoidentityprovider.UpdateDeviceStatusInput{ | |
AccessToken: authResult.AccessToken, | |
DeviceKey: authResult.NewDeviceMetadata.DeviceKey, | |
DeviceRememberedStatus: types.DeviceRememberedStatusTypeRemembered, | |
}) | |
if err != nil { | |
return err | |
} | |
} else { | |
fmt.Println("Device not remembered") | |
} | |
} | |
} | |
fmt.Println("All done!") | |
fmt.Println("Access Token:", *authResult.AccessToken) | |
fmt.Println("Refresh Token:", *authResult.RefreshToken) | |
fmt.Println("Id Token:", *authResult.IdToken) | |
return nil | |
} | |
var ( | |
bigN *big.Int | |
g *big.Int | |
k *big.Int | |
a *big.Int | |
bigA *big.Int | |
) | |
func init() { | |
bigN, _ = big.NewInt(0).SetString(nHex, 16) | |
g, _ = big.NewInt(0).SetString(gHex, 16) | |
k, _ = big.NewInt(0).SetString(hexHash("00"+nHex+"0"+gHex), 16) | |
b := make([]byte, 128) | |
// small A | |
rand.Read(b) | |
randomLongInt, _ := big.NewInt(0).SetString(hex.EncodeToString(b), 16) | |
a = big.NewInt(0).Mod(randomLongInt, bigN) | |
// big A | |
bigA = big.NewInt(0).Exp(g, a, bigN) | |
} | |
func GetAuthParams(email, clientId, clientSecret string) map[string]string { | |
params := map[string]string{ | |
"USERNAME": email, | |
"SRP_A": bigA.Text(16), | |
} | |
if secret, err := getSecretHash(email, clientId, clientSecret); err == nil { | |
params["SECRET_HASH"] = secret | |
} | |
return params | |
} | |
func passwordVerifierChallenge(poolName, clientId, clientSecret, password string, challengeParams map[string]string, ts time.Time) (map[string]string, error) { | |
var ( | |
internalUsername = challengeParams["USERNAME"] | |
userId = challengeParams["USER_ID_FOR_SRP"] | |
saltHex = challengeParams["SALT"] | |
srpBHex = challengeParams["SRP_B"] | |
secretBlockB64 = challengeParams["SECRET_BLOCK"] | |
srpB, _ = big.NewInt(0).SetString(srpBHex, 16) | |
salt, _ = big.NewInt(0).SetString(saltHex, 16) | |
timestamp = ts.In(time.UTC).Format("Mon Jan 2 03:04:05 MST 2006") | |
hkdf = getPasswordAuthenticationKey(poolName, userId, password, srpB, salt) | |
) | |
secretBlockBytes, err := base64.StdEncoding.DecodeString(secretBlockB64) | |
if err != nil { | |
return nil, fmt.Errorf("unable to decode challenge parameter 'SECRET_BLOCK', %s", err.Error()) | |
} | |
msg := poolName + userId + string(secretBlockBytes) + timestamp | |
hmacObj := hmac.New(sha256.New, hkdf) | |
hmacObj.Write([]byte(msg)) | |
signature := base64.StdEncoding.EncodeToString(hmacObj.Sum(nil)) | |
response := map[string]string{ | |
"TIMESTAMP": timestamp, | |
"USERNAME": internalUsername, | |
"PASSWORD_CLAIM_SECRET_BLOCK": secretBlockB64, | |
"PASSWORD_CLAIM_SIGNATURE": signature, | |
} | |
if clientSecret != "" { | |
secretHash, err := getSecretHash(internalUsername, clientId, clientSecret) | |
if err != nil { | |
return nil, err | |
} | |
response["SECRET_HASH"] = secretHash | |
} | |
return response, nil | |
} | |
func getPasswordAuthenticationKey(poolName, username, password string, bigB, salt *big.Int) []byte { | |
var ( | |
userPass = fmt.Sprintf("%s%s:%s", poolName, username, password) | |
userPassHash = hashSha256([]byte(userPass)) | |
uVal, _ = big.NewInt(0).SetString(hexHash(padHex(bigA.Text(16))+padHex(bigB.Text(16))), 16) | |
xVal, _ = big.NewInt(0).SetString(hexHash(padHex(salt.Text(16))+userPassHash), 16) | |
gModPowXN = big.NewInt(0).Exp(g, xVal, bigN) | |
intVal1 = big.NewInt(0).Sub(bigB, big.NewInt(0).Mul(k, gModPowXN)) | |
intVal2 = big.NewInt(0).Add(a, big.NewInt(0).Mul(uVal, xVal)) | |
sVal = big.NewInt(0).Exp(intVal1, intVal2, bigN) | |
) | |
return computeHKDF(padHex(sVal.Text(16)), padHex(uVal.Text(16))) | |
} | |
func secretVerifierConfig(authResult *types.AuthenticationResultType, password string) (*types.DeviceSecretVerifierConfigType, error) { | |
secret := fmt.Sprintf("%s%s:%s", *authResult.NewDeviceMetadata.DeviceGroupKey, *authResult.NewDeviceMetadata.DeviceKey, password) | |
secretSha256 := sha256.Sum256([]byte(secret)) | |
secretHash := hex.EncodeToString(secretSha256[:]) | |
secretHash = strings.Repeat("0", 64-len(secretHash)) + secretHash | |
salt := make([]byte, 16) | |
_, _ = rand.Read(salt) | |
salt[0] = salt[0] & 127 | |
saltHex := hex.EncodeToString(salt) | |
xHex := hexHash(saltHex + secretHash) | |
x, _ := big.NewInt(0).SetString(xHex, 16) | |
v := big.NewInt(0).Exp(g, x, bigN) | |
vHex := padHex(v.Text(16)) | |
vBase64, err := hexToBase64(vHex) | |
if err != nil { | |
return nil, err | |
} | |
saltBase64, err := hexToBase64(saltHex) | |
if err != nil { | |
return nil, err | |
} | |
return &types.DeviceSecretVerifierConfigType{ | |
PasswordVerifier: &vBase64, | |
Salt: &saltBase64, | |
}, nil | |
} | |
func getSecretHash(username, clientId, clientSecret string) (string, error) { | |
var ( | |
msg = username + clientId | |
key = []byte(clientSecret) | |
h = hmac.New(sha256.New, key) | |
) | |
h.Write([]byte(msg)) | |
sh := base64.StdEncoding.EncodeToString(h.Sum(nil)) | |
return sh, nil | |
} | |
func computeHKDF(ikm, salt string) []byte { | |
ikmb, _ := hex.DecodeString(ikm) | |
saltb, _ := hex.DecodeString(salt) | |
extractor := hmac.New(sha256.New, saltb) | |
extractor.Write(ikmb) | |
prk := extractor.Sum(nil) | |
infoBitsUpdate := append([]byte(infoBits), byte(1)) | |
extractor = hmac.New(sha256.New, prk) | |
extractor.Write(infoBitsUpdate) | |
hmacHash := extractor.Sum(nil) | |
return hmacHash[:16] | |
} | |
func hexHash(hexStr string) string { | |
buf, _ := hex.DecodeString(hexStr) | |
return hashSha256(buf) | |
} | |
func hashSha256(buf []byte) string { | |
a := sha256.New() | |
a.Write(buf) | |
return hex.EncodeToString(a.Sum(nil)) | |
} | |
func padHex(hexStr string) string { | |
if len(hexStr)%2 == 1 { | |
hexStr = fmt.Sprintf("0%s", hexStr) | |
} else if strings.Contains("89ABCDEFabcdef", string(hexStr[0])) { | |
hexStr = fmt.Sprintf("00%s", hexStr) | |
} | |
return hexStr | |
} | |
func hexToBase64(hexStr string) (string, error) { | |
b, err := hex.DecodeString(hexStr) | |
if err != nil { | |
return "", err | |
} | |
return base64.StdEncoding.EncodeToString(b), nil | |
} | |
func generateDevicePassword() (string, error) { | |
b := make([]byte, 40) | |
if _, err := rand.Read(b); err != nil { | |
return "", err | |
} | |
return base64.StdEncoding.EncodeToString(b), nil | |
} | |
const ( | |
nHex = "FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1" + | |
"29024E088A67CC74020BBEA63B139B22514A08798E3404DD" + | |
"EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245" + | |
"E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED" + | |
"EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D" + | |
"C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F" + | |
"83655D23DCA3AD961C62F356208552BB9ED529077096966D" + | |
"670C354E4ABC9804F1746C08CA18217C32905E462E36CE3B" + | |
"E39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9" + | |
"DE2BCBF6955817183995497CEA956AE515D2261898FA0510" + | |
"15728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64" + | |
"ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7" + | |
"ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6B" + | |
"F12FFA06D98A0864D87602733EC86A64521F2B18177B200C" + | |
"BBE117577A615D6C770988C0BAD946E208E24FA074E5AB31" + | |
"43DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF" | |
gHex = "2" | |
infoBits = "Caldera Derived Key" | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment