-
-
Save sipa/868c39e29af9e8baf22845c8af2d316d to your computer and use it in GitHub Desktop.
SDMC simulation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <stdio.h> | |
#include <string.h> | |
#include <stdlib.h> | |
#include <stdint.h> | |
#include <assert.h> | |
#include <chrono> | |
#include <atomic> | |
#include <mutex> | |
#include <thread> | |
#include <optional> | |
#include <tuple> | |
#include <cmath> | |
#include <iostream> | |
#include <fstream> | |
#include <vector> | |
#include <algorithm> | |
#include <string> | |
#include <string_view> | |
#include <random> | |
namespace { | |
static constexpr unsigned int FEATURES = 33264; | |
static constexpr unsigned int THREADS = 30; | |
template<unsigned BITS> | |
class BitSet { | |
static constexpr unsigned WORDS = (BITS + 63) / 64; | |
uint64_t val[WORDS]; | |
public: | |
BitSet() { | |
std::fill(std::begin(val), std::end(val), 0); | |
} | |
BitSet(const BitSet&) = default; | |
BitSet& operator=(const BitSet&) = default; | |
BitSet& operator|=(const BitSet& other) { | |
for (unsigned i = 0; i < WORDS; ++i) val[i] |= other.val[i]; | |
return *this; | |
} | |
BitSet& operator&=(const BitSet& other) { | |
for (unsigned i = 0; i < WORDS; ++i) val[i] &= other.val[i]; | |
return *this; | |
} | |
BitSet& Remove(const BitSet& other) { | |
for (unsigned i = 0; i < WORDS; ++i) val[i] &= ~other.val[i]; | |
return *this; | |
} | |
uint64_t Count() const { | |
uint64_t ret = 0; | |
for (unsigned i = 0; i < WORDS; ++i) ret += __builtin_popcountl(val[i]); | |
return ret; | |
} | |
explicit operator bool() const { | |
for (unsigned i = 0; i < WORDS; ++i) { | |
if (val[i]) return true; | |
} | |
return false; | |
} | |
bool Get(unsigned pos) const { | |
return (val[pos >> 6] >> (pos & 63)) & 1; | |
} | |
void Set(unsigned pos) { | |
val[pos >> 6] |= uint64_t{1} << (pos & 63); | |
} | |
void Set(unsigned pos, bool bit) { | |
val[pos >> 6] = (val[pos >> 6] & ~(uint64_t{1} << (pos & 63))) | (uint64_t{bit} << (pos & 63)); | |
} | |
void Unset(unsigned pos) { | |
val[pos >> 6] &= ~(uint64_t{1} << (pos & 63)); | |
} | |
void OrHex(unsigned hexpos, uint8_t v) { | |
val[hexpos >> 4] |= uint64_t{v} << (4 * (hexpos & 15)); | |
} | |
friend bool operator!=(const BitSet& a, const BitSet& b) { | |
for (unsigned i = 0; i < WORDS; ++i) { | |
if (a.val[i] != b.val[i]) return true; | |
} | |
return false; | |
} | |
friend bool operator==(const BitSet& a, const BitSet& b) { | |
for (unsigned i = 0; i < WORDS; ++i) { | |
if (a.val[i] != b.val[i]) return false; | |
} | |
return true; | |
} | |
friend bool operator<(const BitSet& a, const BitSet& b) { | |
for (unsigned i = 0; i < WORDS; ++i) { | |
if (a.val[i] < b.val[i]) return true; | |
if (a.val[i] > b.val[i]) return false; | |
} | |
return false; | |
} | |
}; | |
template<unsigned BITS> | |
BitSet<BITS> FromString(std::string_view view) { | |
static constexpr int8_t TBL[256] = { | |
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, | |
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, | |
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, | |
0,1,2,3,4,5,6,7,8,9,-1,-1,-1,-1,-1,-1, | |
-1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1, | |
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, | |
-1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1, | |
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, | |
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, | |
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, | |
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, | |
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, | |
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, | |
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, | |
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, | |
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1 | |
}; | |
if (view.size() >= 2 && view[0] == '0' && view[1] == 'x') { | |
view = view.substr(2, view.size() - 2); | |
} | |
if (view.size() * 4 > BITS) { | |
assert(!"Too long strong"); | |
abort(); | |
} | |
BitSet<BITS> ret; | |
for (unsigned i = 0; i < view.size(); ++i) { | |
int c = TBL[(uint8_t)view[i]]; | |
if (c == -1) { | |
assert(!"Invalid hex character"); | |
abort(); | |
} | |
ret.OrHex(view.size() - 1 - i, uint8_t(c)); | |
} | |
return ret; | |
} | |
template<unsigned BITS> | |
unsigned LoadFile(std::vector<BitSet<BITS>>& out, std::istream& istream, unsigned max_elem, unsigned thread_count) { | |
std::mutex mutex; | |
unsigned total_loaded = 0; | |
out.resize(max_elem); | |
auto thread_fn = [&]() { | |
std::string line; | |
unsigned loaded = 0; | |
while (true) { | |
{ | |
std::unique_lock<std::mutex> lock(mutex); | |
total_loaded += loaded; | |
if (loaded && (total_loaded & 0xffff) == 0) { | |
printf("Loaded %u entries.\n", total_loaded); | |
} | |
loaded = 0; | |
if (!std::getline(istream, line)) break; | |
} | |
char* ptr; | |
unsigned long idx = strtoul(line.c_str(), &ptr, 0); | |
if (*ptr != ' ') { | |
assert(!"Missing space"); | |
abort(); | |
} | |
if (idx >= max_elem) break; | |
if (!out[idx]) { | |
out[idx] = FromString<BITS>(std::string_view{line}.substr(ptr - line.c_str() + 1)); | |
loaded += 1; | |
} | |
} | |
}; | |
std::vector<std::thread> threads; | |
threads.reserve(thread_count - 1); | |
for (unsigned i = 1; i < thread_count; ++i) threads.emplace_back(thread_fn); | |
thread_fn(); | |
for (auto& thread : threads) thread.join(); | |
while (!out.empty() && !out.back()) out.pop_back(); | |
return total_loaded; | |
} | |
std::vector<BitSet<FEATURES>> TABLE; | |
std::vector<uint64_t> WEIGHT; | |
uint64_t GetScore(const BitSet<FEATURES>& mask) { | |
/* double ret = 0; | |
for (unsigned i = 0; i < FEATURES; ++i) { | |
if (mask.Get(i)) ret += WEIGHT[i]; | |
} | |
return ret;*/ | |
return mask.Count(); | |
} | |
/* Set out = in = max(in, out), atomically. */ | |
void AtomicMax(std::atomic<uint64_t>& global, uint64_t& local) { | |
uint64_t old_global = global.load(); | |
while (true) { | |
if (old_global >= local) { | |
local = old_global; | |
break; | |
} | |
if (global.compare_exchange_weak(old_global, local)) break; | |
} | |
} | |
uint64_t RdSeed() { | |
uint8_t ok; | |
uint64_t r; | |
do { | |
__asm__ volatile ("rdseed %0; setc %1" : "=a"(r), "=q"(ok) :: "cc"); | |
if (ok) return r; | |
__asm__ volatile ("pause"); | |
} while(1); | |
} | |
template<int K> | |
inline uint64_t Rotl(const uint64_t x) { | |
return (x << K) | (x >> (64 - K)); | |
} | |
class RNG { | |
uint64_t v0, v1, v2, v3; | |
public: | |
RNG() : v0(RdSeed()), v1(RdSeed()), v2(RdSeed()), v3(RdSeed()) {} | |
/* Xoshiro256++ */ | |
uint64_t operator()() { | |
const uint64_t result = Rotl<23>(v0 + v3) + v0; | |
const uint64_t t = v1 << 17; | |
v2 ^= v0; | |
v3 ^= v1; | |
v1 ^= v2; | |
v0 ^= v3; | |
v2 ^= t; | |
v3 = Rotl<45>(v3); | |
return result; | |
} | |
uint64_t randrange(uint64_t max) { | |
if (max == 0) return 0; | |
static_assert(sizeof(unsigned long) == sizeof(uint64_t)); | |
uint64_t mask = 0xffffffffffffffff >> __builtin_clzl(max); | |
while (true) { | |
uint64_t rand = (*this)() & mask; | |
if (rand < max) return rand; | |
} | |
} | |
static double constexpr entropy() { return 0.0; } | |
static uint64_t constexpr min() { return 0; } | |
static uint64_t constexpr max() { return 0xffffffffffffffff; } | |
}; | |
template <typename I, typename R> | |
void Shuffle(I first, I last, R&& rng) | |
{ | |
while (first != last) { | |
size_t j = rng.randrange(last - first - 1); | |
if (j) { | |
using std::swap; | |
swap(*first, *(first + j)); | |
} | |
++first; | |
} | |
} | |
template <typename I, typename C> | |
I Prune(I first, I last, C cond) | |
{ | |
while (first != last) { | |
if (cond(*first)) { | |
std::advance(last, -1); | |
if (first != last) { | |
using std::swap; | |
swap(*first, *last); | |
} | |
} else { | |
std::advance(first, 1); | |
} | |
} | |
return last; | |
} | |
std::vector<unsigned> Optimize(unsigned bufsize, unsigned memsize, int threads, unsigned batch_size) { | |
using bufelem = std::tuple<uint64_t, BitSet<FEATURES>, std::vector<unsigned>>; | |
using buftype = std::vector<bufelem>; | |
static const std::string LOGSTR[4] = {"Local inner", "Local final", "Merge inner", "Global final"}; | |
buftype buf, nbuf; | |
BitSet<FEATURES> allmask; | |
for (const auto& mask : TABLE) { | |
allmask |= mask; | |
} | |
buf.emplace_back(0.0, BitSet<FEATURES>{}, std::vector<unsigned>{}); | |
const unsigned tblsize = TABLE.size(); | |
RNG grng; | |
std::atomic<uint64_t> glimit{0}; | |
std::mutex mutex; | |
std::optional<std::vector<unsigned>> solution; | |
using float_ms = std::chrono::duration<double, std::ratio<1, 1000>>; | |
auto compact_fn = [&](buftype& buffer, uint64_t& limit, int thread, int stage, RNG& rng) { | |
auto start = std::chrono::steady_clock::now(); | |
unsigned old_size = buffer.size(); | |
uint64_t old_limit = limit; | |
if (old_size > bufsize) { | |
// Prune. | |
buffer.erase(Prune(std::begin(buffer), std::end(buffer), [limit](const bufelem& a) { return std::get<0>(a) < limit; }), std::end(buffer)); | |
} | |
unsigned prn_size = buffer.size(); | |
if (old_size > bufsize || stage == 3) { | |
std::sort(std::begin(buffer), std::end(buffer), [](const bufelem& a, const bufelem& b) { | |
return std::tie(std::get<0>(a), std::get<1>(a)) > std::tie(std::get<0>(b), std::get<1>(b)); | |
}); | |
// Deduplicate. | |
buffer.erase(std::unique(std::begin(buffer), std::end(buffer), [](const bufelem& a, const bufelem& b) { return std::get<0>(a) == std::get<0>(b) && std::get<1>(a) == std::get<1>(b); }), buffer.end()); | |
} | |
unsigned mid_size = buffer.size(); | |
if (mid_size > bufsize) { | |
// Shuffle similar scores. | |
uint64_t last_score = std::get<0>(buffer[bufsize - 1]); | |
auto end_last_score = buffer.begin() + bufsize; | |
auto begin_last_score = std::prev(end_last_score); | |
while (begin_last_score != buffer.begin() && std::get<0>(*std::prev(begin_last_score)) == last_score) { | |
std::advance(begin_last_score, -1); | |
} | |
while (end_last_score != buffer.end() && std::get<0>(*end_last_score) == last_score) { | |
std::advance(end_last_score, 1); | |
} | |
Shuffle(begin_last_score, end_last_score, rng); | |
// Compact. | |
buffer.resize(bufsize); | |
limit = std::max(limit, std::get<0>(buffer.back())); | |
} | |
unsigned fin_size = buffer.size(); | |
auto stop = std::chrono::steady_clock::now(); | |
printf("- [thread %i/%i] %s compaction orig=%u -> prune=%u -> dedup=%u -> shrink=%u: minscore=%lu->%lu time=%gms\n", thread, threads, LOGSTR[stage].c_str(), old_size, prn_size, mid_size, fin_size, old_limit, limit, float_ms(stop - start).count()); | |
}; | |
std::atomic<uint64_t> todo_pos; | |
uint64_t todo; | |
auto thread_fn = [&](int thread) { | |
buftype lbuf; | |
uint64_t limit{0}; | |
std::vector<unsigned> nsol; | |
const unsigned bufsize = buf.size(); | |
RNG trng; | |
while (true) { | |
uint64_t pos = todo_pos.fetch_add(batch_size); | |
if (pos >= todo) break; | |
unsigned tbl_pos = pos / bufsize; | |
unsigned buf_pos = pos - uint64_t{tbl_pos} * bufsize; | |
int todo_local = batch_size; | |
limit = std::max(limit, glimit.load()); | |
while (todo_local--) { | |
const auto& tbl_entry = TABLE[tbl_pos]; | |
const auto& buf_entry = buf[buf_pos]; | |
auto nmask = tbl_entry; | |
if (nmask) { | |
nmask.Remove(std::get<1>(buf_entry)); | |
if (nmask) { | |
uint64_t nscore = GetScore(nmask) + std::get<0>(buf_entry); | |
if (nscore >= limit) { | |
nsol = std::get<2>(buf_entry); | |
nsol.push_back(tbl_pos); | |
nmask |= std::get<1>(buf_entry); | |
if (nmask == allmask) { | |
std::unique_lock<std::mutex> lock(mutex); | |
solution = nsol; | |
todo_pos = todo; | |
return; | |
} | |
lbuf.emplace_back(nscore, nmask, std::move(nsol)); | |
if (lbuf.size() >= memsize) { | |
limit = std::max(limit, glimit.load()); | |
compact_fn(lbuf, limit, thread, 0, trng); | |
AtomicMax(glimit, limit); | |
} | |
} | |
} | |
} | |
buf_pos += 1; | |
if (buf_pos == bufsize) { | |
buf_pos = 0; | |
tbl_pos += 1; | |
if (tbl_pos == tblsize) break; | |
} | |
} | |
} | |
limit = std::max(limit, glimit.load()); | |
compact_fn(lbuf, limit, thread, 1, trng); | |
AtomicMax(glimit, limit); | |
std::unique_lock<std::mutex> lock(mutex); | |
if (solution.has_value()) return; | |
for (auto& entry : lbuf) { | |
if (std::get<0>(entry) >= limit) { | |
nbuf.emplace_back(std::move(entry)); | |
if (nbuf.size() >= memsize) { | |
limit = std::max(limit, glimit.load()); | |
compact_fn(nbuf, limit, thread, 2, trng); | |
AtomicMax(glimit, limit); | |
} | |
} | |
} | |
}; | |
while (buf.size()) { | |
printf("Iteration: %u elements, %u vectors, scores %lu..%lu\n", (unsigned)buf.size(), (unsigned)std::get<2>(buf.front()).size(), std::get<0>(buf.back()), std::get<0>(buf.front())); | |
nbuf.clear(); | |
todo = uint64_t{buf.size()} * tblsize; | |
todo_pos = 0; | |
std::vector<std::thread> vthreads; | |
vthreads.reserve(threads); | |
for (int thread = 0; thread < threads; ++thread) { | |
vthreads.emplace_back(thread_fn, thread); | |
} | |
for (int thread = 0; thread < threads; ++thread) { | |
vthreads[thread].join(); | |
} | |
{ | |
std::unique_lock<std::mutex> lock(mutex); | |
if (solution.has_value()) return *solution; | |
} | |
uint64_t limit = glimit.load(); | |
compact_fn(nbuf, limit, -1, 3, grng); | |
std::swap(buf, nbuf); | |
} | |
return {}; | |
} | |
} | |
int main(int argc, char** argv) { | |
unsigned long max_elem = strtoul(argv[1], NULL, 0); | |
unsigned total_loaded = 0; | |
std::vector<BitSet<FEATURES>> load; | |
load.reserve(max_elem); | |
setlinebuf(stdout); | |
for (int j = 2; j < argc; ++j) { | |
std::ifstream ifs(argv[j]); | |
unsigned loaded = LoadFile(load, ifs, max_elem, 1 + !!THREADS); | |
printf("Loaded %u elements from %s\n", loaded, argv[j]); | |
total_loaded += loaded; | |
} | |
Shuffle(load.begin(), load.end(), RNG{}); | |
std::vector<unsigned> feature_counts(FEATURES, 0); | |
unsigned nonzero_loaded = 0; | |
BitSet<FEATURES> allmask; | |
for (const auto& elem : load) { | |
if (elem) { | |
allmask |= elem; | |
for (unsigned f = 0; f < FEATURES; ++f) { | |
feature_counts[f] += elem.Get(f); | |
} | |
nonzero_loaded += 1; | |
} | |
} | |
printf("Loaded %u elements in total (%u nonzero)\n", total_loaded, nonzero_loaded); | |
std::vector<double> feature_logs(FEATURES, 0.0); | |
unsigned total_features = 0; | |
unsigned cond_features = 0; | |
double total_logs = 0.0; | |
for (unsigned f = 0; f < FEATURES; ++f) { | |
unsigned c = feature_counts[f]; | |
if (c == 0) { | |
feature_logs[f] = 0.0; | |
} else { | |
if (c == total_loaded) { | |
feature_logs[f] = 0.0; | |
} else { | |
feature_logs[f] = std::log(double(c) / total_loaded) * -1.4426950408889634073599246810; | |
total_logs += feature_logs[f]; | |
cond_features += 1; | |
} | |
total_features += 1; | |
} | |
} | |
printf("Number of used features: %u (or %u)\n", total_features, (unsigned)allmask.Count()); | |
printf("Number of not-always used features: %u\n", cond_features); | |
printf("Rarest feature probability: 1/2^%f\n", *std::max_element(feature_logs.begin(), feature_logs.end())); | |
printf("Total -log2(prob): %f\n", total_logs); | |
double scale = 0xFFFFFFFFFFFFFFFF / total_logs; | |
std::vector<uint64_t> feature_weights(FEATURES, 0); | |
uint64_t total_weight = 0; | |
for (unsigned f = 0; f < FEATURES; ++f) { | |
feature_weights[f] = uint64_t(feature_logs[f] * scale); | |
total_weight += feature_weights[f]; | |
} | |
feature_logs = {}; | |
printf("Total weight: %lu\n", total_weight); | |
TABLE = std::move(load); | |
WEIGHT = std::move(feature_weights); | |
auto ret = Optimize(2048, 65536, THREADS, 16384); | |
BitSet<FEATURES> reconstruct; | |
printf("Solution: %u vectors\n", (unsigned)ret.size()); | |
printf("Solution:"); | |
for (auto i : ret) { | |
printf(" %u,", (unsigned)i); | |
reconstruct |= TABLE[i]; | |
} | |
printf("\n"); | |
printf("Equal: %i\n", (int)(reconstruct == allmask)); | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
N = 115792089237316195423570985008687907852837564279074904382605163141518161494337 | |
def get_configs(bits): | |
configs = [] | |
for comb_blocks in range(1, bits + 1): | |
for comb_teeth in range(1, 9): | |
comb_spacing = (bits - 1 + comb_blocks * comb_teeth) // (comb_blocks * comb_teeth) | |
assert comb_blocks * comb_teeth * comb_spacing >= bits | |
if comb_blocks * comb_teeth * (comb_spacing - 1) >= bits: | |
continue | |
if comb_blocks * (comb_teeth - 1) * comb_spacing >= bits: | |
continue | |
if (comb_blocks - 1) * comb_teeth * comb_spacing >= bits: | |
continue | |
tbl_size = (comb_blocks << (comb_teeth - 1)) * 64 | |
configs.append((tbl_size, comb_blocks, comb_teeth, comb_spacing)) | |
configs.sort() | |
return configs | |
def get_stats(comb_blocks, comb_teeth, comb_spacing): | |
comb_off = comb_spacing - 1 | |
comb_points = 1 << (comb_teeth - 1) | |
first = True | |
stat_cmovs = 0 | |
stat_adds = 0 | |
stat_dbls = 0 | |
while True: | |
bit_pos = comb_off | |
for block in range(comb_blocks): | |
stat_cmovs += comb_points | |
if first: | |
first = False | |
else: | |
stat_adds += 1 | |
if comb_off == 0: | |
break | |
comb_off -= 1 | |
stat_dbls += 1 | |
return (stat_adds, stat_dbls, stat_cmovs) | |
def simulate(comb_blocks, comb_teeth, comb_spacing, m, val): | |
# Precompute 1/2 mod m. | |
half_m = pow(2, -1, m) | |
# Precompute the offset of the first table's entries. | |
offset0 = (-half_m * sum(1 << (j * comb_spacing) for j in range(comb_teeth))) % m | |
# Precompute the first table. | |
table0 = [(offset0 + sum(((q >> j) & 1) << (j * comb_spacing) for j in range(comb_teeth))) % m for q in range(1 << comb_teeth)] | |
# Precompute all tables. | |
tables = [[(t << (j * comb_teeth * comb_spacing)) % m for t in table0] for j in range(comb_blocks)] | |
# Adjust scalar for signed-digit offset. | |
scalar = (val + ((1 << (comb_blocks * comb_teeth * comb_spacing)) - 1) * half_m) % m | |
# Simulate. | |
comb_off = comb_spacing - 1 | |
comb_points = 1 << (comb_teeth - 1) | |
first = True | |
ret = 0 | |
while True: | |
bit_pos = comb_off | |
for block in range(comb_blocks): | |
bits = 0 | |
for tooth in range(comb_teeth): | |
bits += ((scalar >> bit_pos) & 1) << tooth | |
bit_pos += comb_spacing | |
sign = (bits >> (comb_teeth - 1)) & 1 | |
absbits = (bits ^ -sign) & (comb_points - 1) | |
add = tables[block][absbits] | |
if sign: | |
add = (-add) % m | |
if first: | |
ret = add | |
first = False | |
else: | |
if add == ret: | |
return f"double block={block} comb_off={comb_off}" | |
ret = (ret + add) % m | |
if ret == 0: | |
return f"cancel block={block} comb_off={comb_off}" | |
if comb_off == 0: | |
break | |
comb_off -= 1 | |
ret = (ret << 1) % m | |
# Verify | |
assert ((ret - val) % m) == 0 | |
return "ok" | |
class MaskVal: | |
"""Represents a function f(v) where v is an integer in [0,O), of the form ((((v >> S) ^ X) & P) + C) * G. | |
In the expressions below we use order = O, shift = S, pattern = P, xor = X, constant = C. G is the | |
generator of a cyclic group of order O. | |
""" | |
def __init__(self, constant, order, pattern=0, xor=0, shift=0): | |
"""Initialize as expression with specified C, S, P, X, O.""" | |
assert (xor & ~pattern) == 0 | |
assert shift >= 0 | |
self._order = order | |
self._constant = constant % order | |
self._shift = shift | |
self._xor = xor | |
self._pattern = pattern | |
def __str__(self): | |
if self._pattern == 0: | |
return f"0x{self._constant:x}" | |
ret = "v" | |
if self._shift > 0: | |
ret = f"({ret} >> {self._shift})" | |
if self._xor != 0: | |
ret = f"({ret} ^ 0x{self._xor:x})" | |
if self._pattern != ((1 << (self._order.bit_length() - self._shift)) - 1): | |
ret = f"({ret} & 0x{self._pattern:x})" | |
if self._constant != 0: | |
ret = f"({ret} + 0x{self._constant:x})" | |
return ret | |
def __add__(self, other): | |
"""Add two compatible expressions, or expression and integer.""" | |
if isinstance(other, int): | |
other = MaskVal(other, self._order) | |
# We cannot combine expressions from groups of different order. | |
assert self._order == other._order | |
# To add two expressions, the patterns may not overlap. | |
assert (self._pattern & other._pattern) == 0 | |
# If both expressions have non-zero patterns, their shift must be equal. | |
assert self._pattern == 0 or other._pattern == 0 or self._shift == other._shift | |
return MaskVal( | |
constant=self._constant + other._constant, | |
order=self._order, | |
pattern=self._pattern | other._pattern, | |
xor=self._xor | other._xor, | |
shift=self._shift if self._pattern != 0 else other._shift) | |
def __neg__(self): | |
"""Negate an expression.""" | |
return MaskVal( | |
constant=-self._constant - self._pattern, | |
order=self._order, | |
pattern=self._pattern, | |
xor=self._xor ^ self._pattern, | |
shift=self._shift) | |
def __sub__(self, other): | |
"""Subtract two expressions.""" | |
return self + (-other) | |
def __lshift__(self, n): | |
"""Multiply expression by 2^n.""" | |
return MaskVal( | |
constant=self._constant << n, | |
order=self._order, | |
pattern=self._pattern << n, | |
xor=self._xor << n, | |
shift=self._shift - n) | |
def root(self): | |
"""Solve for v such that f(v) = 0.""" | |
# ((((v >> S) ^ X) & P) + C) * G = 0 | |
# <=> (((v >> S) ^ X) & P) + C = k * O (for some k) | |
# Let s = ((v >> S) ^ X) & P | |
# <=> s + C = k * O (for some k) | |
# s <= P, and thus s + C <= P + C. Loop over all values in k from 0 to floor(P + C) / O. | |
k = 0 | |
while k * self._order <= self._pattern + self._constant: | |
if k * self._order >= self._constant: | |
s = k * self._order - self._constant | |
# We have to check if s = ((v >> S) ^ X) & P, for some value of v. | |
if (s & ~self._pattern) == 0: | |
# Note that many v solutions may exist, by choosing the bits not covered by | |
# P. First we set all of them equal to 0, as that yields the smallest solution, | |
# so if any solution less than O exists, that one is certainly included. | |
v = (s ^ self._xor) << self._shift | |
# Finally we need to check if v is in range [0,O). | |
if v >= self._order: | |
continue | |
# Once we know a solution exists, randomly set the unconstrained bits until a | |
# not-too-large result appears. | |
while True: | |
r = v + (random.randrange(1 << self._order.bit_length()) & ~(self._pattern << self._shift)) | |
if r < self._order: | |
return r | |
k += 1 | |
return None | |
def find_offences(comb_blocks, comb_teeth, comb_spacing, m): | |
# Precompute the bit pattern covered by the first table (block=0). | |
comb_bits = sum(1 << (j * comb_spacing) for j in range(comb_teeth)) | |
# Precompute -1/2 mod m. | |
half_m = pow(2, -1, m) | |
# Precompute the mask of all bits in the scalar. | |
mask_all = (1 << m.bit_length()) - 1 | |
# Count number of additions. | |
num_add = 0 | |
tot_add = comb_blocks * comb_spacing - 1 | |
# Symbolically execute the algorithm. | |
comb_off = comb_spacing - 1 | |
first = True | |
res = MaskVal(order=m, constant=0) | |
while True: | |
bit_pos = comb_off | |
for block in range(comb_blocks): | |
constant = -(comb_bits << (bit_pos - comb_off)) * half_m | |
scalar_bits = (comb_bits << bit_pos) & mask_all | |
add = MaskVal(order=m, constant=constant, pattern=scalar_bits >> comb_off, shift=comb_off) | |
if first: | |
res = add | |
first = False | |
else: | |
num_add += 1 | |
summed = res + add | |
diffed = res - add | |
root_summed = summed.root() | |
root_diffed = diffed.root() | |
if root_summed is not None: | |
root_summed_cor = (root_summed - half_m * ((1 << (comb_blocks * comb_teeth * comb_spacing)) - 1)) % m | |
sim = simulate(comb_blocks, comb_teeth, comb_spacing, m, root_summed_cor) | |
print(f"- offset={comb_off}, block={block} (add {num_add}/{tot_add}): cancel 0x{root_summed_cor:x}: {sim}") | |
if root_diffed is not None: | |
root_diffed_cor = (root_diffed - half_m * ((1 << (comb_blocks * comb_teeth * comb_spacing)) - 1)) % m | |
sim = simulate(comb_blocks, comb_teeth, comb_spacing, m, root_diffed_cor) | |
print(f"- offset={comb_off}, block={block} (add {num_add}/{tot_add}): double 0x{root_diffed_cor:x}: {sim}") | |
res = summed | |
bit_pos += comb_spacing * comb_teeth | |
if comb_off == 0: | |
break | |
res = (res << 1) | |
comb_off -= 1 | |
res = res + (half_m) * ((1 << (comb_blocks * comb_teeth * comb_spacing)) - 1) | |
assert str(res) == "v" | |
def sdmc_table_access(blocks, teeth, spacing, scalar, m, halfm, accessed): | |
assert (2 * halfm) % m == 1 | |
off = spacing - 1 | |
points = 1 << (teeth - 1) | |
accessed <<= (blocks << (teeth - 1)) | |
while True: | |
bit_pos = off | |
for block in range(blocks): | |
bits = 0 | |
for tooth in range(teeth): | |
bits += ((scalar >> bit_pos) & 1) << tooth | |
bit_pos += spacing | |
sign = (bits >> (teeth - 1)) & 1 | |
absbits = (bits ^ -sign) & (points - 1) | |
accessed |= (1 << (block * points + absbits)) | |
if off == 0: | |
break | |
off -= 1 | |
return accessed | |
def old_table_access(groupsize, groups, scalar, accessed): | |
mask = (1 << groupsize) - 1 | |
accessed = accessed << (groups << groupsize) | |
for off in range(0, groupsize * groups, groupsize): | |
tblpos = (scalar >> off) & mask | |
accessed = (accessed << (1 << groupsize)) | (1 << tblpos) | |
return accessed | |
from hashlib import sha256 | |
def build_tests(sdmc_configs, old_configs, m, cpu, cpus, maxval): | |
halfm = pow(2, -1, m) | |
i = cpu | |
tacc = 0 | |
cnt = 0 | |
with open(f"simsdmc-{cpu}-{cpus}.txt", "w") as f: | |
while i < maxval: | |
val = int.from_bytes(sha256(i.to_bytes(4, 'little')).digest(), 'big') | |
sacc = 0 | |
for _, blocks, teeth, spacing in sdmc_configs: | |
sacc = sdmc_table_access(blocks, teeth, spacing, val, m, halfm, sacc) | |
for groupsize, groups in old_configs: | |
sacc = old_table_access(groupsize, groups, val, sacc) | |
f.write(f"{i} 0x{sacc:x}\n") | |
tacc |= sacc | |
cnt += 1 | |
if (cnt & 0xff) == 0: | |
print(f"[{cpu}/{cpus}]: {i} ({tacc.bit_length()} features)") | |
i += cpus | |
import sys | |
maxval = int(sys.argv[1]) or 1048576 | |
cpu = int(sys.argv[2]) or 0 | |
cpus = int(sys.argv[3]) or 1 | |
build_tests(get_configs(256), [(4, 64)], N, cpu, cpus, maxval) | |
exit() | |
for tbl_size, comb_blocks, comb_teeth, comb_spacing in get_configs(256): | |
stat_adds, stat_dbls, stat_cmovs = get_stats(comb_blocks, comb_teeth, comb_spacing) | |
print(f"(blk={comb_blocks},tth={comb_teeth},spc={comb_spacing}): tbl={tbl_size} adds={stat_adds} dbls={stat_dbls} cmovs={stat_cmovs}") | |
find_offences(comb_blocks, comb_teeth, comb_spacing, N) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment