Last active
June 14, 2021 10:49
-
-
Save acaly/3f0a5d9fadc09a632143b7618a44e894 to your computer and use it in GitHub Desktop.
Split a given amount using coins of specified denominations - SIMD optimized version.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <vector> | |
#include <unordered_set> | |
#include <unordered_map> | |
#include <algorithm> | |
#include <ctime> | |
#include <chrono> | |
#include <memory> | |
#include <immintrin.h> | |
struct alignas(256) sse256_int | |
{ | |
int a[8]; | |
}; | |
static void write_result(int* plans, int remaining, std::unordered_map<int, int>& result) | |
{ | |
while (remaining >= 0) | |
{ | |
auto& p = plans[remaining]; | |
result[p] += 1; | |
remaining -= p; | |
} | |
} | |
static void propagate_small_coin(int coin, int* min, int* mask) | |
{ | |
for (int i = coin; i < 8; ++i) | |
{ | |
if (min[i] > min[i - coin] + 1) | |
{ | |
min[i] = min[i - coin] + 1; | |
mask[i] = coin; | |
} | |
} | |
} | |
static bool solve(std::unordered_set<int>& coins, int target, std::unordered_map<int, int>& result) | |
{ | |
std::vector<sse256_int> plans_count_vec; | |
std::vector<sse256_int> plans_prev_vec; | |
int sse_length_total = (target + 7) / 8; | |
plans_count_vec.resize(sse_length_total + 2); | |
plans_prev_vec.resize(sse_length_total + 2); | |
int* plans_count = (int*)&plans_count_vec[1]; | |
int* plans_prev = (int*)&plans_prev_vec[1]; | |
sse256_int tmp256; | |
int* ptmp = (int*)&tmp256; | |
for (int i = 0; i < 7; ++i) | |
{ | |
plans_count[i - 8] = target + 1; | |
plans_prev[i - 8] = 0; | |
} | |
plans_count[-1] = 0; | |
plans_prev[-1] == 0; | |
std::vector<int> coins_sorted = { coins.begin(), coins.end() }; | |
std::sort(coins_sorted.begin(), coins_sorted.end()); | |
std::size_t next_coin_index = 0; | |
int next_coin = coins_sorted[0]; | |
__m256i min_count_init = _mm256_set1_epi32(target + 1); | |
__m256i add_one = _mm256_set1_epi32(1); | |
__m256i read_init_mask[8] = {}; | |
for (int i = 1; i < 8; ++i) | |
{ | |
int* p = (int*)&read_init_mask[i]; | |
for (int j = i; j < 8; ++j) | |
{ | |
p[j] = target + 1; | |
} | |
} | |
for (int i = 1; i <= target; i += 8) | |
{ | |
__m256i min_count = min_count_init; | |
__m256i min_prev = _mm256_setzero_si256(); | |
while (i + 8 > next_coin && next_coin_index < coins_sorted.size()) | |
{ | |
next_coin_index += 1; | |
if (next_coin_index < coins_sorted.size()) | |
{ | |
next_coin = coins_sorted[next_coin_index]; | |
} | |
} | |
for (int j = (int)next_coin_index - 1; j >= 0; --j) | |
{ | |
int coin = coins_sorted[j]; | |
__m256i read = _mm256_loadu_si256((__m256i*)&plans_count[i - coin - 1]); | |
if (coin < 8) | |
{ | |
read = _mm256_max_epi32(read, read_init_mask[coin]); | |
} | |
min_count = _mm256_min_epi32(min_count, read); | |
__m256i write_prev_flags = _mm256_cmpeq_epi32(min_count, read); | |
__m256i write_prev_value = _mm256_set1_epi32(coin); | |
min_prev = _mm256_blendv_epi8(min_prev, write_prev_value, write_prev_flags); | |
if (coin < 8) | |
{ | |
sse256_int min_count_m, min_prev_m; | |
_mm256_store_si256((__m256i*)&min_count_m, min_count); | |
_mm256_store_si256((__m256i*)&min_prev_m, min_prev); | |
propagate_small_coin(coin, &min_count_m.a[0], &min_prev_m.a[0]); | |
min_count = _mm256_load_si256((__m256i*)&min_count_m); | |
min_prev = _mm256_load_si256((__m256i*)&min_prev_m); | |
} | |
} | |
min_count = _mm256_add_epi32(min_count, add_one); | |
_mm256_storeu_si256((__m256i*)&plans_count[i - 1], min_count); | |
_mm256_storeu_si256((__m256i*)&plans_prev[i - 1], min_prev); | |
} | |
if (plans_count[target - 1] > target) | |
{ | |
result.clear(); | |
return false; | |
} | |
result.clear(); | |
write_result(plans_prev, target - 1, result); | |
return true; | |
} | |
static auto time_ms() | |
{ | |
return std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count(); | |
} | |
//sample input: 6 672 660 600 453 445 405 20000 | |
int main() | |
{ | |
int ncoins; | |
std::cin >> ncoins; | |
std::unordered_set<int> coins; | |
for (int i = 0; i < ncoins; ++i) | |
{ | |
int c; | |
std::cin >> c; | |
coins.insert(c); | |
} | |
int target; | |
std::cin >> target; | |
std::unordered_map<int, int> result; | |
auto t1 = time_ms(); | |
bool success = solve(coins, target, result); | |
if (!success) | |
{ | |
std::cout << "无可行方案" << std::endl; | |
} | |
else | |
{ | |
std::vector<std::pair<int, int>> output_result = { result.begin(), result.end() }; | |
std::sort(output_result.begin(), output_result.end(), [](auto& a, auto& b) { return a.first < b.first; }); | |
std::cout << target << " = "; | |
for (std::size_t i = 0; i < output_result.size(); ++i) | |
{ | |
if (i != 0) | |
{ | |
std::cout << " + "; | |
} | |
std::cout << output_result[i].first << " * " << output_result[i].second; | |
} | |
std::cout << std::endl; | |
} | |
auto t2 = time_ms(); | |
std::cout << "用时" << (t2 - t1) << " ms"; | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment