Skip to content

Instantly share code, notes, and snippets.

@amallia
Last active December 19, 2017 22:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amallia/6b096a40b7dde339d7e59dfdec5afaaa to your computer and use it in GitHub Desktop.
Save amallia/6b096a40b7dde339d7e59dfdec5afaaa to your computer and use it in GitHub Desktop.
// search.cpp
// g++ -march=native -std=c++14 search.cpp -o search
#include <random>
#include <iostream>
#include <cstdlib>
#include <smmintrin.h>
#include <vector>
#include <algorithm>
#include <chrono>
int search(const std::vector<uint32_t>& data, uint32_t key) {
const unsigned n = data.size();
for (unsigned i=0; i < n; i++) {
if (key == data[i]) {
return i;
}
}
return -1;
}
int sse_search(const std::vector<uint32_t>& data, uint32_t key) {
const __m128i keys = _mm_set1_epi32(key);
const auto n = data.size();
const auto rounded = 8 * (n/8);
for (size_t i=0; i < rounded; i += 8) {
const __m128i vec1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(&data[i]));
const __m128i vec2 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(&data[i + 4]));
const __m128i cmp1 = _mm_cmpeq_epi32(vec1, keys);
const __m128i cmp2 = _mm_cmpeq_epi32(vec2, keys);
const __m128i tmp = _mm_packs_epi32(cmp1, cmp2);
const uint32_t mask = _mm_movemask_epi8(tmp);
if (mask != 0) {
return i + __builtin_ctz(mask)/2;
}
}
for (size_t i = rounded; i < n; i++) {
if (data[i] == key) {
return i;
}
}
return -1;
}
#define _mm_setbits_si128(a) _mm_cmpeq_epi32(a, a)
__m128i _mm_invert_si128(__m128i a) {
return _mm_xor_si128(a, _mm_setbits_si128(a)); // returns ~a
}
__m128i _mm_cmpge_epi32(__m128i a, __m128i b) {
return _mm_invert_si128(_mm_cmplt_epi32(a, b)); // a >= b == ~(a < b)
}
int sse_next_geq(const std::vector<uint32_t>& data, uint32_t key) {
const __m128i keys = _mm_set1_epi32(key);
const auto n = data.size();
const auto rounded = 8 * (n/8);
for (size_t i=0; i < rounded; i += 8) {
const __m128i vec1 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(&data[i]));
const __m128i vec2 = _mm_loadu_si128(reinterpret_cast<const __m128i*>(&data[i + 4]));
const __m128i cmp1 = _mm_cmpge_epi32(vec1, keys);
const __m128i cmp2 = _mm_cmpge_epi32(vec2, keys);
const __m128i tmp = _mm_packs_epi32(cmp1, cmp2);
const uint32_t mask = _mm_movemask_epi8(tmp);
if (mask != 0) {
return i + __builtin_ctz(mask)/2;
}
}
for (size_t i = rounded; i < n; i++) {
if (data[i] <= key) {
return i;
}
}
return -1;
}
int next_geq(const std::vector<uint32_t>& data, uint32_t key) {
const unsigned n = data.size();
for (unsigned i=0; i < n; i++) {
if (key <= data[i]) {
return i;
}
}
return -1;
}
int main()
{
std::random_device rd;
std::mt19937 gen(rd());
auto size = 1000 * 1000 * 10;
std::uniform_int_distribution<> dis(1, size);
std::vector<uint32_t> data;
for(size_t i = 0; i < size; ++i) {
data.push_back(dis(gen));
}
std::sort(data.begin(), data.end());
int pos = 0;
auto start = std::chrono::steady_clock::now();
pos = sse_search(data, size+1);
auto end = std::chrono::steady_clock::now();
std::cout << "Elapsed time in milliseconds : "
<< std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()
<< " ms" << std::endl;
std::cout << pos << std::endl;
start = std::chrono::steady_clock::now();
pos = search(data, size+1);
end = std::chrono::steady_clock::now();
std::cout << "Elapsed time in milliseconds : "
<< std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()
<< " ms" << std::endl;
std::cout << pos << std::endl;
start = std::chrono::steady_clock::now();
pos = sse_next_geq(data, size+1);
end = std::chrono::steady_clock::now();
std::cout << "Elapsed time in milliseconds : "
<< std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()
<< " ms" << std::endl;
std::cout << pos << std::endl;
start = std::chrono::steady_clock::now();
pos = next_geq(data, size+1);
end = std::chrono::steady_clock::now();
std::cout << "Elapsed time in milliseconds : "
<< std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count()
<< " ms" << std::endl;
std::cout << pos << std::endl;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment