Skip to content

Instantly share code, notes, and snippets.

@ykm11
Last active November 21, 2018 13:28
Show Gist options
  • Save ykm11/3118f316d9372dddf214bb1865f319c1 to your computer and use it in GitHub Desktop.
Save ykm11/3118f316d9372dddf214bb1865f319c1 to your computer and use it in GitHub Desktop.
Chacha20-Poly1305 re-developing
from chacha20 import *
from poly1305 import *
from utils import *
from calc import *
import struct
def chacha20_aead_encrypt(aad, key, iv, constant, plaintext):
"""
# https://tools.ietf.org/html/rfc7539#section-2.8.1
nonce = constant | iv
otk = poly1305_key_gen(key, nonce)
ciphertext = chacha20_encrypt(key, 1, nonce, plaintext)
mac_data = aad | pad16(aad)
mac_data |= ciphertext | pad16(ciphertext)
mac_data |= num_to_4_le_bytes(aad.length)
mac_data |= num_to_4_le_bytes(ciphertext.length)
tag = poly1305_mac(mac_data, otk)
return (ciphertext, tag)
"""
nonce = constant + iv
otk = poly1305_key_gen(key, nonce)
ciphertext = chacha20_encrypt(key, 1, nonce, plaintext)
mac_data = aad + pad16(aad)
mac_data += (ciphertext + pad16(ciphertext))
#mac_data += le_num2bytes(len(aad))
#mac_data += le_num2bytes(len(ciphertext))
mac_data += struct.pack("<Q", len(aad))
mac_data += struct.pack("<Q", len(ciphertext))
tag = poly1305_mac(mac_data, otk)
return ciphertext, tag
def Add(x, y):
return (x + y) % 0x100000000
def Xor(x, y):
return (x ^ y ) % 0x100000000
def LeftRotate(x, kbits):
right = (x >> (32 - kbits)) % 0x100000000
left = (x << kbits) % 0x100000000
return left | right
def QuarterRound(a, b, c, d):
"""
1. a += b; d ^= a; d <<<= 16;
2. c += d; b ^= c; b <<<= 12;
3. a += b; d ^= a; d <<<= 8;
4. c += d; b ^= c; b <<<= 7;
Applied mod 0x100000000
"""
a = Add(a, b); d = Xor(d, a); d = LeftRotate(d, 16)
c = Add(c, d); b = Xor(b, c); b = LeftRotate(b, 12)
a = Add(a, b); d = Xor(d, a); d = LeftRotate(d, 8)
c = Add(c, d); b = Xor(b, c); b = LeftRotate(b, 7)
return a, b, c, d
def clamp(r):
return r & 0x0ffffffc0ffffffc0ffffffc0fffffff
import binascii
from calc import *
from utils import *
# https://tools.ietf.org/html/rfc7539
def chacha20_initial_state(key:bytes, block_counter:int, nonce:bytes):
"""
cccccccc cccccccc cccccccc cccccccc
kkkkkkkk kkkkkkkk kkkkkkkk kkkkkkkk
kkkkkkkk kkkkkkkk kkkkkkkk kkkkkkkk
bbbbbbbb nnnnnnnn nnnnnnnn nnnnnnnn
c=constant[16 bytes] k=key[32bytes] b=blockcount[4bytes] n=nonce[12bytes]
"""
assert len(key) == 32
assert len(nonce) == 12
c = list(map(num2bytes, [0x61707865, 0x3320646e, 0x79622d32, 0x6b206574]))
k = make_array(key)
b = [num2bytes(block_counter)]
n = make_array(nonce)
initial_state = c + k + b + n
return initial_state
def chacha20_inner_block(state):
"""
state : c + k + b + n
1. QUARTERROUND ( 0, 4, 8,12)
2. QUARTERROUND ( 1, 5, 9,13)
3. QUARTERROUND ( 2, 6,10,14)
4. QUARTERROUND ( 3, 7,11,15)
5. QUARTERROUND ( 0, 5,10,15)
6. QUARTERROUND ( 1, 6,11,12)
7. QUARTERROUND ( 2, 7, 8,13)
8. QUARTERROUND ( 3, 4, 9,14)
"""
# Column
state[0], state[4], state[8], state[12] = QuarterRound(state[0], state[4], state[8], state[12])
state[1], state[5], state[9], state[13] = QuarterRound(state[1], state[5], state[9], state[13])
state[2], state[6], state[10], state[14] = QuarterRound(state[2], state[6], state[10], state[14])
state[3], state[7], state[11], state[15] = QuarterRound(state[3], state[7], state[11], state[15])
# Diagonal
state[0], state[5], state[10], state[15] = QuarterRound(state[0], state[5], state[10], state[15])
state[1], state[6], state[11], state[12] = QuarterRound(state[1], state[6], state[11], state[12])
state[2], state[7], state[8], state[13] = QuarterRound(state[2], state[7], state[8], state[13])
state[3], state[4], state[9], state[14] = QuarterRound(state[3], state[4], state[9], state[14])
def chacha20_block(key:bytes, counter:int, nonce:bytes):
"""
state = constants | key | counter | nonce
working_state = state
for i=1 upto 10
inner_block(working_state)
end
state += working_state
return serialize(state)
end
"""
state = chacha20_initial_state(key, counter, nonce)
state = list(map(bytes2num, state))
working_state = state.copy()
#print("[+] Initial State :", list(map(num2bytes, state)))
for _ in range(10):
chacha20_inner_block(working_state)
#print("[+] Mixed working_state :", list(map(num2bytes, working_state)))
for i in range(16):
state[i] = (state[i] + working_state[i]) % 0x100000000
state = list(map(le_num2bytes, state))
#print("[+] Serialized State :", state)
return state
def chacha20_encrypt(key, counter, nonce, plaintext):
"""
# https://tools.ietf.org/html/rfc7539#section-2.4.1
for j = 0 upto floor(len(plaintext)/64)-1
key_stream = chacha20_block(key, counter+j, nonce)
block = plaintext[(j*64)..(j*64+63)]
encrypted_message += block ^ key_stream
end
if ((len(plaintext) % 64) != 0)
j = floor(len(plaintext)/64)
key_stream = chacha20_block(key, counter+j, nonce)
block = plaintext[(j*64)..len(plaintext)-1]
encrypted_message += (block^key_stream)[0..len(plaintext)%64]
end
return encrypted_message
end
"""
encrypted_message = b""
for j in range(len(plaintext)//64):
key_stream = chacha20_block(key, counter+j, nonce)
block = plaintext[64*j : 64*j + 64]
key_stream = b''.join(key_stream)
enc_tmp = []
for ch1, ch2 in zip(block, key_stream):
enc_tmp.append(ch1 ^ ch2)
encrypted_message += bytes(enc_tmp)
if len(plaintext) % 64 != 0:
j = len(plaintext) // 64
key_stream = chacha20_block(key, counter+j, nonce)
block = plaintext[64*j : 64*j + (len(plaintext) % 64)]
key_stream = b''.join(key_stream)
enc_tmp = []
for ch1, ch2 in zip(block, key_stream):
enc_tmp.append(ch1 ^ ch2)
encrypted_message += bytes(enc_tmp)
return encrypted_message
from utils import *
from calc import *
from chacha20 import *
def poly1305_mac(msg, key:bytes):
"""
# https://tools.ietf.org/html/rfc7539#section-2.5.1
poly1305_mac(msg, key):
r = (le_bytes_to_num(key[0..15])
clamp(r)
s = le_num(key[16..31])
accumulator = 0
p = (1<<130)-5
for i=1 upto ceil(msg length in bytes / 16)
n = le_bytes_to_num(msg[((i-1)*16)..(i*16)] | [0x01])
a += n
a = (r * a) % p
end
a += s
return num_to_16_le_bytes(a)
end
"""
s = bytes2num(key[16:][::-1])
r = key[:16][::-1]
r = clamp(bytes2num(r))
print("[+] s :", hex(s))
print("[+] r :", hex(r))
a = 0
p = (1<<130) - 5
for i in range(1, (len(msg)-1)//16 + 2):
print("[+] Acc :", hex(a))
n = bytes2num((msg[(i - 1)*16 : i * 16] + b'\x01')[::-1])
print("[+] Block :", hex(n))
a += n
print("[+] Acc + Block :", hex(a))
print("[+] Acc + Block * r:", hex(a * r))
a = (r * a) % p
print("[+] (Acc + Block * r) % p:", hex(a * r))
print()
#a += s
a = (a + s) % 0x100000000000000000000000000000000 # 16 bytes
if len(hex(a)[2:]) % 2 != 0:
a = '0' + hex(a)[2:]
a = binascii.unhexlify(a.encode())[::-1]
else:
a = binascii.unhexlify(hex(a)[2:].encode())[::-1]
print("[+] tag :", a)
return a
def poly1305_key_gen(key, nonce):
counter = 0
block = chacha20_block(key, counter, nonce)
#print("[+] After generating Poly1305 one-time key :\n", block)
return b"".join(block[:8])
import unittest
from binascii import unhexlify
from chacha20 import *
from poly1305 import *
from aead_enc import *
from utils import *
from calc import *
class TestTashizan(unittest.TestCase):
def test_aead1(self):
aad = binascii.unhexlify(b'50515253c0c1c2c3c4c5c6c7')
key = binascii.unhexlify(b'808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f')
iv = binascii.unhexlify(b'4041424344454647')
constant = binascii.unhexlify(b'07000000')
plaintext = binascii.unhexlify(b"".join(b"""
4c 61 64 69 65 73 20 61 6e 64 20 47 65 6e 74 6c
65 6d 65 6e 20 6f 66 20 74 68 65 20 63 6c 61 73
73 20 6f 66 20 27 39 39 3a 20 49 66 20 49 20 63
6f 75 6c 64 20 6f 66 66 65 72 20 79 6f 75 20 6f
6e 6c 79 20 6f 6e 65 20 74 69 70 20 66 6f 72 20
74 68 65 20 66 75 74 75 72 65 2c 20 73 75 6e 73
63 72 65 65 6e 20 77 6f 75 6c 64 20 62 65 20 69
74 2e
""".split()))
enc, tag = chacha20_aead_encrypt(aad, key, iv, constant, plaintext)
expected_enc = unhexlify(b''.join(b'''
d3 1a 8d 34 64 8e 60 db 7b 86 af bc 53 ef 7e c2
a4 ad ed 51 29 6e 08 fe a9 e2 b5 a7 36 ee 62 d6
3d be a4 5e 8c a9 67 12 82 fa fb 69 da 92 72 8b
1a 71 de 0a 9e 06 0b 29 05 d6 a5 b6 7e cd 3b 36
92 dd bd 7f 2d 77 8b 8c 98 03 ae e3 28 09 1b 58
fa b3 24 e4 fa d6 75 94 55 85 80 8b 48 31 d7 bc
3f f4 de f0 8e 4b 7a 9d e5 76 d2 65 86 ce c6 4b
61 16
'''.split()))
expected_tag = unhexlify(b'1ae10b594f09e26a7e902ecbd0600691')
self.assertEqual(expected_enc, enc)
self.assertEqual(expected_tag, tag)
def test_aead2(self):
aad = unhexlify(b''.join(b'f3 33 88 86 00 00 00 00 00 00 4e 91'.split()))
key = unhexlify(b''.join(b'1c 92 40 a5 eb 55 d3 8a f3 33 88 86 04 f6 b5 f0 47 39 17 c1 40 2b 80 09 9d ca 5c bc 20 70 75 c0'.split()))
constant = unhexlify(b''.join(b'00 00 00 00'.split()))
iv = unhexlify(b''.join(b'01 02 03 04 05 06 07 08'.split()))
plaintext = unhexlify(b''.join(b'''
49 6e 74 65 72 6e 65 74 2d 44 72 61 66 74 73 20
61 72 65 20 64 72 61 66 74 20 64 6f 63 75 6d 65
6e 74 73 20 76 61 6c 69 64 20 66 6f 72 20 61 20
6d 61 78 69 6d 75 6d 20 6f 66 20 73 69 78 20 6d
6f 6e 74 68 73 20 61 6e 64 20 6d 61 79 20 62 65
20 75 70 64 61 74 65 64 2c 20 72 65 70 6c 61 63
65 64 2c 20 6f 72 20 6f 62 73 6f 6c 65 74 65 64
20 62 79 20 6f 74 68 65 72 20 64 6f 63 75 6d 65
6e 74 73 20 61 74 20 61 6e 79 20 74 69 6d 65 2e
20 49 74 20 69 73 20 69 6e 61 70 70 72 6f 70 72
69 61 74 65 20 74 6f 20 75 73 65 20 49 6e 74 65
72 6e 65 74 2d 44 72 61 66 74 73 20 61 73 20 72
65 66 65 72 65 6e 63 65 20 6d 61 74 65 72 69 61
6c 20 6f 72 20 74 6f 20 63 69 74 65 20 74 68 65
6d 20 6f 74 68 65 72 20 74 68 61 6e 20 61 73 20
2f e2 80 9c 77 6f 72 6b 20 69 6e 20 70 72 6f 67
72 65 73 73 2e 2f e2 80 9d
'''.split()))
enc, tag = chacha20_aead_encrypt(aad, key, iv, constant, plaintext)
expected_tag = unhexlify(b''.join(b'ee ad 9d 67 89 0c bb 22 39 23 36 fe a1 85 1f 38'.split()))
expected_enc = unhexlify(b''.join(b'''
64 a0 86 15 75 86 1a f4 60 f0 62 c7 9b e6 43 bd
5e 80 5c fd 34 5c f3 89 f1 08 67 0a c7 6c 8c b2
4c 6c fc 18 75 5d 43 ee a0 9e e9 4e 38 2d 26 b0
bd b7 b7 3c 32 1b 01 00 d4 f0 3b 7f 35 58 94 cf
33 2f 83 0e 71 0b 97 ce 98 c8 a8 4a bd 0b 94 81
14 ad 17 6e 00 8d 33 bd 60 f9 82 b1 ff 37 c8 55
97 97 a0 6e f4 f0 ef 61 c1 86 32 4e 2b 35 06 38
36 06 90 7b 6a 7c 02 b0 f9 f6 15 7b 53 c8 67 e4
b9 16 6c 76 7b 80 4d 46 a5 9b 52 16 cd e7 a4 e9
90 40 c5 a4 04 33 22 5e e2 82 a1 b0 a0 6c 52 3e
af 45 34 d7 f8 3f a1 15 5b 00 47 71 8c bc 54 6a
0d 07 2b 04 b3 56 4e ea 1b 42 22 73 f5 48 27 1a
0b b2 31 60 53 fa 76 99 19 55 eb d6 31 59 43 4e
ce bb 4e 46 6d ae 5a 10 73 a6 72 76 27 09 7a 10
49 e6 17 d9 1d 36 10 94 fa 68 f0 ff 77 98 71 30
30 5b ea ba 2e da 04 df 99 7b 71 4d 6c 6f 2c 29
a6 ad 5c b4 02 2b 02 70 9b
'''.split()))
self.assertEqual(expected_tag, tag)
self.assertEqual(expected_enc, enc)
def test_mac1(self):
key1 = unhexlify(b"".join(b"""
00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00
36 e5 f6 b5 c5 e0 60 70 f0 ef ca 96 22 7a 86 3e""".split()))
key2 = unhexlify(b"".join(b"""
36 e5 f6 b5 c5 e0 60 70 f0 ef ca 96 22 7a 86 3e
00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00""".split()))
msg = unhexlify(b"".join(b'''
41 6e 79 20 73 75 62 6d 69 73 73 69 6f 6e 20 74
6f 20 74 68 65 20 49 45 54 46 20 69 6e 74 65 6e
64 65 64 20 62 79 20 74 68 65 20 43 6f 6e 74 72
69 62 75 74 6f 72 20 66 6f 72 20 70 75 62 6c 69
63 61 74 69 6f 6e 20 61 73 20 61 6c 6c 20 6f 72
20 70 61 72 74 20 6f 66 20 61 6e 20 49 45 54 46
20 49 6e 74 65 72 6e 65 74 2d 44 72 61 66 74 20
6f 72 20 52 46 43 20 61 6e 64 20 61 6e 79 20 73
74 61 74 65 6d 65 6e 74 20 6d 61 64 65 20 77 69
74 68 69 6e 20 74 68 65 20 63 6f 6e 74 65 78 74
20 6f 66 20 61 6e 20 49 45 54 46 20 61 63 74 69
76 69 74 79 20 69 73 20 63 6f 6e 73 69 64 65 72
65 64 20 61 6e 20 22 49 45 54 46 20 43 6f 6e 74
72 69 62 75 74 69 6f 6e 22 2e 20 53 75 63 68 20
73 74 61 74 65 6d 65 6e 74 73 20 69 6e 63 6c 75
64 65 20 6f 72 61 6c 20 73 74 61 74 65 6d 65 6e
74 73 20 69 6e 20 49 45 54 46 20 73 65 73 73 69
6f 6e 73 2c 20 61 73 20 77 65 6c 6c 20 61 73 20
77 72 69 74 74 65 6e 20 61 6e 64 20 65 6c 65 63
74 72 6f 6e 69 63 20 63 6f 6d 6d 75 6e 69 63 61
74 69 6f 6e 73 20 6d 61 64 65 20 61 74 20 61 6e
79 20 74 69 6d 65 20 6f 72 20 70 6c 61 63 65 2c
20 77 68 69 63 68 20 61 72 65 20 61 64 64 72 65
73 73 65 64 20 74 6f
'''.split()))
tag1 = poly1305_mac(msg, key1)
expected_tag1 = unhexlify(b''.join(b'36 e5 f6 b5 c5 e0 60 70 f0 ef ca 96 22 7a 86 3e'.split()))
tag2 = poly1305_mac(msg, key2)
expected_tag2 = unhexlify(b''.join(b'f3 47 7e 7c d9 54 17 af 89 a6 b8 79 4c 31 0c f0'.split()))
self.assertEqual(expected_tag1, tag1)
self.assertEqual(expected_tag2, tag2)
def test_mac2(self):
key = unhexlify(b"".join(b"""
1c 92 40 a5 eb 55 d3 8a f3 33 88 86 04 f6 b5 f0
47 39 17 c1 40 2b 80 09 9d ca 5c bc 20 70 75 c0
""".split()))
msg = unhexlify(b"".join(b"""
27 54 77 61 73 20 62 72 69 6c 6c 69 67 2c 20 61
6e 64 20 74 68 65 20 73 6c 69 74 68 79 20 74 6f
76 65 73 0a 44 69 64 20 67 79 72 65 20 61 6e 64
20 67 69 6d 62 6c 65 20 69 6e 20 74 68 65 20 77
61 62 65 3a 0a 41 6c 6c 20 6d 69 6d 73 79 20 77
65 72 65 20 74 68 65 20 62 6f 72 6f 67 6f 76 65
73 2c 0a 41 6e 64 20 74 68 65 20 6d 6f 6d 65 20
72 61 74 68 73 20 6f 75 74 67 72 61 62 65 2e
""".split()))
tag = poly1305_mac(msg, key)
expected_tag = unhexlify(b''.join(b'45 41 66 9a 7e aa ee 61 e7 08 dc 7c bc c5 eb 62'.split()))
self.assertEqual(expected_tag, tag)
def test_otk1(self):
key = b'\x00'*32
nonce = b'\x00'*12
otk = poly1305_key_gen(key, nonce)
expected_otk = unhexlify(b''.join(b"76 b8 e0 ad a0 f1 3d 90 40 5d 6a e5 53 86 bd 28 bd d2 19 b8 a0 8d ed 1a a8 36 ef cc 8b 77 0d c7".split()))
self.assertEqual(otk, expected_otk)
def test_otk2(self):
key = b'\x00'*31 + b'\x01'
nonce = b'\x00'*11 + b'\x02'
otk = poly1305_key_gen(key, nonce)
expected_otk = unhexlify(b''.join(b"ec fa 25 4f 84 5f 64 74 73 d3 cb 14 0d a9 e8 76 06 cb 33 06 6c 44 7b 87 bc 26 66 dd e3 fb b7 39".split()))
self.assertEqual(otk, expected_otk)
if __name__ == "__main__":
unittest.main()
import struct
import binascii
def le_num2bytes(x):
return struct.pack("<I", x % 0x100000000)
def num2bytes(x):
return struct.pack(">I", x % 0x100000000)
def bytes2num(data):
return int(binascii.hexlify(data), 16)
def le_bytes2num(data):
return struct.unpack("<I", data)[0]
def make_array(data:bytes, endian='little'):
if endian == 'little':
result = [data[index : index + 4][::-1] for index in range(0, len(data), 4)]
elif endian == 'big':
result = [data[index : index + 4] for index in range(0, len(data), 4)]
else:
raise ValueError("Invalid endian")
return result
def pad16(x):
if (len(x) % 16) == 0:
return b''
else:
return b'\x00'*(16-(len(x)%16))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment