Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Predict output of Mersenne Twister after seeing 624 values
#!/usr/bin/env python
#
# Mersenne Twister predictor
#
# Feed this program the output of any 32-bit MT19937 Mersenne Twister and
# after seeing 624 values it will correctly predict the rest.
#
# The values may come from any point in the sequence -- the program does not
# need to see the first 624 values, just *any* 624 consecutive values. The
# seed used is also irrelevant, and it will work even if the generator was
# seeded from /dev/random or any other high quality source.
#
# The values should be in decimal, one per line, on standard input.
#
# The program expects the actual unsigned 32 bit integer values taken directly
# from the output of the Mersenne Twister. It won't work if they've been
# scaled or otherwise modified, such as by using modulo or
# std::uniform_int_distribution to alter the distribution/range. In principle
# it would be possible to cope with such a scenario if you knew the exact
# parameters used by such an algorithm, but this program does not have any
# such knowledge.
#
# For more information, refer to the original 1998 paper:
#
# "Mersenne Twister: A 623-dimensionally equidistributed uniform pseudorandom
# number generator", Makoto Matsumoto, Takuji Nishimura, 1998
#
# http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.215.1141
#
# This code is not written with speed or efficiency in mind, but to follow
# as closely as possible to the terminology and naming in the paper.
#
# License: CC0 http://creativecommons.org/publicdomain/zero/1.0/
from __future__ import print_function
import sys
import collections
class Params:
# clearly a mathematician and not a programmer came up with these names
# because a dozen single-letter names would ordinarily be insane
w = 32 # word size
n = 624 # degree of recursion
m = 397 # middle term
r = 31 # separation point of one word
a = 0x9908b0df # bottom row of matrix A
u = 11 # tempering shift
s = 7 # tempering shift
t = 15 # tempering shift
l = 18 # tempering shift
b = 0x9d2c5680 # tempering mask
c = 0xefc60000 # tempering mask
def undo_xor_rshift(x, shift):
''' reverses the operation x ^= (x >> shift) '''
result = x
for shift_amount in range(shift, Params.w, shift):
result ^= (x >> shift_amount)
return result
def undo_xor_lshiftmask(x, shift, mask):
''' reverses the operation x ^= ((x << shift) & mask) '''
window = (1 << shift) - 1
for _ in range(Params.w // shift):
x ^= (((window & x) << shift) & mask)
window <<= shift
return x
def temper(x):
''' tempers the value to improve k-distribution properties '''
x ^= (x >> Params.u)
x ^= ((x << Params.s) & Params.b)
x ^= ((x << Params.t) & Params.c)
x ^= (x >> Params.l)
return x
def untemper(x):
''' reverses the tempering operation '''
x = undo_xor_rshift(x, Params.l)
x = undo_xor_lshiftmask(x, Params.t, Params.c)
x = undo_xor_lshiftmask(x, Params.s, Params.b)
x = undo_xor_rshift(x, Params.u)
return x
def upper(x):
''' return the upper (w - r) bits of x '''
return x & ((1 << Params.w) - (1 << Params.r))
def lower(x):
''' return the lower r bits of x '''
return x & ((1 << Params.r) - 1)
def timesA(x):
''' performs the equivalent of x*A '''
if x & 1:
return (x >> 1) ^ Params.a
else:
return (x >> 1)
seen = collections.deque(maxlen=Params.n)
print('waiting for {} previous inputs'.format(Params.n))
for _ in range(Params.n):
val = untemper(int(sys.stdin.readline()))
seen.append(val)
num_correct = num_incorrect = 0
print('ready to predict')
while True:
# The recurrence relation is:
#
# x[k + n] = x[k + m] ^ timesA(upper(x[k]) | lower(x[k + 1]))
#
# Substituting j = k + n gives
#
# x[j] = x[j - n + m] ^ timesA(upper(x[j - n]) | lower(x[j - n + 1]))
#
# The 'seen' deque holds the last 'n' seen values, where seen[-1] is the
# most recently seen, therefore letting j = 0 gives the equation for the
# next predicted value.
next_val = seen[-Params.n + Params.m] ^ timesA(
upper(seen[-Params.n]) | lower(seen[-Params.n + 1]))
seen.append(next_val)
predicted = temper(next_val)
actual = sys.stdin.readline()
if not actual:
print('end of input -- {} predicted correctly, {} failures'.format(
num_correct, num_incorrect))
sys.exit(0)
actual = int(actual)
if predicted == actual:
status = 'CORRECT'
num_correct += 1
else:
status = 'FAIL'
num_incorrect += 1
print('predicted {} got {} -- {}'.format(predicted, actual, status))
#include <iostream>
#include <random>
int main()
{
std::mt19937 mt{std::random_device{}()};
// optional -- advance to some arbitrary point in the sequence
mt.discard(8675309);
for(int i = 0; i < 624 + 10; i++) {
std::cout << mt() << '\n';
}
}
$ clang++ -Wall -Wextra -pedantic -std=c++14 -g -O2 testcase.cpp -o testcase
$ ./testcase | python mersenne_predict.py
waiting for 624 previous inputs
ready to predict
predicted 4191587413 got 4191587413 -- CORRECT
predicted 3803830168 got 3803830168 -- CORRECT
predicted 3060092316 got 3060092316 -- CORRECT
predicted 1846923718 got 1846923718 -- CORRECT
predicted 201058667 got 201058667 -- CORRECT
predicted 3665647888 got 3665647888 -- CORRECT
predicted 2520008872 got 2520008872 -- CORRECT
predicted 656224810 got 656224810 -- CORRECT
predicted 176364743 got 176364743 -- CORRECT
predicted 4252528975 got 4252528975 -- CORRECT
end of input -- 10 predicted correctly, 0 failures
$ ./testcase | python mersenne_predict.py
waiting for 624 previous inputs
ready to predict
predicted 329654205 got 329654205 -- CORRECT
predicted 539185382 got 539185382 -- CORRECT
predicted 2864548263 got 2864548263 -- CORRECT
predicted 1977004707 got 1977004707 -- CORRECT
predicted 2299828616 got 2299828616 -- CORRECT
predicted 2990397916 got 2990397916 -- CORRECT
predicted 1480674209 got 1480674209 -- CORRECT
predicted 3199083133 got 3199083133 -- CORRECT
predicted 1235004829 got 1235004829 -- CORRECT
predicted 771504621 got 771504621 -- CORRECT
end of input -- 10 predicted correctly, 0 failures
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.