Skip to content

Instantly share code, notes, and snippets.

@windshadow233
Last active November 21, 2024 17:40
Show Gist options
  • Save windshadow233/229ec53e67577bedb8965e652fdc7466 to your computer and use it in GitHub Desktop.
Save windshadow233/229ec53e67577bedb8965e652fdc7466 to your computer and use it in GitHub Desktop.
MT19937 Predictor Written in Python
import hashlib
import math
import os
import struct
import time
from itertools import repeat, accumulate
from bisect import bisect
from typing import Set, Sequence
def _int32(x):
return int(x & 0xffffffff)
class MT19937Predictor:
def __init__(self, twist_per_step=True):
"""
:param twist_per_step:
Determines how often the twist operation is performed.
- False: Behaves like CPython's random (performs 624 twist operations when _mti reaches 624).
- True: Performs one twist operation each time a 32-bit random number is extracted.
"""
self.N = 624
self.M = 397
self.MATRIX_A = 0x9908b0df
self.UPPER_MASK = 0x80000000
self.LOWER_MASK = 0x7fffffff
self._mt = [0] * self.N
self._mti = 624
self.gauss_next = None
self._twist_per_step = twist_per_step
def getstate(self):
return 3, (*self._mt, self._mti), self.gauss_next
def setstate(self, state):
*self._mt, self._mti = state[1]
self._mt = list(self._mt)
self.gauss_next = state[2]
def init_genrand(self, seed):
self._mt[0] = _int32(seed)
for i in range(1, self.N):
self._mt[i] = _int32(1812433253 * (self._mt[i - 1] ^ (self._mt[i - 1] >> 30)) + i)
self._mti = self.N
def init_by_array(self, key, length):
self.init_genrand(19650218)
i, j = 1, 0
k = max(self.N, length)
for _ in range(k):
self._mt[i] = (self._mt[i] ^ ((self._mt[i - 1] ^ (self._mt[i - 1] >> 30)) * 1664525)) + key[j] + j
self._mt[i] = _int32(self._mt[i])
i += 1
j += 1
if i >= self.N:
self._mt[0] = self._mt[self.N - 1]
i = 1
if j >= length:
j = 0
for _ in range(self.N - 1):
self._mt[i] = (self._mt[i] ^ ((self._mt[i - 1] ^ (self._mt[i - 1] >> 30)) * 1566083941)) - i
self._mt[i] = _int32(self._mt[i])
i += 1
if i >= self.N:
self._mt[0] = self._mt[self.N - 1]
i = 1
self._mt[0] = 0x80000000
def random_seed_urandom(self):
random_bytes = os.urandom(self.N * 4)
key = [int.from_bytes(random_bytes[i:i + 4], 'big') for i in range(0, self.N * 4, 4)]
self.init_by_array(key, len(key))
def random_seed_time_pid(self):
now = int(time.time())
now_mono = int(time.monotonic_ns())
key = [_int32(now),
now >> 32,
_int32(os.getpid()),
_int32(now_mono),
now_mono >> 32]
self.init_by_array(key, 5)
@staticmethod
def tempering(y):
y ^= (y >> 11)
y ^= (y << 7) & 0x9d2c5680
y ^= (y << 15) & 0xefc60000
y ^= (y >> 18)
return y
@staticmethod
def untempering(y):
y ^= (y >> 18)
y ^= (y << 15) & 0xefc60000
y ^= ((y << 7) & 0x9d2c5680) ^ ((y << 14) & 0x94284000) ^ ((y << 21) & 0x14200000) ^ ((y << 28) & 0x10000000)
y ^= (y >> 11) ^ (y >> 22)
return y
def extract_number(self):
if self._twist_per_step:
self._mti %= self.N
self.twist(self._mti)
else:
if self._mti >= self.N:
for i in range(self.N):
self.twist(i)
self._mti = 0
y = self._mt[self._mti]
y = self.tempering(y)
self._mti += 1
return _int32(y)
def unextract_number(self):
self._mti = (self._mti - 1) % self.N
if self._twist_per_step:
self.untwist(self._mti)
if self._mti == 0:
self._mti = self.N
else:
if self._mti == 0:
for i in range(self.N - 1, -1, -1):
self.untwist(i)
self._mti = self.N
def twist(self, i):
y = (self._mt[i] & self.UPPER_MASK) | (self._mt[(i + 1) % self.N] & self.LOWER_MASK)
self._mt[i] = y >> 1
if y & 0x1 == 1:
self._mt[i] ^= self.MATRIX_A
self._mt[i] ^= self._mt[(i + self.M) % self.N]
def untwist(self, i):
tmp = self._mt[i] ^ self._mt[(i + self.M) % self.N]
if tmp & self.UPPER_MASK == self.UPPER_MASK:
tmp ^= self.MATRIX_A
tmp <<= 1
tmp |= 1
else:
tmp <<= 1
res = tmp & self.UPPER_MASK
tmp = self._mt[i - 1] ^ self._mt[(i + self.M - 1) % self.N]
if tmp & self.UPPER_MASK == self.UPPER_MASK:
tmp ^= self.MATRIX_A
tmp <<= 1
tmp |= 1
else:
tmp <<= 1
res |= tmp & self.LOWER_MASK
self._mt[i] = res
def _seed(self, a=None):
if a is None:
try:
self.random_seed_urandom()
return
except Exception:
self.random_seed_time_pid()
return
if isinstance(a, int):
n = abs(a)
else:
hash_val = hashlib.sha512(str(a).encode('utf-8')).digest()
n = int.from_bytes(hash_val, 'big')
bits = n.bit_length()
keyused = 1 if bits == 0 else (bits + 31) // 32
key = [_int32(n >> (32 * i)) for i in range(keyused)]
if struct.pack("=I", 1)[0] == 0:
key.reverse()
self.init_by_array(key, keyused)
def seed(self, a=None):
if isinstance(a, (str, bytes, bytearray)):
if isinstance(a, str):
a = a.encode()
a += hashlib.sha512(a).digest()
a = int.from_bytes(a, 'big')
self._seed(a)
def setrand_int32(self, y):
assert 0 <= y < 2 ** 32
self._mti %= self.N
self._mt[self._mti] = self.untempering(y)
self._mti += 1
def setrandbits(self, y, bits):
if not (bits % 32 == 0):
raise ValueError('number of bits must be a multiple of 32')
if not (0 <= y < 2 ** bits):
raise ValueError('invalid state')
if bits == 32:
self.setrand_int32(y)
else:
while bits > 0:
self.setrand_int32(y & 0xffffffff)
y >>= 32
bits -= 32
def getrandbits(self, k):
assert k > 0
num_bits_per_call = 32
result = 0
bits_needed = k
i = 0
while bits_needed > 0:
random_number = self.extract_number()
if bits_needed >= num_bits_per_call:
result = (random_number << num_bits_per_call * i) | result
bits_needed -= num_bits_per_call
else:
random_number >>= (num_bits_per_call - bits_needed)
result = (random_number << num_bits_per_call * i) | result
bits_needed = 0
i += 1
return result
def _randbelow(self, n):
assert n > 0
k = n.bit_length()
r = self.getrandbits(k)
while r >= n:
r = self.getrandbits(k)
return r
def randrange(self, start, stop=None, step=1):
if stop is None:
if start > 0:
return self._randbelow(start)
raise ValueError("empty range for randrange()")
width = stop - start
if step == 1 and width > 0:
return start + self._randbelow(width)
if step == 1:
raise ValueError("empty range for randrange() (%d, %d, %d)" % (start, stop, width))
if step > 0:
n = (width + step - 1) // step
elif step < 0:
n = (width + step + 1) // step
else:
raise ValueError("zero step for randrange()")
if n <= 0:
raise ValueError("empty range for randrange()")
return start + step * self._randbelow(n)
def randint(self, a, b):
return self.randrange(a, b + 1)
def random(self):
a = self.extract_number() >> 5
b = self.extract_number() >> 6
return (a * 67108864.0 + b) / 9007199254740992.0
def uniform(self, a, b):
return a + (b - a) * self.random()
def choice(self, seq):
try:
i = self._randbelow(len(seq))
except ValueError:
raise IndexError('Cannot choose from an empty sequence') from None
return seq[i]
def choices(self, population, weights=None, *, cum_weights=None, k=1):
random = self.random
n = len(population)
if cum_weights is None:
if weights is None:
_int = int
n += 0.0 # convert to float for a small speed improvement
return [population[_int(random() * n)] for i in repeat(None, k)]
cum_weights = list(accumulate(weights))
elif weights is not None:
raise TypeError('Cannot specify both weights and cumulative weights')
if len(cum_weights) != n:
raise ValueError('The number of weights does not match the population')
total = cum_weights[-1] + 0.0 # convert to float
hi = n - 1
return [population[bisect(cum_weights, random() * total, 0, hi)]
for i in repeat(None, k)]
def shuffle(self, x, random=None):
if random is None:
randbelow = self._randbelow
for i in reversed(range(1, len(x))):
j = randbelow(i + 1)
x[i], x[j] = x[j], x[i]
else:
for i in reversed(range(1, len(x))):
j = int(random() * (i + 1))
x[i], x[j] = x[j], x[i]
def gauss(self, mu, sigma):
random = self.random
z = self.gauss_next
self.gauss_next = None
if z is None:
x2pi = random() * 2 * math.pi
g2rad = math.sqrt(-2.0 * math.log(1.0 - random()))
z = math.cos(x2pi) * g2rad
self.gauss_next = math.sin(x2pi) * g2rad
return mu + z * sigma
def sample(self, population, k):
if isinstance(population, Set):
population = tuple(population)
if not isinstance(population, Sequence):
raise TypeError("Population must be a sequence or set. For dicts, use list(d).")
randbelow = self._randbelow
n = len(population)
if not 0 <= k <= n:
raise ValueError("Sample larger than population or is negative")
result = [None] * k
setsize = 21
if k > 5:
setsize += 4 ** math.ceil(math.log(k * 3, 4))
if n <= setsize:
pool = list(population)
for i in range(k):
j = randbelow(n - i)
result[i] = pool[j]
pool[j] = pool[n - i - 1]
else:
selected = set()
selected_add = selected.add
for i in range(k):
j = randbelow(n)
while j in selected:
j = randbelow(n)
selected_add(j)
result[i] = population[j]
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment