Skip to content

Instantly share code, notes, and snippets.

@KenjiOhtsuka
Last active December 1, 2023 22:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save KenjiOhtsuka/822aa5d17137d9065622740b3573e5e8 to your computer and use it in GitHub Desktop.
Save KenjiOhtsuka/822aa5d17137d9065622740b3573e5e8 to your computer and use it in GitHub Desktop.
SHA-256 Calculation in Python
# メッセージをビット列に変換する関数
def message_to_bitstring(message):
bitstring = ""
for char in message:
# ASCII コードを 8 ビットのビット列に変換する
bit = bin(ord(char))[2:].zfill(8)
bitstring += bit
return bitstring
# メッセージにパディングを行う関数
def padding(bitstring):
# メッセージの長さを保存する
length = len(bitstring)
# メッセージの末尾に 1 ビットの 1 を追加する
bitstring += "1"
# メッセージの長さが 448 mod 512 になるまで 0 ビットを追加する
remainder = len(bitstring) % 512
bitstring += "0" * ((448 - remainder) + (remainder > 448) * 512)
# メッセージの長さを 64 ビットのビット列に変換し、メッセージの末尾に追加する
bitstring += bin(length)[2:].zfill(64)
# この時点で、 448 + 64 = 512 ビットになっている
# 64 bit で表される数値の最大値は 2^64 - 1 = 18446744073709551615 であり
# これを超えるメッセージは SHA-256 でハッシュ化できない
# 厳密には 2^64 - 1 - 1 - 64 = 18446744073709551550 ビットまでしかハッシュ化できない
return bitstring
# ビット列を 512 ビットずつのブロックに分割する関数
def split_blocks(bitstring, block_size=512):
blocks = []
# ビット列を 512 ビットずつに切り出す
for i in range(0, len(bitstring), block_size):
block = bitstring[i:i+block_size]
blocks.append(block)
return blocks
# ブロックを 32 ビットずつのワードに分割する関数
def split_words(block, w=32):
words = []
# ブロックを w ビットずつに切り出す
for i in range(0, len(block), w):
word = block[i:i+w]
words.append(word)
return words
###############################################
# Value Conversion
###############################################
class ValueConversion:
@staticmethod
def _int_to_hex(num):
assert 0 <= num < 16
if num < 10:
return str(num)
else:
return chr(ord('a') + num - 10)
@staticmethod
def _hex_to_int(char):
if char.isdigit():
return int(char)
else:
return ord(char) - ord('a') + 10
@staticmethod
def decimal_hex_digits(num):
h = 2 ** 4 # 16
u = 1 / h
s = ""
num %= 1
for i in range(8):
d = num // u
s += ValueConversion._int_to_hex(int(d))
num = (num - u * d)
num *= h
return s
@staticmethod
def hex_to_integer(s):
"""
convert hex string to integer
example:
>>> ValueConversion.hex_to_integer("1")
1
>>> ValueConversion.hex_to_integer("a")
10
>>> ValueConversion.hex_to_integer("ff")
255
:param s:
:return:
"""
n = 0
for c in s:
assert c in "0123456789abcdef"
n <<= 4
n += int(c, 16)
return n
@staticmethod
def bitstring_to_integer(s: str) -> int:
"""
convert bit string to integer
example:
>>> ValueConversion.bitstring_to_integer("1")
1
>>> ValueConversion.bitstring_to_integer("101")
5
>>> ValueConversion.bitstring_to_integer("10101")
21
:param s:
:return:
"""
n = 0
for c in s:
assert c in "01"
n <<= 1
n += int(c)
return n
@staticmethod
def integer_to_bitstring(n: int, l: int = 0) -> str:
"""
convert integer to bit string
example:
>>> ValueConversion.integer_to_bitstring(1)
'1'
>>> ValueConversion.integer_to_bitstring(5)
'101'
:param n:
:param l:
:return:
"""
s = ""
while n > 0:
s = str(n % 2) + s
n //= 2
if l > 0:
s = "0" * (l - len(s)) + s
return s
@staticmethod
def integer_to_hexstring(n: int, l: int = 0) -> str:
"""
convert integer to hex string
example:
>>> ValueConversion.integer_to_hexstring(1)
'1'
>>> ValueConversion.integer_to_hexstring(10)
'a'
>>> ValueConversion.integer_to_hexstring(255)
'ff'
:param n:
:param l:
:return:
"""
s = ""
while n > 0:
s = ValueConversion._int_to_hex(n % 16) + s
n //= 16
if l > 0:
s = "0" * (l - len(s)) + s
return s
@staticmethod
def hexstring_to_bitstring(hexstring: str) -> str:
"""
convert hex string to bit string
example:
>>> ValueConversion.hexstring_to_bitstring("1")
'0001'
>>> ValueConversion.hexstring_to_bitstring("a")
'1010'
>>> ValueConversion.hexstring_to_bitstring("ff")
'11111111'
:param hexstring:
:return:
"""
for c in hexstring:
assert c in "0123456789abcdef"
def _hex_to_bitstring(c):
n = ValueConversion._hex_to_int(c)
s = ""
for _ in range(4):
s = str(n % 2) + s
n >>= 1
return s
return "".join(map(lambda x: _hex_to_bitstring(x), hexstring))
###############################################
# Basic Arithmetic
###############################################
def sqrt(n: int, p: int = 16):
"""
calculate square root
>>> 2 - sqrt(2) ** 2 < 10 ** -15
True
>>> 3 - sqrt(3) ** 2 < 10 ** -15
True
>>> 121 - sqrt(121) ** 2 < 10 ** -12
True
>>> 700 - sqrt(700) ** 2 < 10 ** -12
True
>>> 999 - sqrt(999) ** 2 < 10 ** -12
True
not precise but sufficient for sha-256 calculation.
:param n: number that is not more than 100
:param p:
:return:
"""
s = 0
while n > 100:
n /= 100
s += 1
assert 0 < n <= 100
r = 0
a = 0
for j in range(p):
for i in range(9, -1, -1):
if (r + i) * i <= n:
a += i / (10 ** j)
n = (n - (r + i) * i) * 100
r = (r + i * 2) * 10
break
return a * (10 ** s)
def curt(n: int, p: int = 16):
"""
calculate cube root
>>> 2 - curt(2) ** 3 < 10 ** -14
True
>>> 3 - curt(3) ** 3 < 10 ** -14
True
>>> 300 - curt(300) ** 3 < 10 ** -13
True
>>> 999 - curt(999) ** 3 < 10 ** -12
True
:param n: number that is not more than 1000
:param p:
:return:
"""
assert 0 < n <= 1000
r1 = 0
r2 = 0
a = 0
for j in range(p):
for i in range(9, -1, -1):
if (r2 + (r1 + i) * i) * i <= n:
n = (n - (r2 + (r1 + i) * i) * i) * 1000
r2 = (r2 + 2 * (r1 + i) * i + i ** 2) * 100
r1 = (r1 + i * 3) * 10
a += i / (10 ** j)
break
return a
class Prime:
def __init__(self):
self.nums = [2, 3, 5, 7, 11, 13, 17, 19, 23]
self.N = self.nums[-1]
def _append_if_prime(self, n):
"""
append n to self.nums if n is prime
>>> Prime()._append_if_prime(121)
False
:param n:
:return:
"""
if n in self.nums:
return True
if n <= self.N:
return False
m = int(sqrt(n) + 0.001)
for i in self.nums:
if i > m:
self.nums.append(n)
self.N = n
return True
if n % i == 0:
self.N = n
return False
while i <= m:
if n % i == 0:
self.N = n
return False
i += 2
self.N = n
self.nums.append(n)
return True
def numbers(self, num):
c = len(self.nums)
n = self.N + 1 + self.N % 2
while c < num:
if self._append_if_prime(n):
c += 1
n += 2
return self.nums[:num]
###############################################
# Bit Operations
###############################################
def rotl(n: int, x, w=32):
"""
rotate left
:param x: w-bit word
:param n:
:param w:
:return:
"""
assert 0 <= n < w
assert 0 <= x < 2 ** w
return (x << n) | (x >> (w - n))
def rotr(n: int, x, w=32):
"""
rotate right
:param x: w-bit word
:param n:
:param w:
:return:
"""
assert 0 <= n < w
assert 0 <= x < 2 ** w
return (x >> n) | (x << (w - n))
def shr(n: int, x, w=32):
"""
shift right
:param x: w-bit word
:param n:
:param w:
:return:
"""
assert 0 <= n < w
assert 0 <= x < 2 ** w
return x >> n
###############################################
# Logical Function
###############################################
def ch(x, y, z):
return (x & y) ^ ((~x) & z)
def maj(x, y, z):
return (x & y) ^ (x & z) ^ (y & z)
def sum_0(x):
return rotr(2, x) ^ rotr(13, x) ^ rotr(22, x)
def sum_1(x):
return rotr(6, x) ^ rotr(11, x) ^ rotr(25, x)
def sig_0(x):
return rotr(7, x) ^ rotr(18, x) ^ shr(3, x)
def sig_1(x):
return rotr(17, x) ^ rotr(19, x) ^ shr(10, x)
p = Prime()
# cube roots of first 64 primes
K = list(
map(
lambda x: ValueConversion.hex_to_integer(ValueConversion.decimal_hex_digits(curt(x))),
p.numbers(64)
)
)
def sha256(bitstring: str):
"""
calculate sha256 hash value
>>> text = "Python"
>>> a = sha256(message_to_bitstring(text))
>>> import hashlib
>>> b = hashlib.sha256(text.encode('ascii')).hexdigest()
>>> a == b
True
>>> text = "Hello" * 100
>>> a = sha256(message_to_bitstring(text))
>>> import hashlib
>>> b = hashlib.sha256(text.encode('ascii')).hexdigest()
>>> a == b
True
:param bitstring:
:return:
"""
global K, p
H = list(map(lambda x: ValueConversion.hex_to_integer(ValueConversion.decimal_hex_digits(sqrt(x))), p.numbers(8)))
w = 32
bitstring = padding(bitstring)
M = split_blocks(bitstring)
for m in M:
W = list(map(lambda x: ValueConversion.bitstring_to_integer(x), split_words(m)))
for t in range(16, 64):
W.append(
(sig_1(W[t - 2]) + W[t - 7] + sig_0(W[t - 15]) + W[t - 16]) % (2 ** w)
)
a, b, c, d, e, f, g, h = H
for t in range(64):
T1 = (h + sum_1(e) + ch(e, f, g) + K[t] + W[t]) % (2 ** w)
T2 = (sum_0(a) + maj(a, b, c)) % (2 ** w)
h = g
g = f
f = e
e = (d + T1) % (2 ** w)
d = c
c = b
b = a
a = (T1 + T2) % (2 ** w)
H = list(map(
lambda x: x % (2 ** w),
[a + H[0], b + H[1], c + H[2], d + H[3], e + H[4], f + H[5], g + H[6], h + H[7]]
))
return "".join(ValueConversion.integer_to_hexstring(h, 8) for h in H)
if __name__ == '__main__':
import doctest
doctest.testmod()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment