Skip to content

Instantly share code, notes, and snippets.

@Deamon5550
Created October 6, 2017 00:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Deamon5550/5a14a8fce6028fe34f1311c3978a892e to your computer and use it in GitHub Desktop.
Save Deamon5550/5a14a8fce6028fe34f1311c3978a892e to your computer and use it in GitHub Desktop.
SIMD Mersenne Twister
#include <chrono>
#include <ctime>
#include <random>
#include <ratio>
#include <cmath>
#include <cinttypes>
#include "boost/random.hpp"
#include <x86intrin.h>
using namespace std;
using namespace std::chrono;
union m128_to_v4{
__m128i m;
uint32_t v[4];
};
class CustomTwister {
public:
uint32_t x[625];
int32_t i = 624 + 1;
CustomTwister() {seed(5489u);}
CustomTwister(uint32_t s) {seed(s);}
void seed(uint32_t s) {
x[0] = s;
for(int i = 1; i < 624; i++) {
x[i] = 1812433253 * (x[i - 1] ^ (x[i - 1] >> 30)) + i;
}
x[624] = x[0];
}
static inline __m128i muly(const __m128i &a, const __m128i &b)
{
#ifdef __SSE4_1__ // modern CPU - use SSE 4.1
return _mm_mullo_epi32(a, b);
#else // old CPU - use SSE 2
__m128i tmp1 = _mm_mul_epu32(a,b); /* mul 2,0*/
__m128i tmp2 = _mm_mul_epu32( _mm_srli_si128(a,4), _mm_srli_si128(b,4)); /* mul 3,1 */
return _mm_unpacklo_epi32(_mm_shuffle_epi32(tmp1, _MM_SHUFFLE (0,0,2,0)), _mm_shuffle_epi32(tmp2, _MM_SHUFFLE (0,0,2,0))); /* shuffle results to [63..0] and pack */
#endif
}
void twist() {
__m128i x00;
__m128i x01;
__m128i x02;
__m128i x03;
m128_to_v4 xA;
for(int i = 0; i < 224; i+=4) {
x00 = _mm_lddqu_si128((const __m128i*) (x + i));
x01 = _mm_set1_epi32(0x80000000u);
x00 = _mm_and_si128(x00, x01);
x02 = _mm_lddqu_si128((const __m128i*) (x + i + 1));
x03 = _mm_set1_epi32(0x7FFFFFFFu);
x02 = _mm_and_si128(x02, x03);
x00 = _mm_add_epi32(x00, x02);
x01 = _mm_srai_epi32(x00, 1);
x02 = _mm_set1_epi32(1);
x03 = _mm_set1_epi32(0x9908B0DFu);
x00 = _mm_and_si128(x01, x02);
x02 = muly(x00, x03);
x03 = _mm_xor_si128(x01, x02);
x01 = _mm_lddqu_si128((const __m128i*) (x + i + 397));
x02 = _mm_xor_si128(x03, x01);
_mm_storeu_si128((__m128i*) (x + i), x02);
}
{
x00 = _mm_lddqu_si128((const __m128i*) (x + 224));
x01 = _mm_set1_epi32(0x80000000u);
x00 = _mm_and_si128(x00, x01);
x02 = _mm_lddqu_si128((const __m128i*) (x + 225));
x03 = _mm_set1_epi32(0x7FFFFFFFu);
x02 = _mm_and_si128(x02, x03);
x00 = _mm_add_epi32(x00, x02);
x01 = _mm_srai_epi32(x00, 1);
x02 = _mm_set1_epi32(1);
x03 = _mm_set1_epi32(0x9908B0DFu);
x00 = _mm_and_si128(x01, x02);
x02 = muly(x00, x03);
xA.m = _mm_xor_si128(x01, x02);
x[224] = x[621] ^ xA.v[0];
x[225] = x[622] ^ xA.v[1];
x[226] = x[623] ^ xA.v[2];
x[227] = x[0] ^ xA.v[3];
}
for(int i = 228; i < 624; i+=4) {
x00 = _mm_lddqu_si128((const __m128i*) (x + i));
x01 = _mm_set1_epi32(0x80000000u);
x00 = _mm_and_si128(x00, x01);
x02 = _mm_lddqu_si128((const __m128i*) (x + i + 1));
x03 = _mm_set1_epi32(0x7FFFFFFFu);
x02 = _mm_and_si128(x02, x03);
x00 = _mm_add_epi32(x00, x02);
x01 = _mm_srai_epi32(x00, 1);
x02 = _mm_set1_epi32(1);
x03 = _mm_set1_epi32(0x9908B0DFu);
x00 = _mm_and_si128(x01, x02);
x02 = muly(x00, x03);
x03 = _mm_xor_si128(x01, x02);
x01 = _mm_lddqu_si128((const __m128i*) (x + i - 227));
x02 = _mm_xor_si128(x03, x01);
_mm_storeu_si128((__m128i*) (x + i), x02);
}
x[624] = x[0];
/*
for(int i = 0; i < 624; i++) {
uint32_t x0 = (x[i] & 0x80000000) + (x[(i+1) % 624] & 0x7FFFFFFF);
uint32_t xA = x0 >> 1;
if (x0 % 2 != 0) {
xA = xA ^ 0x9908B0DF;
}
x[i] = x[(i + 397) % 624] ^ xA;
}
*/
i = 0;
}
uint32_t operator()() {
if (i >= 624) {
twist();
}
uint32_t z = x[i++];
z ^= (z >> 11);
z ^= ((z << 7) & 0x9d2c5680);
z ^= ((z << 15) & 0xefc6000);
z ^= (z >> 18);
return z;
}
};
double random_01(CustomTwister& twister) {
return twister() / (double) 0xFFFFFFFFu;
}
int main() {
std::random_device rd;
unsigned seed = rd();
std::mt19937 std_generator(seed);
std::uniform_real_distribution<> std_01dist(0.0, 1.0);
boost::mt19937 boost_generator;
boost_generator.seed(seed);
boost::uniform_01<> boost_01dist;
const unsigned num_trials = 1000000;
for(unsigned counter = 0; counter < num_trials / 10; ++counter) {
double num = std_01dist(std_generator);
}
double std_total = 0.0;
high_resolution_clock::time_point std_start = high_resolution_clock::now();
for(unsigned counter = 0; counter < num_trials; ++counter) {
double num = std_01dist(std_generator);
std_total += num;
}
// std_total should be roughly close to num_trials / 2
high_resolution_clock::time_point std_end = high_resolution_clock::now();
double std_duration = duration_cast<std::chrono::microseconds>(std_end - std_start).count() / 1000.0;
for(unsigned counter = 0; counter < num_trials / 10; ++counter) {
double num = boost_01dist(boost_generator);
}
double boost_total = 0.0;
high_resolution_clock::time_point boost_start = high_resolution_clock::now();
for(unsigned counter = 0; counter < num_trials; ++counter) {
double num = boost_01dist(boost_generator);
boost_total += num;
}
// boost_total should be roughly close to num_trials / 2
high_resolution_clock::time_point boost_end = high_resolution_clock::now();
double boost_duration = duration_cast<std::chrono::microseconds>(boost_end - boost_start).count() / 1000.0;
CustomTwister twister(seed);
for(unsigned counter = 0; counter < num_trials / 10; ++counter) {
double num = random_01(twister);
}
double custom_total = 0.0;
high_resolution_clock::time_point custom_start = high_resolution_clock::now();
for(unsigned counter = 0; counter < num_trials; ++counter) {
double num = random_01(twister);
custom_total += num;
}
// boost_total should be roughly close to num_trials / 2
high_resolution_clock::time_point custom_end = high_resolution_clock::now();
double custom_duration = duration_cast<std::chrono::microseconds>(custom_end - custom_start).count() / 1000.0;
printf(" std took %.3fms with a delta of %.2f\n", std_duration, abs(std_total - 500000));
printf(" boost took %.3fms with a delta of %.2f\n", boost_duration, abs(boost_total - 500000));
printf("custom took %.3fms with a delta of %.2f\n", custom_duration, abs(custom_total - 500000));
printf("\n");
if (std_duration > boost_duration && custom_duration > boost_duration) {
printf("boost is %.1fx faster than std\n", std_duration / boost_duration);
printf("boost is %.1fx faster than custom\n", custom_duration / boost_duration);
} else if (std_duration > custom_duration) {
printf("custom is %.1fx faster than std\n", std_duration / custom_duration);
printf("custom is %.1fx faster than boost\n", boost_duration / custom_duration);
} else {
printf("std is %.1fx faster than boost\n", boost_duration / std_duration);
printf("std is %.1fx faster than custom\n", custom_duration / std_duration);
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment