Last active
December 19, 2017 22:23
-
-
Save amallia/6b096a40b7dde339d7e59dfdec5afaaa to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// search.cpp | |
// g++ -march=native -std=c++14 search.cpp -o search | |
#include <random> | |
#include <iostream> | |
#include <cstdlib> | |
#include <smmintrin.h> | |
#include <vector> | |
#include <algorithm> | |
#include <chrono> | |
int search(const std::vector<uint32_t>& data, uint32_t key) { | |
const unsigned n = data.size(); | |
for (unsigned i=0; i < n; i++) { | |
if (key == data[i]) { | |
return i; | |
} | |
} | |
return -1; | |
} | |
int sse_search(const std::vector<uint32_t>& data, uint32_t key) { | |
const __m128i keys = _mm_set1_epi32(key); | |
const auto n = data.size(); | |
const auto rounded = 8 * (n/8); | |
for (size_t i=0; i < rounded; i += 8) { | |
const __m128i vec1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(&data[i])); | |
const __m128i vec2 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(&data[i + 4])); | |
const __m128i cmp1 = _mm_cmpeq_epi32(vec1, keys); | |
const __m128i cmp2 = _mm_cmpeq_epi32(vec2, keys); | |
const __m128i tmp = _mm_packs_epi32(cmp1, cmp2); | |
const uint32_t mask = _mm_movemask_epi8(tmp); | |
if (mask != 0) { | |
return i + __builtin_ctz(mask)/2; | |
} | |
} | |
for (size_t i = rounded; i < n; i++) { | |
if (data[i] == key) { | |
return i; | |
} | |
} | |
return -1; | |
} | |
#define _mm_setbits_si128(a) _mm_cmpeq_epi32(a, a) | |
__m128i _mm_invert_si128(__m128i a) { | |
return _mm_xor_si128(a, _mm_setbits_si128(a)); // returns ~a | |
} | |
__m128i _mm_cmpge_epi32(__m128i a, __m128i b) { | |
return _mm_invert_si128(_mm_cmplt_epi32(a, b)); // a >= b == ~(a < b) | |
} | |
int sse_next_geq(const std::vector<uint32_t>& data, uint32_t key) { | |
const __m128i keys = _mm_set1_epi32(key); | |
const auto n = data.size(); | |
const auto rounded = 8 * (n/8); | |
for (size_t i=0; i < rounded; i += 8) { | |
const __m128i vec1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(&data[i])); | |
const __m128i vec2 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(&data[i + 4])); | |
const __m128i cmp1 = _mm_cmpge_epi32(vec1, keys); | |
const __m128i cmp2 = _mm_cmpge_epi32(vec2, keys); | |
const __m128i tmp = _mm_packs_epi32(cmp1, cmp2); | |
const uint32_t mask = _mm_movemask_epi8(tmp); | |
if (mask != 0) { | |
return i + __builtin_ctz(mask)/2; | |
} | |
} | |
for (size_t i = rounded; i < n; i++) { | |
if (data[i] <= key) { | |
return i; | |
} | |
} | |
return -1; | |
} | |
int next_geq(const std::vector<uint32_t>& data, uint32_t key) { | |
const unsigned n = data.size(); | |
for (unsigned i=0; i < n; i++) { | |
if (key <= data[i]) { | |
return i; | |
} | |
} | |
return -1; | |
} | |
int main() | |
{ | |
std::random_device rd; | |
std::mt19937 gen(rd()); | |
auto size = 1000 * 1000 * 10; | |
std::uniform_int_distribution<> dis(1, size); | |
std::vector<uint32_t> data; | |
for(size_t i = 0; i < size; ++i) { | |
data.push_back(dis(gen)); | |
} | |
std::sort(data.begin(), data.end()); | |
int pos = 0; | |
auto start = std::chrono::steady_clock::now(); | |
pos = sse_search(data, size+1); | |
auto end = std::chrono::steady_clock::now(); | |
std::cout << "Elapsed time in milliseconds : " | |
<< std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() | |
<< " ms" << std::endl; | |
std::cout << pos << std::endl; | |
start = std::chrono::steady_clock::now(); | |
pos = search(data, size+1); | |
end = std::chrono::steady_clock::now(); | |
std::cout << "Elapsed time in milliseconds : " | |
<< std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() | |
<< " ms" << std::endl; | |
std::cout << pos << std::endl; | |
start = std::chrono::steady_clock::now(); | |
pos = sse_next_geq(data, size+1); | |
end = std::chrono::steady_clock::now(); | |
std::cout << "Elapsed time in milliseconds : " | |
<< std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() | |
<< " ms" << std::endl; | |
std::cout << pos << std::endl; | |
start = std::chrono::steady_clock::now(); | |
pos = next_geq(data, size+1); | |
end = std::chrono::steady_clock::now(); | |
std::cout << "Elapsed time in milliseconds : " | |
<< std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() | |
<< " ms" << std::endl; | |
std::cout << pos << std::endl; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment