Skip to content

Instantly share code, notes, and snippets.

@Sam-Belliveau
Last active August 6, 2019 02:14
Show Gist options
  • Save Sam-Belliveau/6bea24259ddf1eb5203bd192907a34ee to your computer and use it in GitHub Desktop.
Save Sam-Belliveau/6bea24259ddf1eb5203bd192907a34ee to your computer and use it in GitHub Desktop.
Robust Implementation of ChaCha in python that is very extendable, and has the ability to generate high quality random numbers
from numpy import uint32
from numpy import uint64
from numpy import seterr as set_numpy_err
from numpy import geterr as get_numpy_err
class ChaCha:
class ByteSizeError(ValueError):
pass
class EndianError(ValueError):
pass
def __init__(self, key=b"", nonce=uint64(0xDEADDEADBEEFBEEF),
pos=0, constant=b"expand 32-byte k",
stream_size=16, rounds=20, endian='big'):
self._state = [uint32(0)] * 16
self._stream_size = int(stream_size)
self._rounds = int(rounds)
self._endian = endian.lower()
self.reset(key, nonce, pos, constant)
def _check_for_valid_inputs(self):
# Check if key is correct length
if len(self._key) > 32:
raise self.ByteSizeError("Length of 'key' is too big! ({} > 32)"
.format(len(self._key)))
# Check if constant is correct length
if len(self._constant) > 16:
raise self.ByteSizeError("Length of 'constant' is too big! ({} > 16)"
.format(len(self._constant)))
# Check if settings are safe
if self._rounds < 12:
print("ChaCha-Warning: Stream Generation is Insecure! (Rounds < 12)")
if self._key == bytearray([0] * len(self._key)):
print("ChaCha-Warning: Stream Generation is Insecure! (Key is Uninitialized)")
# Check if endian is valid
if not self._endian.lower() in ['big', 'little']:
raise self.EndianError("Invalid setting for 'endian'! ({})"
.format(self._endian))
# Check if endian is valid
if not isinstance(self._rounds, int):
raise TypeError("Invalid type for setting 'rounds'!")
# Check if stream_size is valid
if not isinstance(self._stream_size, int):
raise TypeError("Invalid type for setting 'stream_size'!")
def reset(self, key=None, nonce=None, pos=0, constant=None):
# Set all of the state settings
if key is not None:
self._key = bytearray(key)
self._key += bytearray([0] * (32 - len(self._key)))
if nonce is not None:
self._nonce = uint64(nonce)
if constant is not None:
self._constant = bytearray(constant)
self._constant += bytearray([0] * (16 - len(self._key)))
# Check inputs for size errors
self._check_for_valid_inputs()
# Constant Row
for i in range(0, 4):
self._state[i + 0] = uint32(
int.from_bytes(self._constant[i*4:i*4 + 4], byteorder=self._endian)
)
# Key Row
for i in range(0, 8):
self._state[i + 4] = uint32(
int.from_bytes(self._key[i*4:i*4 + 4], byteorder=self._endian)
)
# Update Position
self.set_pos(pos)
# Nonce Words
if self._endian == 'little':
self._state[14] = uint32((self._nonce >> uint64(00)) & uint64(0xffffffff))
self._state[15] = uint32((self._nonce >> uint64(32)) & uint64(0xffffffff))
elif self._endian == 'big':
self._state[14] = uint32((self._nonce >> uint64(32)) * uint64(0xffffffff))
self._state[15] = uint32((self._nonce >> uint64(00)) & uint64(0xffffffff))
def set_key(self, key):
self.reset(key=key)
def set_nonce(self, nonce):
self.reset(nonce=nonce)
def set_constant(self, constant):
self.reset(constant=constant)
def set_pos(self, pos=None):
if pos is not None:
self._pos = uint64(pos)
if self._endian == 'little':
self._state[12] = uint32((self._pos >> uint64(00)) & uint64(0xffffffff))
self._state[13] = uint32((self._pos >> uint64(32)) & uint64(0xffffffff))
elif self._endian == 'big':
self._state[12] = uint32((self._pos >> uint64(32)) * uint64(0xffffffff))
self._state[13] = uint32((self._pos >> uint64(00)) & uint64(0xffffffff))
def get(self):
return {
'state': self._state, 'key': self._key,
'nonce': self._nonce, 'constant': self._constant,
'pos': self._pos, 'rounds': self._rounds,
'stream_size': self._stream_size, 'endian': self._endian
}
# Quarter Round, mixes up indexs
def _quarter_round(self, ia, ib, ic, id):
def left_rotate(x, r):
return ((uint32(x) << uint32(r)) |
(uint32(x) >> uint32(32 - r)))
self._state[ia] += self._state[ib]
self._state[id] ^= self._state[ia]
self._state[id] = left_rotate(self._state[id], 16)
self._state[ic] += self._state[id]
self._state[ib] ^= self._state[ic]
self._state[ib] = left_rotate(self._state[ib], 12)
self._state[ia] += self._state[ib]
self._state[id] ^= self._state[ia]
self._state[id] = left_rotate(self._state[id], 8)
self._state[ic] += self._state[id]
self._state[ib] ^= self._state[ic]
self._state[ib] = left_rotate(self._state[ib], 7)
def _iter_state(self):
over_err = get_numpy_err()['over']
set_numpy_err(over='ignore')
self.set_pos()
for i in range(0, self._rounds):
if (i & 1) == 0:
# Even Round
self._quarter_round(0, 4, 8, 12)
self._quarter_round(1, 5, 9, 13)
self._quarter_round(2, 6, 10, 14)
self._quarter_round(3, 7, 11, 15)
else:
# Odd Round
self._quarter_round(0, 5, 10, 15)
self._quarter_round(1, 6, 11, 12)
self._quarter_round(2, 7, 8, 13)
self._quarter_round(3, 4, 9, 14)
self._pos += uint64(1)
set_numpy_err(over=over_err)
def next_stream(self, stream_size=None):
# Convert 32 bit number into bytes
def word_to_bytes(word, endian):
if endian == 'big':
return bytearray([
int((uint32(word) >> uint32(24)) & uint32(0xff)),
int((uint32(word) >> uint32(16)) & uint32(0xff)),
int((uint32(word) >> uint32(8)) & uint32(0xff)),
int((uint32(word) >> uint32(0)) & uint32(0xff))
])
elif endian == 'little':
return bytearray([
int((uint32(word) >> uint32(0)) & uint32(0xff)),
int((uint32(word) >> uint32(8)) & uint32(0xff)),
int((uint32(word) >> uint32(16)) & uint32(0xff)),
int((uint32(word) >> uint32(24)) & uint32(0xff))
])
# Default State
if stream_size is None:
stream_size = self._stream_size
# Output Array
output = bytearray()
# Add state to output
while len(output) < stream_size:
# While there is stream left to make, Shuffle State
self._iter_state()
# Add bytes to output
for index in [0, 1, 2, 3, 12, 13, 14, 15]:
output += word_to_bytes(self._state[index], self._endian)
# Resize and output bytes
return output[:stream_size]
def next_raw_int(self, bits=32, signed=False):
# Check if bits is an int
if not isinstance(bits, int):
raise TypeError("Setting 'bits' must be an int!")
# Round bits up by 8
byte_len = int((bits + 7) >> 3)
# Get bytes
byte_pool = self.next_stream(byte_len)
# Convert Bytes
out_int = int.from_bytes(byte_pool, byteorder=self._endian, signed=signed)
# If bits was not multiple of 8, mask bits
if bits & 0b111 != 0:
out_int &= 2**bits - 1
return out_int
def next_real(self, min=0.0, max=1.0, bit_depth=64):
if not isinstance(bit_depth, int):
raise TypeError("Setting 'bit_depth' must be an int!")
# Get raw bits
val = self.next_raw_int(bits=bit_depth) / (2.0**bit_depth)
# Scale Value
val *= max - min
val += min
return val
def next_int(self, min=0, max=2, bit_depth=64):
if not isinstance(min, int):
raise TypeError("Setting 'min' must be an int!")
if not isinstance(max, int):
raise TypeError("Setting 'max' must be an int!")
# Cast next_real
return int(self.next_real(min, max, bit_depth=bit_depth))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment