Skip to content

Instantly share code, notes, and snippets.

@acaly
Last active June 14, 2021 10:49
Show Gist options
  • Save acaly/3f0a5d9fadc09a632143b7618a44e894 to your computer and use it in GitHub Desktop.
Save acaly/3f0a5d9fadc09a632143b7618a44e894 to your computer and use it in GitHub Desktop.
Split a given amount using coins of specified denominations - SIMD optimized version.
#include <iostream>
#include <vector>
#include <unordered_set>
#include <unordered_map>
#include <algorithm>
#include <ctime>
#include <chrono>
#include <memory>
#include <immintrin.h>
struct alignas(256) sse256_int
{
int a[8];
};
static void write_result(int* plans, int remaining, std::unordered_map<int, int>& result)
{
while (remaining >= 0)
{
auto& p = plans[remaining];
result[p] += 1;
remaining -= p;
}
}
static void propagate_small_coin(int coin, int* min, int* mask)
{
for (int i = coin; i < 8; ++i)
{
if (min[i] > min[i - coin] + 1)
{
min[i] = min[i - coin] + 1;
mask[i] = coin;
}
}
}
static bool solve(std::unordered_set<int>& coins, int target, std::unordered_map<int, int>& result)
{
std::vector<sse256_int> plans_count_vec;
std::vector<sse256_int> plans_prev_vec;
int sse_length_total = (target + 7) / 8;
plans_count_vec.resize(sse_length_total + 2);
plans_prev_vec.resize(sse_length_total + 2);
int* plans_count = (int*)&plans_count_vec[1];
int* plans_prev = (int*)&plans_prev_vec[1];
sse256_int tmp256;
int* ptmp = (int*)&tmp256;
for (int i = 0; i < 7; ++i)
{
plans_count[i - 8] = target + 1;
plans_prev[i - 8] = 0;
}
plans_count[-1] = 0;
plans_prev[-1] == 0;
std::vector<int> coins_sorted = { coins.begin(), coins.end() };
std::sort(coins_sorted.begin(), coins_sorted.end());
std::size_t next_coin_index = 0;
int next_coin = coins_sorted[0];
__m256i min_count_init = _mm256_set1_epi32(target + 1);
__m256i add_one = _mm256_set1_epi32(1);
__m256i read_init_mask[8] = {};
for (int i = 1; i < 8; ++i)
{
int* p = (int*)&read_init_mask[i];
for (int j = i; j < 8; ++j)
{
p[j] = target + 1;
}
}
for (int i = 1; i <= target; i += 8)
{
__m256i min_count = min_count_init;
__m256i min_prev = _mm256_setzero_si256();
while (i + 8 > next_coin && next_coin_index < coins_sorted.size())
{
next_coin_index += 1;
if (next_coin_index < coins_sorted.size())
{
next_coin = coins_sorted[next_coin_index];
}
}
for (int j = (int)next_coin_index - 1; j >= 0; --j)
{
int coin = coins_sorted[j];
__m256i read = _mm256_loadu_si256((__m256i*)&plans_count[i - coin - 1]);
if (coin < 8)
{
read = _mm256_max_epi32(read, read_init_mask[coin]);
}
min_count = _mm256_min_epi32(min_count, read);
__m256i write_prev_flags = _mm256_cmpeq_epi32(min_count, read);
__m256i write_prev_value = _mm256_set1_epi32(coin);
min_prev = _mm256_blendv_epi8(min_prev, write_prev_value, write_prev_flags);
if (coin < 8)
{
sse256_int min_count_m, min_prev_m;
_mm256_store_si256((__m256i*)&min_count_m, min_count);
_mm256_store_si256((__m256i*)&min_prev_m, min_prev);
propagate_small_coin(coin, &min_count_m.a[0], &min_prev_m.a[0]);
min_count = _mm256_load_si256((__m256i*)&min_count_m);
min_prev = _mm256_load_si256((__m256i*)&min_prev_m);
}
}
min_count = _mm256_add_epi32(min_count, add_one);
_mm256_storeu_si256((__m256i*)&plans_count[i - 1], min_count);
_mm256_storeu_si256((__m256i*)&plans_prev[i - 1], min_prev);
}
if (plans_count[target - 1] > target)
{
result.clear();
return false;
}
result.clear();
write_result(plans_prev, target - 1, result);
return true;
}
static auto time_ms()
{
return std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
}
//sample input: 6 672 660 600 453 445 405 20000
int main()
{
int ncoins;
std::cin >> ncoins;
std::unordered_set<int> coins;
for (int i = 0; i < ncoins; ++i)
{
int c;
std::cin >> c;
coins.insert(c);
}
int target;
std::cin >> target;
std::unordered_map<int, int> result;
auto t1 = time_ms();
bool success = solve(coins, target, result);
if (!success)
{
std::cout << "无可行方案" << std::endl;
}
else
{
std::vector<std::pair<int, int>> output_result = { result.begin(), result.end() };
std::sort(output_result.begin(), output_result.end(), [](auto& a, auto& b) { return a.first < b.first; });
std::cout << target << " = ";
for (std::size_t i = 0; i < output_result.size(); ++i)
{
if (i != 0)
{
std::cout << " + ";
}
std::cout << output_result[i].first << " * " << output_result[i].second;
}
std::cout << std::endl;
}
auto t2 = time_ms();
std::cout << "用时" << (t2 - t1) << " ms";
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment