Skip to content

Instantly share code, notes, and snippets.

@robertmaxwilliams
Last active February 2, 2021 22:25
Show Gist options
  • Save robertmaxwilliams/f823d9948a894227ff810dd9c4660712 to your computer and use it in GitHub Desktop.
Save robertmaxwilliams/f823d9948a894227ff810dd9c4660712 to your computer and use it in GitHub Desktop.
Simple and mildly ugly (and not very fast) implementation of 32 bit Mersenne Twister MT19937. mers_min.c is a shortened verion, mersenne_twister.c is the longer, parameterized version.
// Compile with gcc -std=c99 mersenne_twister.c -o /dev/random
// ps the -o /dev/random part is a joke, it's not actually meant to be a random device.
// Absolutely no warranty, Kopyleft I guess who gives a shit
#include <stdint.h>
#include <stdio.h>
uint32_t x[624];
uint32_t index = 0;
void twist() {
for (int k = 0; k < 624; k++) {
x[k] = x[(k+397)%624]
^ (x[k]&0x80000000 | x[(k+1)%624]&0x7fffffff) >> 1
^ (x[(k+1)%624]&1) * 0x9908B0DF;
}
}
void initialize(uint32_t seed) {
index = 0;
x[0] = seed;
for (int i = 1; i < 624; i++) {
x[i] = 1812433253UL * (x[i-1] ^ x[i-1] >> 30) + i;
}
twist();
}
uint32_t get_number() {
if (index >= 624) {
twist();
index = 0;
}
uint32_t y = x[index++];
y ^= y >> 11;
y ^= y << 7 & 0x9D2C5680;
y ^= y << 15 & 0xEFC60000;
y ^= y >> 18;
return y;
}
int main() {
initialize(5489u);
int foo;
for (int i = 0; i < 10000; i++) {
foo = get_number();
}
printf("%08x\n", foo);
initialize(5489u);
for (int i = 0; i < 100; i++) {
printf("%08x ", get_number());
}
printf("\n");
}
// Compile with gcc -std=c99 mersenne_twister.c -o /dev/random
// ps the -o /dev/random part is a joke, it's not actually meant to be a random device.
// Absolutely no warranty, Kopyleft I guess who gives a shit
#include <stdint.h>
#include <stdio.h>
// from here: https://en.wikipedia.org/wiki/Mersenne_Twister
// constants for MT19937
typedef uint32_t word_t;
typedef uint32_t scal_t; // uint scalar type
const scal_t w = 8 * sizeof(word_t); // should be 32 lol
const scal_t n = 624; // degree of recurrence
//const scal_t m = 329; // midle word, 1<=m<n
const scal_t m = 397; // midle word, 1<=m<n
const scal_t r = 31; // seperation point of one word, number of words in lower bitmask.
// 0 <= r <= w-1
const word_t lower_mask = (1 << r) - 1;
const word_t upper_mask = ~lower_mask;
const word_t a = 0x9908B0DF; //coefficent of rational normal form twist matrix
const word_t b = 0x9D2C5680; // TGSFSR bitmasks
const word_t c = 0xEFC60000; // TGSFSR bitmasks
const scal_t s = 7; // TGSFR bit shifts
const scal_t t = 15; // TGSFR bit shifts
const scal_t u = 11; // additional shifts/masks
const word_t d = 0xFFFFFFFF;
const scal_t l = 18;
// restriction: 2^(nw-r) - 1 is Merscenne prime.
// Series x
// x_(k+n) = k(k+m) XOR ((upper(x_k, w - r), lower(x_{k+1}, r)) * A)
// Where A = / 0 I_{w-1} \
// \ a_{w-1} (a_{w-2}, ..., a_0) /
// with I_{n-1} as the n-1 x n-1 identity matrix
// (Note: we're in F2, so XOR takes the place of addition)
// So xA has the shortcut:
// xA = { x >> 1 if x_0 = 0
// { (x>>1) XOR a otherwise
// Tempering transform:
// y = x XOR ((x>>u) & d)
// y = y XOR ((y<<s) & b)
// y = y XOR ((y<<t) & c)
// z = y XOR (y<<l)
// Where x is the next value in the series, y is a local temp, and z is the outptut/
// restriction: s+t >= floor(w/2) - 1
word_t x[n]; // The state
const scal_t f = 1812433253UL;
void twist() {
// do the main thing, and twist
// I had this wrong until I looked at
// https://create.stephan-brumme.com/mersenne-twister/
// Also, the %m is equivalent to the 3 loops, the 3 loops are just faster
for (int k = 0; k < n; k++) {
word_t bits = (x[k]&upper_mask) | (x[(k+1)%n]&lower_mask);
x[k] = x[(k+m)%n] ^ (bits >> 1) ^ (bits%2 == 0 ? 0 : a);
}
}
scal_t index = 0;
void initialize(word_t seed) {
index = 0;
for (int i = 0; i < n; i++) {
x[i] = 0;
}
x[0] = seed;
for (int i = 1; i < n; i++) {
x[i] = f * (x[i-1] ^ (x[i-1] >> (w-2))) + i;
}
twist();
}
word_t get_number() {
if (index >= n) {
twist();
index = 0;
}
// Tempering
word_t y = x[index];
y = y ^ ((y >> u) & d);
y = y ^ ((y << s) & b);
y = y ^ ((y << t) & c);
y = y ^ (y >> l);
index += 1;
return y;
}
int main() {
initialize(5489u);
// The 10000th call to get_number should yield 4123659995
// so that should be the last number printed here
// in hex, it should be 0xF5CA0EDB
int foo;
for (int n = 0; n < 10000; n++) {
foo = get_number();
}
printf("%08x\n", foo);
printf("%u\n", foo);
printf("%d\n", foo);
// Now get some fat stacks of RNG
initialize(5489u);
for (int j = 0; j < 3; j++) {
for (int i = 0; i < 100; i++) {
printf("%08x ", get_number());
// Alternative ways to print
//printf("%01x", get_number() >> 28);
//printf("%u ", get_number());
}
printf("\n");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment