Skip to content

Instantly share code, notes, and snippets.

@DavidBuchanan314
Last active March 9, 2022 23:23
Show Gist options
  • Save DavidBuchanan314/3acd177ba3947443f23a35be0b3b24c0 to your computer and use it in GitHub Desktop.
Save DavidBuchanan314/3acd177ba3947443f23a35be0b3b24c0 to your computer and use it in GitHub Desktop.
Pure python Curve25519 scalar point multiplication, as defined in RFC7748
# https://datatracker.ietf.org/doc/html/rfc7748
# This code is based directly on the pseudocode in the RFC, translated into Python 3
# This implementation is NOT CONSTANT TIME (due to Python's underlying arithmetic ops not being guaranteed constant time)
# See also: https://en.wikipedia.org/wiki/Montgomery_curve
p = 2**255 - 19
def decodeScalar25519(k):
k = bytearray(k)
k[0] &= 0xf8
k[31] &= 0x7f
k[31] |= 0x40
return int.from_bytes(k, "little")
# This might not be constant-time, but the rest of this code isn't either...
def cswap(swap, a, b):
return (b, a) if swap else (a, b)
def X25519(k, u):
x_1 = u
x_2 = 1
z_2 = 0
x_3 = u
z_3 = 1
swap = 0
for t in range(255)[::-1]:
k_t = (k >> t) & 1
swap ^= k_t
x_2, x_3 = cswap(swap, x_2, x_3)
z_2, z_3 = cswap(swap, z_2, z_3)
swap = k_t
A = x_2 + z_2
AA = pow(A, 2, p)
B = x_2 - z_2
BB = pow(B, 2, p)
E = AA - BB
C = x_3 + z_3
D = x_3 - z_3
DA = (D * A) % p
CB = (C * B) % p
x_3 = pow(DA + CB, 2, p)
z_3 = (x_1 * pow(DA - CB, 2, p)) % p
x_2 = (AA * BB) % p
z_2 = (E * (AA + ((121665 * E) % p))) % p
x_2, x_3 = cswap(swap, x_2, x_3)
z_2, z_3 = cswap(swap, z_2, z_3)
return (x_2 * pow(z_2, p - 2, p)) % p
if __name__ == "__main__":
# first test vector from rfc7748
scalar_in = decodeScalar25519(bytes.fromhex("a546e36bf0527c9d3b16154b82465edd62144c0ac1fc5a18506a2244ba449ac4"))
u_in = int.from_bytes(bytes.fromhex("e6db6867583030db3594c1a424b15f7c726624ec26b3353b10a903a6d0ab1c4c"), "little")
u_out = X25519(scalar_in, u_in)
print(u_out.to_bytes(32, "little").hex())
@DavidBuchanan314
Copy link
Author

DavidBuchanan314 commented Mar 9, 2022

So uh, there's a bug in here somewhere. It gets the wrong answer for:

	scalar_in = decodeScalar25519(bytes.fromhex("e6db6867583030db3594c1a424b15f7c726624ec26b3353b10a903a6d0ab1c4c"))
	u_in = int.from_bytes(bytes.fromhex("a546e36bf0527c9d3b16154b82465edd62144c0ac1fc5a18506a2244ba449ac4"), "little")
	u_out = X25519(scalar_in, u_in)

(same inputs as the first RFC test vector, but swapped order)

I believe the correct answer is 94cffd93eadcebf78386fd6206c01a2f96612ab5a07850eaa109679195c49149 - but I cannot find the bug...

Sample code that gets the correct output for this testcase:

from cryptography.hazmat.primitives.asymmetric import x25519
priv = x25519.X25519PrivateKey.from_private_bytes(bytes.fromhex("e6db6867583030db3594c1a424b15f7c726624ec26b3353b10a903a6d0ab1c4c"))
pub = x25519.X25519PublicKey.from_public_bytes(bytes.fromhex(   "a546e36bf0527c9d3b16154b82465edd62144c0ac1fc5a18506a2244ba449ac4"))
sk = priv.exchange(pub)
print(sk.hex())

@DavidBuchanan314
Copy link
Author

When receiving such an array, implementations of X25519 (but not X448) MUST mask the most significant bit in the final byte.

Of course...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment