Skip to content

Instantly share code, notes, and snippets.

@nkoneko
Created March 18, 2023 15:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nkoneko/c6f9f4332463a2a0d9b3c617043ee7a4 to your computer and use it in GitHub Desktop.
Save nkoneko/c6f9f4332463a2a0d9b3c617043ee7a4 to your computer and use it in GitHub Desktop.
def gf_mul(a, b):
result = 0
while b:
if b & 1:
result ^= a
a <<= 1
if a & 0x100:
a ^= 0x11b
b >>= 1
return result
def gf_sqr(x):
return gf_mul(x, x)
def gf_inv(x):
y = x
for _ in range(6):
y = gf_sqr(y)
y = gf_mul(x, y)
y = gf_sqr(y)
return y
def sub_bytes(state):
for i in range(4):
for j in range(4):
byte = state[i][j]
if byte == 0x00:
inv = 0
else:
inv = gf_inv(byte)
tmp = inv
tmp ^= ((inv << 1) | (inv >> 7)) & 0xFF
tmp ^= ((inv << 2) | (inv >> 6)) & 0xFF
tmp ^= ((inv << 3) | (inv >> 5)) & 0xFF
tmp ^= ((inv << 4) | (inv >> 4)) & 0xFF
tmp ^= 0x63
state[i][j] = tmp & 0xff
return state
def shift_rows(state):
for i in range(1, 4):
state[i] = state[i][i:] + state[i][:i]
return state
def xor_sum(bs):
result = 0
for b in bs:
result ^= b
return result
def mix_columns(state):
mix_matrix = [
[2, 3, 1, 1],
[1, 2, 3, 1],
[1, 1, 2, 3],
[3, 1, 1, 2]
]
temp = [[0] * 4 for _ in range(4)]
for i in range(4):
for j in range(4):
temp[i][j] = xor_sum(gf_mul(mix_matrix[i][k], state[k][j]) for k in range(4))
return temp
def xor_bytes(a, b):
return [x ^ y for x, y in zip(a, b)]
def rot_word(word):
return word[1:] + word[:1]
def sub_word(word):
result = []
for x in word:
inv = gf_inv(x)
tmp = inv
tmp ^= ((inv << 1) | (inv >> 7)) & 0xFF
tmp ^= ((inv << 2) | (inv >> 6)) & 0xFF
tmp ^= ((inv << 3) | (inv >> 5)) & 0xFF
tmp ^= ((inv << 4) | (inv >> 4)) & 0xFF
tmp ^= 0x63
result.append(tmp)
return result
def key_schedule(key, rounds):
key_words = [key[i:i + 4] for i in range(0, len(key), 4)]
rcon = [0x01]
for _ in range(1, rounds):
val = rcon[-1]
if val < 0x80:
rcon.append(val << 1)
else:
rcon.append((val << 1) ^ 0x11b)
for i in range(4, 4 * (rounds + 1)):
temp = key_words[i - 1]
if i % 4 == 0:
temp = xor_bytes(sub_word(rot_word(temp)), [rcon[i // 4 - 1], 0, 0, 0])
key_words.append(xor_bytes(key_words[i - 4], temp))
return key_words
def add_round_key(state, key):
for i in range(4):
for j in range(4):
state[i][j] ^= key[j][i]
return state
def aes_encrypt(plaintext, key, rounds=10):
key_schedule_list = key_schedule(key, rounds)
state = [[plaintext[j + 4 * i] for i in range(4)] for j in range(4)]
state = add_round_key(state, key_schedule_list[0:4])
for r in range(1, rounds):
state = sub_bytes(state)
state = shift_rows(state)
state = mix_columns(state)
state = add_round_key(state, key_schedule_list[4*r:4*(r+1)])
state = sub_bytes(state)
state = shift_rows(state)
state = add_round_key(state, key_schedule_list[4*rounds:4*(rounds+1)])
return [state[j][i] for j in range(4) for i in range(4)]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment