Last active
August 6, 2021 19:06
-
-
Save chris-wood/2205c8e79b309adec5a785470d31226e to your computer and use it in GitHub Desktop.
SPAKE2 review and test vector check
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
// Questions: | |
// - Should we include a definition of UKS attacks inline, rather than cite draft-ietf-mmusic-sdp-uks? | |
// - Should SPAKE2 require that the output length of Hash is at least 256-bits? (It's output is split in half to derive Ke and Ka, and we probably want those to have at least 128 bits.) | |
// - What does it mean to exchange messages symmetrically? (In the per-user M and N section) | |
// - Beyond scalar multiplication being constant time, are there any other constant time considerations we should include? | |
// - Why is Ke not included in the test vectors? It may be redundant, but it seems useful as an additional sanity check. | |
// - There are currently no test vectors that include AAD -- should we add some? | |
// - Why is len() a little-endian output? | |
// - Should we clarify that the transcript assumes a particular point encoding scheme, and that is to be defined by the calling application or implementation? | |
import ( | |
"bytes" | |
"crypto/elliptic" | |
"crypto/hmac" | |
"crypto/sha256" | |
"encoding/hex" | |
"fmt" | |
"io" | |
"math/big" | |
"golang.org/x/crypto/hkdf" | |
) | |
// P256-SHA256-HKDF-HMAC | |
func hash(input []byte) []byte { | |
output := sha256.Sum256(input) | |
return output[:] | |
} | |
func kdf(ikm, salt, info []byte) []byte { | |
hash := sha256.New | |
hkdf := hkdf.New(hash, ikm, salt, info) | |
output := make([]byte, 32) | |
if _, err := io.ReadFull(hkdf, output); err != nil { | |
panic(err) | |
} | |
return output | |
} | |
func mac(key, input []byte) []byte { | |
mac := hmac.New(sha256.New, key) | |
mac.Write(input) | |
return mac.Sum(nil) | |
} | |
func mustDecode(input string) []byte { | |
x, err := hex.DecodeString(input) | |
if err != nil { | |
panic(err) | |
} | |
return x | |
} | |
func mustDecodeInt(input string) big.Int { | |
x := mustDecode(input) | |
var y big.Int | |
y.SetBytes(x) | |
return y | |
} | |
func toLE(b []byte) []byte { | |
for i := 0; i < len(b)/2; i++ { | |
b[i], b[len(b)-i-1] = b[len(b)-i-1], b[i] | |
} | |
return b | |
} | |
func length(input []byte) []byte { | |
buf := make([]byte, 8) | |
x := big.NewInt(int64(len(input))) | |
x.FillBytes(buf) | |
return toLE(buf) | |
} | |
func lengthEncodeString(input string) []byte { | |
return append(length([]byte(input)), []byte(input)...) | |
} | |
func lengthEncodeSlice(input []byte) []byte { | |
return append(length(input), input...) | |
} | |
func lengthEncodePoint(curve elliptic.Curve, Px, Py *big.Int) []byte { | |
val := elliptic.Marshal(curve, Px, Py) | |
return append(length(val), val...) | |
} | |
func main() { | |
fmt.Println("SPAKE2 test vector check") | |
A := "server" | |
B := "client" | |
x := mustDecode("43dd0fd7215bdcb482879fca3220c6a968e66d70b1356cac18bb26c84a78d729") | |
y := mustDecode("dcb60106f276b02606d8ef0a328c02e4b629f84f89786af5befb0bc75b6e66be") | |
w := mustDecode("2ee57912099d31560b3a44b1184b9b4866e904c49d12ac5042c97dca461b1a5f") | |
Menc := mustDecode("02886e2f97ace46e55ba9dd7242579f2993b64e16ef3dcab95afd497333d8fa12f") | |
Nenc := mustDecode("03d8bbd6c639c62937b04d997f38c3770719c629d7014d49a24b4f98baa1292b49") | |
KaExp := mustDecode("15bdf72e2b35b5c9e5663168e960a91b") | |
KcAExp := mustDecode("00c12546835755c86d8c0db7851ae86f") | |
KcBExp := mustDecode("a9fa3406c3b781b93d804485430ca27a") | |
var wVal big.Int | |
wVal.SetBytes(w) | |
curve := elliptic.P256() | |
Mx, My := elliptic.UnmarshalCompressed(curve, Menc) | |
Nx, Ny := elliptic.UnmarshalCompressed(curve, Nenc) | |
// X = x * P | |
Xx, Xy := curve.ScalarBaseMult(x) | |
// S = w * M + X | |
Ux, Uy := curve.ScalarMult(Mx, My, w) | |
Sx, Sy := curve.Add(Ux, Uy, Xx, Xy) | |
Senc := hex.EncodeToString(elliptic.Marshal(curve, Sx, Sy)) | |
fmt.Println("S", Senc) | |
// 04a56fa807caaa53a4d28dbb9853b9815c61a411118a6fe516a8798434751470f9010153ac33d0d5f2047ffdb1a3e42c9b4e6be662766e1eeb4116988ede5f912c | |
// Y = y * P | |
Yx, Yy := curve.ScalarBaseMult(y) | |
// T = w * N + Y | |
Ux, Uy = curve.ScalarMult(Nx, Ny, w) | |
Tx, Ty := curve.Add(Ux, Uy, Yx, Yy) | |
Tenc := hex.EncodeToString(elliptic.Marshal(curve, Tx, Ty)) | |
fmt.Println("T", Tenc) | |
// 0406557e482bd03097ad0cbaa5df82115460d951e3451962f1eaf4367a420676d09857ccbc522686c83d1852abfa8ed6e4a1155cf8f1543ceca528afb591a1e0b7 | |
// A: h*x*(T-w*N), where h=1 | |
// A: x * (T - w * N) | |
Ux, Uy = curve.ScalarMult(Nx, Ny, wVal.Bytes()) | |
Uy.Neg(Uy) | |
Vx, Vy := curve.Add(Tx, Ty, Ux, Uy) | |
Gx, Gy := curve.ScalarMult(Vx, Vy, x) | |
// B: h*y*(S-w*M), where h=1 | |
// B: y * (S - w * M) | |
Ux, Uy = curve.ScalarMult(Mx, My, wVal.Bytes()) | |
Uy.Neg(Uy) | |
Vx, Vy = curve.Add(Sx, Sy, Ux, Uy) | |
Hx, Hy := curve.ScalarMult(Vx, Vy, y) | |
if Gx.Cmp(Hx) != 0 { | |
panic("Mismatch shared x coordinate") | |
} | |
if Gy.Cmp(Hy) != 0 { | |
panic("Mismatch shared y coordinate") | |
} | |
TT := lengthEncodeString(A) | |
TT = append(TT, lengthEncodeString(B)...) | |
TT = append(TT, lengthEncodePoint(curve, Sx, Sy)...) | |
TT = append(TT, lengthEncodePoint(curve, Tx, Ty)...) | |
TT = append(TT, lengthEncodePoint(curve, Gx, Gy)...) // K | |
TT = append(TT, lengthEncodeSlice(w)...) | |
fmt.Println("TT =", hex.EncodeToString(TT)) | |
// 06000000000000007365727665720600000000000000636c69656e74410000000000000004a56fa807caaa53a4d28dbb9853b9815c61a411118a6fe516a8798434751470f9010153ac33d0d5f2047ffdb1a3e42c9b4e6be662766e1eeb4116988ede5f912c41000000000000000406557e482bd03097ad0cbaa5df82115460d951e3451962f1eaf4367a420676d09857ccbc522686c83d1852abfa8ed6e4a1155cf8f1543ceca528afb591a1e0b741000000000000000412af7e89717850671913e6b469ace67bd90a4df8ce45c2af19010175e37eed69f75897996d539356e2fa6a406d528501f907e04d97515fbe83db277b715d332520000000000000002ee57912099d31560b3a44b1184b9b4866e904c49d12ac5042c97dca461b1a5f | |
secret := hash(TT) // No AAD | |
// KeA := secret[0:16] // not tested, or included in the test vectors | |
KaA := secret[16:32] | |
derivedKeys := kdf(KaA, nil, []byte("ConfirmationKeys")) | |
KcA := derivedKeys[0:16] | |
KcB := derivedKeys[16:32] | |
if !bytes.Equal(KaA, KaExp) { | |
panic("Ka mismatch") | |
} | |
if !bytes.Equal(KcAExp, KcA) { | |
panic("KcA mismatch") | |
} | |
if !bytes.Equal(KcBExp, KcB) { | |
panic("KcB mismatch") | |
} | |
fmt.Println("Done") | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment