Last active
November 21, 2024 17:40
-
-
Save windshadow233/229ec53e67577bedb8965e652fdc7466 to your computer and use it in GitHub Desktop.
MT19937 Predictor Written in Python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import hashlib | |
import math | |
import os | |
import struct | |
import time | |
from itertools import repeat, accumulate | |
from bisect import bisect | |
from typing import Set, Sequence | |
def _int32(x): | |
return int(x & 0xffffffff) | |
class MT19937Predictor: | |
def __init__(self, twist_per_step=True): | |
""" | |
:param twist_per_step: | |
Determines how often the twist operation is performed. | |
- False: Behaves like CPython's random (performs 624 twist operations when _mti reaches 624). | |
- True: Performs one twist operation each time a 32-bit random number is extracted. | |
""" | |
self.N = 624 | |
self.M = 397 | |
self.MATRIX_A = 0x9908b0df | |
self.UPPER_MASK = 0x80000000 | |
self.LOWER_MASK = 0x7fffffff | |
self._mt = [0] * self.N | |
self._mti = 624 | |
self.gauss_next = None | |
self._twist_per_step = twist_per_step | |
def getstate(self): | |
return 3, (*self._mt, self._mti), self.gauss_next | |
def setstate(self, state): | |
*self._mt, self._mti = state[1] | |
self._mt = list(self._mt) | |
self.gauss_next = state[2] | |
def init_genrand(self, seed): | |
self._mt[0] = _int32(seed) | |
for i in range(1, self.N): | |
self._mt[i] = _int32(1812433253 * (self._mt[i - 1] ^ (self._mt[i - 1] >> 30)) + i) | |
self._mti = self.N | |
def init_by_array(self, key, length): | |
self.init_genrand(19650218) | |
i, j = 1, 0 | |
k = max(self.N, length) | |
for _ in range(k): | |
self._mt[i] = (self._mt[i] ^ ((self._mt[i - 1] ^ (self._mt[i - 1] >> 30)) * 1664525)) + key[j] + j | |
self._mt[i] = _int32(self._mt[i]) | |
i += 1 | |
j += 1 | |
if i >= self.N: | |
self._mt[0] = self._mt[self.N - 1] | |
i = 1 | |
if j >= length: | |
j = 0 | |
for _ in range(self.N - 1): | |
self._mt[i] = (self._mt[i] ^ ((self._mt[i - 1] ^ (self._mt[i - 1] >> 30)) * 1566083941)) - i | |
self._mt[i] = _int32(self._mt[i]) | |
i += 1 | |
if i >= self.N: | |
self._mt[0] = self._mt[self.N - 1] | |
i = 1 | |
self._mt[0] = 0x80000000 | |
def random_seed_urandom(self): | |
random_bytes = os.urandom(self.N * 4) | |
key = [int.from_bytes(random_bytes[i:i + 4], 'big') for i in range(0, self.N * 4, 4)] | |
self.init_by_array(key, len(key)) | |
def random_seed_time_pid(self): | |
now = int(time.time()) | |
now_mono = int(time.monotonic_ns()) | |
key = [_int32(now), | |
now >> 32, | |
_int32(os.getpid()), | |
_int32(now_mono), | |
now_mono >> 32] | |
self.init_by_array(key, 5) | |
@staticmethod | |
def tempering(y): | |
y ^= (y >> 11) | |
y ^= (y << 7) & 0x9d2c5680 | |
y ^= (y << 15) & 0xefc60000 | |
y ^= (y >> 18) | |
return y | |
@staticmethod | |
def untempering(y): | |
y ^= (y >> 18) | |
y ^= (y << 15) & 0xefc60000 | |
y ^= ((y << 7) & 0x9d2c5680) ^ ((y << 14) & 0x94284000) ^ ((y << 21) & 0x14200000) ^ ((y << 28) & 0x10000000) | |
y ^= (y >> 11) ^ (y >> 22) | |
return y | |
def extract_number(self): | |
if self._twist_per_step: | |
self._mti %= self.N | |
self.twist(self._mti) | |
else: | |
if self._mti >= self.N: | |
for i in range(self.N): | |
self.twist(i) | |
self._mti = 0 | |
y = self._mt[self._mti] | |
y = self.tempering(y) | |
self._mti += 1 | |
return _int32(y) | |
def unextract_number(self): | |
self._mti = (self._mti - 1) % self.N | |
if self._twist_per_step: | |
self.untwist(self._mti) | |
if self._mti == 0: | |
self._mti = self.N | |
else: | |
if self._mti == 0: | |
for i in range(self.N - 1, -1, -1): | |
self.untwist(i) | |
self._mti = self.N | |
def twist(self, i): | |
y = (self._mt[i] & self.UPPER_MASK) | (self._mt[(i + 1) % self.N] & self.LOWER_MASK) | |
self._mt[i] = y >> 1 | |
if y & 0x1 == 1: | |
self._mt[i] ^= self.MATRIX_A | |
self._mt[i] ^= self._mt[(i + self.M) % self.N] | |
def untwist(self, i): | |
tmp = self._mt[i] ^ self._mt[(i + self.M) % self.N] | |
if tmp & self.UPPER_MASK == self.UPPER_MASK: | |
tmp ^= self.MATRIX_A | |
tmp <<= 1 | |
tmp |= 1 | |
else: | |
tmp <<= 1 | |
res = tmp & self.UPPER_MASK | |
tmp = self._mt[i - 1] ^ self._mt[(i + self.M - 1) % self.N] | |
if tmp & self.UPPER_MASK == self.UPPER_MASK: | |
tmp ^= self.MATRIX_A | |
tmp <<= 1 | |
tmp |= 1 | |
else: | |
tmp <<= 1 | |
res |= tmp & self.LOWER_MASK | |
self._mt[i] = res | |
def _seed(self, a=None): | |
if a is None: | |
try: | |
self.random_seed_urandom() | |
return | |
except Exception: | |
self.random_seed_time_pid() | |
return | |
if isinstance(a, int): | |
n = abs(a) | |
else: | |
hash_val = hashlib.sha512(str(a).encode('utf-8')).digest() | |
n = int.from_bytes(hash_val, 'big') | |
bits = n.bit_length() | |
keyused = 1 if bits == 0 else (bits + 31) // 32 | |
key = [_int32(n >> (32 * i)) for i in range(keyused)] | |
if struct.pack("=I", 1)[0] == 0: | |
key.reverse() | |
self.init_by_array(key, keyused) | |
def seed(self, a=None): | |
if isinstance(a, (str, bytes, bytearray)): | |
if isinstance(a, str): | |
a = a.encode() | |
a += hashlib.sha512(a).digest() | |
a = int.from_bytes(a, 'big') | |
self._seed(a) | |
def setrand_int32(self, y): | |
assert 0 <= y < 2 ** 32 | |
self._mti %= self.N | |
self._mt[self._mti] = self.untempering(y) | |
self._mti += 1 | |
def setrandbits(self, y, bits): | |
if not (bits % 32 == 0): | |
raise ValueError('number of bits must be a multiple of 32') | |
if not (0 <= y < 2 ** bits): | |
raise ValueError('invalid state') | |
if bits == 32: | |
self.setrand_int32(y) | |
else: | |
while bits > 0: | |
self.setrand_int32(y & 0xffffffff) | |
y >>= 32 | |
bits -= 32 | |
def getrandbits(self, k): | |
assert k > 0 | |
num_bits_per_call = 32 | |
result = 0 | |
bits_needed = k | |
i = 0 | |
while bits_needed > 0: | |
random_number = self.extract_number() | |
if bits_needed >= num_bits_per_call: | |
result = (random_number << num_bits_per_call * i) | result | |
bits_needed -= num_bits_per_call | |
else: | |
random_number >>= (num_bits_per_call - bits_needed) | |
result = (random_number << num_bits_per_call * i) | result | |
bits_needed = 0 | |
i += 1 | |
return result | |
def _randbelow(self, n): | |
assert n > 0 | |
k = n.bit_length() | |
r = self.getrandbits(k) | |
while r >= n: | |
r = self.getrandbits(k) | |
return r | |
def randrange(self, start, stop=None, step=1): | |
if stop is None: | |
if start > 0: | |
return self._randbelow(start) | |
raise ValueError("empty range for randrange()") | |
width = stop - start | |
if step == 1 and width > 0: | |
return start + self._randbelow(width) | |
if step == 1: | |
raise ValueError("empty range for randrange() (%d, %d, %d)" % (start, stop, width)) | |
if step > 0: | |
n = (width + step - 1) // step | |
elif step < 0: | |
n = (width + step + 1) // step | |
else: | |
raise ValueError("zero step for randrange()") | |
if n <= 0: | |
raise ValueError("empty range for randrange()") | |
return start + step * self._randbelow(n) | |
def randint(self, a, b): | |
return self.randrange(a, b + 1) | |
def random(self): | |
a = self.extract_number() >> 5 | |
b = self.extract_number() >> 6 | |
return (a * 67108864.0 + b) / 9007199254740992.0 | |
def uniform(self, a, b): | |
return a + (b - a) * self.random() | |
def choice(self, seq): | |
try: | |
i = self._randbelow(len(seq)) | |
except ValueError: | |
raise IndexError('Cannot choose from an empty sequence') from None | |
return seq[i] | |
def choices(self, population, weights=None, *, cum_weights=None, k=1): | |
random = self.random | |
n = len(population) | |
if cum_weights is None: | |
if weights is None: | |
_int = int | |
n += 0.0 # convert to float for a small speed improvement | |
return [population[_int(random() * n)] for i in repeat(None, k)] | |
cum_weights = list(accumulate(weights)) | |
elif weights is not None: | |
raise TypeError('Cannot specify both weights and cumulative weights') | |
if len(cum_weights) != n: | |
raise ValueError('The number of weights does not match the population') | |
total = cum_weights[-1] + 0.0 # convert to float | |
hi = n - 1 | |
return [population[bisect(cum_weights, random() * total, 0, hi)] | |
for i in repeat(None, k)] | |
def shuffle(self, x, random=None): | |
if random is None: | |
randbelow = self._randbelow | |
for i in reversed(range(1, len(x))): | |
j = randbelow(i + 1) | |
x[i], x[j] = x[j], x[i] | |
else: | |
for i in reversed(range(1, len(x))): | |
j = int(random() * (i + 1)) | |
x[i], x[j] = x[j], x[i] | |
def gauss(self, mu, sigma): | |
random = self.random | |
z = self.gauss_next | |
self.gauss_next = None | |
if z is None: | |
x2pi = random() * 2 * math.pi | |
g2rad = math.sqrt(-2.0 * math.log(1.0 - random())) | |
z = math.cos(x2pi) * g2rad | |
self.gauss_next = math.sin(x2pi) * g2rad | |
return mu + z * sigma | |
def sample(self, population, k): | |
if isinstance(population, Set): | |
population = tuple(population) | |
if not isinstance(population, Sequence): | |
raise TypeError("Population must be a sequence or set. For dicts, use list(d).") | |
randbelow = self._randbelow | |
n = len(population) | |
if not 0 <= k <= n: | |
raise ValueError("Sample larger than population or is negative") | |
result = [None] * k | |
setsize = 21 | |
if k > 5: | |
setsize += 4 ** math.ceil(math.log(k * 3, 4)) | |
if n <= setsize: | |
pool = list(population) | |
for i in range(k): | |
j = randbelow(n - i) | |
result[i] = pool[j] | |
pool[j] = pool[n - i - 1] | |
else: | |
selected = set() | |
selected_add = selected.add | |
for i in range(k): | |
j = randbelow(n) | |
while j in selected: | |
j = randbelow(n) | |
selected_add(j) | |
result[i] = population[j] | |
return result |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment