Last active
December 1, 2023 22:55
-
-
Save KenjiOhtsuka/822aa5d17137d9065622740b3573e5e8 to your computer and use it in GitHub Desktop.
SHA-256 Calculation in Python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# メッセージをビット列に変換する関数 | |
def message_to_bitstring(message): | |
bitstring = "" | |
for char in message: | |
# ASCII コードを 8 ビットのビット列に変換する | |
bit = bin(ord(char))[2:].zfill(8) | |
bitstring += bit | |
return bitstring | |
# メッセージにパディングを行う関数 | |
def padding(bitstring): | |
# メッセージの長さを保存する | |
length = len(bitstring) | |
# メッセージの末尾に 1 ビットの 1 を追加する | |
bitstring += "1" | |
# メッセージの長さが 448 mod 512 になるまで 0 ビットを追加する | |
remainder = len(bitstring) % 512 | |
bitstring += "0" * ((448 - remainder) + (remainder > 448) * 512) | |
# メッセージの長さを 64 ビットのビット列に変換し、メッセージの末尾に追加する | |
bitstring += bin(length)[2:].zfill(64) | |
# この時点で、 448 + 64 = 512 ビットになっている | |
# 64 bit で表される数値の最大値は 2^64 - 1 = 18446744073709551615 であり | |
# これを超えるメッセージは SHA-256 でハッシュ化できない | |
# 厳密には 2^64 - 1 - 1 - 64 = 18446744073709551550 ビットまでしかハッシュ化できない | |
return bitstring | |
# ビット列を 512 ビットずつのブロックに分割する関数 | |
def split_blocks(bitstring, block_size=512): | |
blocks = [] | |
# ビット列を 512 ビットずつに切り出す | |
for i in range(0, len(bitstring), block_size): | |
block = bitstring[i:i+block_size] | |
blocks.append(block) | |
return blocks | |
# ブロックを 32 ビットずつのワードに分割する関数 | |
def split_words(block, w=32): | |
words = [] | |
# ブロックを w ビットずつに切り出す | |
for i in range(0, len(block), w): | |
word = block[i:i+w] | |
words.append(word) | |
return words | |
############################################### | |
# Value Conversion | |
############################################### | |
class ValueConversion: | |
@staticmethod | |
def _int_to_hex(num): | |
assert 0 <= num < 16 | |
if num < 10: | |
return str(num) | |
else: | |
return chr(ord('a') + num - 10) | |
@staticmethod | |
def _hex_to_int(char): | |
if char.isdigit(): | |
return int(char) | |
else: | |
return ord(char) - ord('a') + 10 | |
@staticmethod | |
def decimal_hex_digits(num): | |
h = 2 ** 4 # 16 | |
u = 1 / h | |
s = "" | |
num %= 1 | |
for i in range(8): | |
d = num // u | |
s += ValueConversion._int_to_hex(int(d)) | |
num = (num - u * d) | |
num *= h | |
return s | |
@staticmethod | |
def hex_to_integer(s): | |
""" | |
convert hex string to integer | |
example: | |
>>> ValueConversion.hex_to_integer("1") | |
1 | |
>>> ValueConversion.hex_to_integer("a") | |
10 | |
>>> ValueConversion.hex_to_integer("ff") | |
255 | |
:param s: | |
:return: | |
""" | |
n = 0 | |
for c in s: | |
assert c in "0123456789abcdef" | |
n <<= 4 | |
n += int(c, 16) | |
return n | |
@staticmethod | |
def bitstring_to_integer(s: str) -> int: | |
""" | |
convert bit string to integer | |
example: | |
>>> ValueConversion.bitstring_to_integer("1") | |
1 | |
>>> ValueConversion.bitstring_to_integer("101") | |
5 | |
>>> ValueConversion.bitstring_to_integer("10101") | |
21 | |
:param s: | |
:return: | |
""" | |
n = 0 | |
for c in s: | |
assert c in "01" | |
n <<= 1 | |
n += int(c) | |
return n | |
@staticmethod | |
def integer_to_bitstring(n: int, l: int = 0) -> str: | |
""" | |
convert integer to bit string | |
example: | |
>>> ValueConversion.integer_to_bitstring(1) | |
'1' | |
>>> ValueConversion.integer_to_bitstring(5) | |
'101' | |
:param n: | |
:param l: | |
:return: | |
""" | |
s = "" | |
while n > 0: | |
s = str(n % 2) + s | |
n //= 2 | |
if l > 0: | |
s = "0" * (l - len(s)) + s | |
return s | |
@staticmethod | |
def integer_to_hexstring(n: int, l: int = 0) -> str: | |
""" | |
convert integer to hex string | |
example: | |
>>> ValueConversion.integer_to_hexstring(1) | |
'1' | |
>>> ValueConversion.integer_to_hexstring(10) | |
'a' | |
>>> ValueConversion.integer_to_hexstring(255) | |
'ff' | |
:param n: | |
:param l: | |
:return: | |
""" | |
s = "" | |
while n > 0: | |
s = ValueConversion._int_to_hex(n % 16) + s | |
n //= 16 | |
if l > 0: | |
s = "0" * (l - len(s)) + s | |
return s | |
@staticmethod | |
def hexstring_to_bitstring(hexstring: str) -> str: | |
""" | |
convert hex string to bit string | |
example: | |
>>> ValueConversion.hexstring_to_bitstring("1") | |
'0001' | |
>>> ValueConversion.hexstring_to_bitstring("a") | |
'1010' | |
>>> ValueConversion.hexstring_to_bitstring("ff") | |
'11111111' | |
:param hexstring: | |
:return: | |
""" | |
for c in hexstring: | |
assert c in "0123456789abcdef" | |
def _hex_to_bitstring(c): | |
n = ValueConversion._hex_to_int(c) | |
s = "" | |
for _ in range(4): | |
s = str(n % 2) + s | |
n >>= 1 | |
return s | |
return "".join(map(lambda x: _hex_to_bitstring(x), hexstring)) | |
############################################### | |
# Basic Arithmetic | |
############################################### | |
def sqrt(n: int, p: int = 16): | |
""" | |
calculate square root | |
>>> 2 - sqrt(2) ** 2 < 10 ** -15 | |
True | |
>>> 3 - sqrt(3) ** 2 < 10 ** -15 | |
True | |
>>> 121 - sqrt(121) ** 2 < 10 ** -12 | |
True | |
>>> 700 - sqrt(700) ** 2 < 10 ** -12 | |
True | |
>>> 999 - sqrt(999) ** 2 < 10 ** -12 | |
True | |
not precise but sufficient for sha-256 calculation. | |
:param n: number that is not more than 100 | |
:param p: | |
:return: | |
""" | |
s = 0 | |
while n > 100: | |
n /= 100 | |
s += 1 | |
assert 0 < n <= 100 | |
r = 0 | |
a = 0 | |
for j in range(p): | |
for i in range(9, -1, -1): | |
if (r + i) * i <= n: | |
a += i / (10 ** j) | |
n = (n - (r + i) * i) * 100 | |
r = (r + i * 2) * 10 | |
break | |
return a * (10 ** s) | |
def curt(n: int, p: int = 16): | |
""" | |
calculate cube root | |
>>> 2 - curt(2) ** 3 < 10 ** -14 | |
True | |
>>> 3 - curt(3) ** 3 < 10 ** -14 | |
True | |
>>> 300 - curt(300) ** 3 < 10 ** -13 | |
True | |
>>> 999 - curt(999) ** 3 < 10 ** -12 | |
True | |
:param n: number that is not more than 1000 | |
:param p: | |
:return: | |
""" | |
assert 0 < n <= 1000 | |
r1 = 0 | |
r2 = 0 | |
a = 0 | |
for j in range(p): | |
for i in range(9, -1, -1): | |
if (r2 + (r1 + i) * i) * i <= n: | |
n = (n - (r2 + (r1 + i) * i) * i) * 1000 | |
r2 = (r2 + 2 * (r1 + i) * i + i ** 2) * 100 | |
r1 = (r1 + i * 3) * 10 | |
a += i / (10 ** j) | |
break | |
return a | |
class Prime: | |
def __init__(self): | |
self.nums = [2, 3, 5, 7, 11, 13, 17, 19, 23] | |
self.N = self.nums[-1] | |
def _append_if_prime(self, n): | |
""" | |
append n to self.nums if n is prime | |
>>> Prime()._append_if_prime(121) | |
False | |
:param n: | |
:return: | |
""" | |
if n in self.nums: | |
return True | |
if n <= self.N: | |
return False | |
m = int(sqrt(n) + 0.001) | |
for i in self.nums: | |
if i > m: | |
self.nums.append(n) | |
self.N = n | |
return True | |
if n % i == 0: | |
self.N = n | |
return False | |
while i <= m: | |
if n % i == 0: | |
self.N = n | |
return False | |
i += 2 | |
self.N = n | |
self.nums.append(n) | |
return True | |
def numbers(self, num): | |
c = len(self.nums) | |
n = self.N + 1 + self.N % 2 | |
while c < num: | |
if self._append_if_prime(n): | |
c += 1 | |
n += 2 | |
return self.nums[:num] | |
############################################### | |
# Bit Operations | |
############################################### | |
def rotl(n: int, x, w=32): | |
""" | |
rotate left | |
:param x: w-bit word | |
:param n: | |
:param w: | |
:return: | |
""" | |
assert 0 <= n < w | |
assert 0 <= x < 2 ** w | |
return (x << n) | (x >> (w - n)) | |
def rotr(n: int, x, w=32): | |
""" | |
rotate right | |
:param x: w-bit word | |
:param n: | |
:param w: | |
:return: | |
""" | |
assert 0 <= n < w | |
assert 0 <= x < 2 ** w | |
return (x >> n) | (x << (w - n)) | |
def shr(n: int, x, w=32): | |
""" | |
shift right | |
:param x: w-bit word | |
:param n: | |
:param w: | |
:return: | |
""" | |
assert 0 <= n < w | |
assert 0 <= x < 2 ** w | |
return x >> n | |
############################################### | |
# Logical Function | |
############################################### | |
def ch(x, y, z): | |
return (x & y) ^ ((~x) & z) | |
def maj(x, y, z): | |
return (x & y) ^ (x & z) ^ (y & z) | |
def sum_0(x): | |
return rotr(2, x) ^ rotr(13, x) ^ rotr(22, x) | |
def sum_1(x): | |
return rotr(6, x) ^ rotr(11, x) ^ rotr(25, x) | |
def sig_0(x): | |
return rotr(7, x) ^ rotr(18, x) ^ shr(3, x) | |
def sig_1(x): | |
return rotr(17, x) ^ rotr(19, x) ^ shr(10, x) | |
p = Prime() | |
# cube roots of first 64 primes | |
K = list( | |
map( | |
lambda x: ValueConversion.hex_to_integer(ValueConversion.decimal_hex_digits(curt(x))), | |
p.numbers(64) | |
) | |
) | |
def sha256(bitstring: str): | |
""" | |
calculate sha256 hash value | |
>>> text = "Python" | |
>>> a = sha256(message_to_bitstring(text)) | |
>>> import hashlib | |
>>> b = hashlib.sha256(text.encode('ascii')).hexdigest() | |
>>> a == b | |
True | |
>>> text = "Hello" * 100 | |
>>> a = sha256(message_to_bitstring(text)) | |
>>> import hashlib | |
>>> b = hashlib.sha256(text.encode('ascii')).hexdigest() | |
>>> a == b | |
True | |
:param bitstring: | |
:return: | |
""" | |
global K, p | |
H = list(map(lambda x: ValueConversion.hex_to_integer(ValueConversion.decimal_hex_digits(sqrt(x))), p.numbers(8))) | |
w = 32 | |
bitstring = padding(bitstring) | |
M = split_blocks(bitstring) | |
for m in M: | |
W = list(map(lambda x: ValueConversion.bitstring_to_integer(x), split_words(m))) | |
for t in range(16, 64): | |
W.append( | |
(sig_1(W[t - 2]) + W[t - 7] + sig_0(W[t - 15]) + W[t - 16]) % (2 ** w) | |
) | |
a, b, c, d, e, f, g, h = H | |
for t in range(64): | |
T1 = (h + sum_1(e) + ch(e, f, g) + K[t] + W[t]) % (2 ** w) | |
T2 = (sum_0(a) + maj(a, b, c)) % (2 ** w) | |
h = g | |
g = f | |
f = e | |
e = (d + T1) % (2 ** w) | |
d = c | |
c = b | |
b = a | |
a = (T1 + T2) % (2 ** w) | |
H = list(map( | |
lambda x: x % (2 ** w), | |
[a + H[0], b + H[1], c + H[2], d + H[3], e + H[4], f + H[5], g + H[6], h + H[7]] | |
)) | |
return "".join(ValueConversion.integer_to_hexstring(h, 8) for h in H) | |
if __name__ == '__main__': | |
import doctest | |
doctest.testmod() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment