Skip to content

Instantly share code, notes, and snippets.

@ludydoo

ludydoo/main.go Secret

Created December 1, 2023 19:01
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ludydoo/9fef0858e63b79275342a2000bc68560 to your computer and use it in GitHub Desktop.
Save ludydoo/9fef0858e63b79275342a2000bc68560 to your computer and use it in GitHub Desktop.
Minimal reproduction for USER_SRP_AUTH + MFA
package main
import (
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"fmt"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/cognitoidentityprovider"
"github.com/aws/aws-sdk-go-v2/service/cognitoidentityprovider/types"
"github.com/manifoldco/promptui"
"math/big"
"net/mail"
"os"
"os/signal"
"strings"
"time"
)
func main() {
ctx, _ := signal.NotifyContext(context.Background(), os.Interrupt)
if err := run(ctx); err != nil {
fmt.Println(err)
os.Exit(1)
}
}
func run(ctx context.Context) error {
fmt.Println("Starting Cognito SRP CLI")
fmt.Println("Environment variables can be used to set default values for prompts...")
fmt.Println()
fmt.Println("COGNITO_CLIENT_ID Cognito Client ID")
fmt.Println("COGNITO_CLIENT_SECRET Cognito Client Secret")
fmt.Println("COGNITO_USER_POOL_ID Cognito User Pool ID")
fmt.Println("COGNITO_EMAIL Email")
fmt.Println("COGNITO_PASSWORD Password")
fmt.Println()
cognitoConfig, err := config.LoadDefaultConfig(ctx)
if err != nil {
return fmt.Errorf("failed to load default config, %v", err)
}
cognitoObj := cognitoidentityprovider.NewFromConfig(cognitoConfig)
clientIdPrompt := promptui.Prompt{
Label: "Cognito Client ID",
Default: os.Getenv("COGNITO_CLIENT_ID"),
}
clientId, err := clientIdPrompt.Run()
if err != nil {
return err
}
clientSecretPrompt := promptui.Prompt{
Label: "Cognito Client Secret",
Default: os.Getenv("COGNITO_CLIENT_SECRET"),
}
clientSecret, err := clientSecretPrompt.Run()
if err != nil {
return err
}
poolIdPrompt := promptui.Prompt{
Label: "Cognito User Pool ID",
Default: os.Getenv("COGNITO_USER_POOL_ID"),
}
poolId, err := poolIdPrompt.Run()
if err != nil {
return err
}
emailPrompt := promptui.Prompt{
Label: "Email",
Default: os.Getenv("COGNITO_EMAIL"),
Validate: func(s string) error {
email, err := mail.ParseAddress(s)
if err != nil {
return err
}
if email.Address == "" {
return fmt.Errorf("address not allowed")
}
if email.Name != "" {
return fmt.Errorf("name not allowed")
}
return nil
},
}
email, err := emailPrompt.Run()
if err != nil {
return err
}
hmacObj := hmac.New(sha256.New, []byte(clientSecret))
hmacObj.Write([]byte(email + clientId))
passwordPrompt := promptui.Prompt{
Label: "Password",
Default: os.Getenv("COGNITO_PASSWORD"),
Mask: '*',
}
password, err := passwordPrompt.Run()
if err != nil {
return err
}
var secretHash string
if clientSecret != "" {
secretHash, err = getSecretHash(email, clientId, clientSecret)
if err != nil {
return err
}
}
fmt.Println("Initiating auth with Cognito. AuthFlow:", types.AuthFlowTypeUserSrpAuth)
initiateAuthResp, err := cognitoObj.InitiateAuth(ctx, &cognitoidentityprovider.InitiateAuthInput{
AuthFlow: types.AuthFlowTypeUserSrpAuth,
ClientId: &clientId,
AuthParameters: GetAuthParams(email, clientId, clientSecret),
})
if err != nil {
return err
}
authResult := initiateAuthResp.AuthenticationResult
challengeName := initiateAuthResp.ChallengeName
challengeParams := initiateAuthResp.ChallengeParameters
session := initiateAuthResp.Session
for challengeName != "" {
fmt.Println("Handling challenge:", challengeName)
if challengeName == types.ChallengeNameTypePasswordVerifier {
params, err := passwordVerifierChallenge(strings.Split(poolId, "_")[1], clientId, clientSecret, password, challengeParams, time.Now())
if err != nil {
return err
}
respondToAuthChallengeResp, err := cognitoObj.RespondToAuthChallenge(ctx, &cognitoidentityprovider.RespondToAuthChallengeInput{
ChallengeName: challengeName,
ChallengeResponses: params,
ClientId: &clientId,
Session: session,
})
if err != nil {
return err
}
challengeName = respondToAuthChallengeResp.ChallengeName
challengeParams = respondToAuthChallengeResp.ChallengeParameters
session = respondToAuthChallengeResp.Session
authResult = respondToAuthChallengeResp.AuthenticationResult
} else if challengeName == types.ChallengeNameTypeSoftwareTokenMfa {
responsePrompt := promptui.Prompt{
Label: "MFA Code",
}
response, err := responsePrompt.Run()
if err != nil {
return err
}
challengeResponse := map[string]string{
"SOFTWARE_TOKEN_MFA_CODE": response,
"USERNAME": email,
}
if secretHash != "" {
challengeResponse["SECRET_HASH"] = secretHash
}
respondToAuthChallengeResp, err := cognitoObj.RespondToAuthChallenge(ctx, &cognitoidentityprovider.RespondToAuthChallengeInput{
ChallengeName: challengeName,
ChallengeResponses: challengeResponse,
ClientId: &clientId,
Session: session,
})
if err != nil {
return err
}
challengeName = respondToAuthChallengeResp.ChallengeName
challengeParams = respondToAuthChallengeResp.ChallengeParameters
session = respondToAuthChallengeResp.Session
authResult = respondToAuthChallengeResp.AuthenticationResult
} else {
return fmt.Errorf("unsupported challenge: %s", challengeName)
}
}
if authResult.NewDeviceMetadata != nil {
fmt.Println("New device metadata obtained")
fmt.Println("Device Key:", *authResult.NewDeviceMetadata.DeviceKey)
fmt.Println("Device Group Key:", *authResult.NewDeviceMetadata.DeviceGroupKey)
devicePassword, err := generateDevicePassword()
if err != nil {
return err
}
verifierConfig, err := secretVerifierConfig(authResult, devicePassword)
if err != nil {
return err
}
confirmResp, err := cognitoObj.ConfirmDevice(ctx, &cognitoidentityprovider.ConfirmDeviceInput{
AccessToken: authResult.AccessToken,
DeviceKey: authResult.NewDeviceMetadata.DeviceKey,
DeviceName: nil,
DeviceSecretVerifierConfig: verifierConfig,
})
if err != nil {
return err
}
if confirmResp.UserConfirmationNecessary {
fmt.Println("User confirmation necessary")
confirmPrompt := promptui.Prompt{
Label: "Do you want to remember this device?",
Validate: func(s string) error {
if s == "y" || s == "n" {
return nil
}
return fmt.Errorf("invalid response")
},
}
confirm, err := confirmPrompt.Run()
if err != nil {
return err
}
if confirm == "y" {
_, err := cognitoObj.UpdateDeviceStatus(ctx, &cognitoidentityprovider.UpdateDeviceStatusInput{
AccessToken: authResult.AccessToken,
DeviceKey: authResult.NewDeviceMetadata.DeviceKey,
DeviceRememberedStatus: types.DeviceRememberedStatusTypeRemembered,
})
if err != nil {
return err
}
} else {
fmt.Println("Device not remembered")
}
}
}
fmt.Println("All done!")
fmt.Println("Access Token:", *authResult.AccessToken)
fmt.Println("Refresh Token:", *authResult.RefreshToken)
fmt.Println("Id Token:", *authResult.IdToken)
return nil
}
var (
bigN *big.Int
g *big.Int
k *big.Int
a *big.Int
bigA *big.Int
)
func init() {
bigN, _ = big.NewInt(0).SetString(nHex, 16)
g, _ = big.NewInt(0).SetString(gHex, 16)
k, _ = big.NewInt(0).SetString(hexHash("00"+nHex+"0"+gHex), 16)
b := make([]byte, 128)
// small A
rand.Read(b)
randomLongInt, _ := big.NewInt(0).SetString(hex.EncodeToString(b), 16)
a = big.NewInt(0).Mod(randomLongInt, bigN)
// big A
bigA = big.NewInt(0).Exp(g, a, bigN)
}
func GetAuthParams(email, clientId, clientSecret string) map[string]string {
params := map[string]string{
"USERNAME": email,
"SRP_A": bigA.Text(16),
}
if secret, err := getSecretHash(email, clientId, clientSecret); err == nil {
params["SECRET_HASH"] = secret
}
return params
}
func passwordVerifierChallenge(poolName, clientId, clientSecret, password string, challengeParams map[string]string, ts time.Time) (map[string]string, error) {
var (
internalUsername = challengeParams["USERNAME"]
userId = challengeParams["USER_ID_FOR_SRP"]
saltHex = challengeParams["SALT"]
srpBHex = challengeParams["SRP_B"]
secretBlockB64 = challengeParams["SECRET_BLOCK"]
srpB, _ = big.NewInt(0).SetString(srpBHex, 16)
salt, _ = big.NewInt(0).SetString(saltHex, 16)
timestamp = ts.In(time.UTC).Format("Mon Jan 2 03:04:05 MST 2006")
hkdf = getPasswordAuthenticationKey(poolName, userId, password, srpB, salt)
)
secretBlockBytes, err := base64.StdEncoding.DecodeString(secretBlockB64)
if err != nil {
return nil, fmt.Errorf("unable to decode challenge parameter 'SECRET_BLOCK', %s", err.Error())
}
msg := poolName + userId + string(secretBlockBytes) + timestamp
hmacObj := hmac.New(sha256.New, hkdf)
hmacObj.Write([]byte(msg))
signature := base64.StdEncoding.EncodeToString(hmacObj.Sum(nil))
response := map[string]string{
"TIMESTAMP": timestamp,
"USERNAME": internalUsername,
"PASSWORD_CLAIM_SECRET_BLOCK": secretBlockB64,
"PASSWORD_CLAIM_SIGNATURE": signature,
}
if clientSecret != "" {
secretHash, err := getSecretHash(internalUsername, clientId, clientSecret)
if err != nil {
return nil, err
}
response["SECRET_HASH"] = secretHash
}
return response, nil
}
func getPasswordAuthenticationKey(poolName, username, password string, bigB, salt *big.Int) []byte {
var (
userPass = fmt.Sprintf("%s%s:%s", poolName, username, password)
userPassHash = hashSha256([]byte(userPass))
uVal, _ = big.NewInt(0).SetString(hexHash(padHex(bigA.Text(16))+padHex(bigB.Text(16))), 16)
xVal, _ = big.NewInt(0).SetString(hexHash(padHex(salt.Text(16))+userPassHash), 16)
gModPowXN = big.NewInt(0).Exp(g, xVal, bigN)
intVal1 = big.NewInt(0).Sub(bigB, big.NewInt(0).Mul(k, gModPowXN))
intVal2 = big.NewInt(0).Add(a, big.NewInt(0).Mul(uVal, xVal))
sVal = big.NewInt(0).Exp(intVal1, intVal2, bigN)
)
return computeHKDF(padHex(sVal.Text(16)), padHex(uVal.Text(16)))
}
func secretVerifierConfig(authResult *types.AuthenticationResultType, password string) (*types.DeviceSecretVerifierConfigType, error) {
secret := fmt.Sprintf("%s%s:%s", *authResult.NewDeviceMetadata.DeviceGroupKey, *authResult.NewDeviceMetadata.DeviceKey, password)
secretSha256 := sha256.Sum256([]byte(secret))
secretHash := hex.EncodeToString(secretSha256[:])
secretHash = strings.Repeat("0", 64-len(secretHash)) + secretHash
salt := make([]byte, 16)
_, _ = rand.Read(salt)
salt[0] = salt[0] & 127
saltHex := hex.EncodeToString(salt)
xHex := hexHash(saltHex + secretHash)
x, _ := big.NewInt(0).SetString(xHex, 16)
v := big.NewInt(0).Exp(g, x, bigN)
vHex := padHex(v.Text(16))
vBase64, err := hexToBase64(vHex)
if err != nil {
return nil, err
}
saltBase64, err := hexToBase64(saltHex)
if err != nil {
return nil, err
}
return &types.DeviceSecretVerifierConfigType{
PasswordVerifier: &vBase64,
Salt: &saltBase64,
}, nil
}
func getSecretHash(username, clientId, clientSecret string) (string, error) {
var (
msg = username + clientId
key = []byte(clientSecret)
h = hmac.New(sha256.New, key)
)
h.Write([]byte(msg))
sh := base64.StdEncoding.EncodeToString(h.Sum(nil))
return sh, nil
}
func computeHKDF(ikm, salt string) []byte {
ikmb, _ := hex.DecodeString(ikm)
saltb, _ := hex.DecodeString(salt)
extractor := hmac.New(sha256.New, saltb)
extractor.Write(ikmb)
prk := extractor.Sum(nil)
infoBitsUpdate := append([]byte(infoBits), byte(1))
extractor = hmac.New(sha256.New, prk)
extractor.Write(infoBitsUpdate)
hmacHash := extractor.Sum(nil)
return hmacHash[:16]
}
func hexHash(hexStr string) string {
buf, _ := hex.DecodeString(hexStr)
return hashSha256(buf)
}
func hashSha256(buf []byte) string {
a := sha256.New()
a.Write(buf)
return hex.EncodeToString(a.Sum(nil))
}
func padHex(hexStr string) string {
if len(hexStr)%2 == 1 {
hexStr = fmt.Sprintf("0%s", hexStr)
} else if strings.Contains("89ABCDEFabcdef", string(hexStr[0])) {
hexStr = fmt.Sprintf("00%s", hexStr)
}
return hexStr
}
func hexToBase64(hexStr string) (string, error) {
b, err := hex.DecodeString(hexStr)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(b), nil
}
func generateDevicePassword() (string, error) {
b := make([]byte, 40)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(b), nil
}
const (
nHex = "FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD1" +
"29024E088A67CC74020BBEA63B139B22514A08798E3404DD" +
"EF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245" +
"E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7ED" +
"EE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3D" +
"C2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F" +
"83655D23DCA3AD961C62F356208552BB9ED529077096966D" +
"670C354E4ABC9804F1746C08CA18217C32905E462E36CE3B" +
"E39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9" +
"DE2BCBF6955817183995497CEA956AE515D2261898FA0510" +
"15728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64" +
"ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7" +
"ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6B" +
"F12FFA06D98A0864D87602733EC86A64521F2B18177B200C" +
"BBE117577A615D6C770988C0BAD946E208E24FA074E5AB31" +
"43DB5BFCE0FD108E4B82D120A93AD2CAFFFFFFFFFFFFFFFF"
gHex = "2"
infoBits = "Caldera Derived Key"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment