Skip to content

Instantly share code, notes, and snippets.

@sipa
Last active December 11, 2022 19:16
Show Gist options
  • Save sipa/868c39e29af9e8baf22845c8af2d316d to your computer and use it in GitHub Desktop.
Save sipa/868c39e29af9e8baf22845c8af2d316d to your computer and use it in GitHub Desktop.
SDMC simulation
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <stdint.h>
#include <assert.h>
#include <chrono>
#include <atomic>
#include <mutex>
#include <thread>
#include <optional>
#include <tuple>
#include <cmath>
#include <iostream>
#include <fstream>
#include <vector>
#include <algorithm>
#include <string>
#include <string_view>
#include <random>
namespace {
static constexpr unsigned int FEATURES = 33264;
static constexpr unsigned int THREADS = 30;
template<unsigned BITS>
class BitSet {
static constexpr unsigned WORDS = (BITS + 63) / 64;
uint64_t val[WORDS];
public:
BitSet() {
std::fill(std::begin(val), std::end(val), 0);
}
BitSet(const BitSet&) = default;
BitSet& operator=(const BitSet&) = default;
BitSet& operator|=(const BitSet& other) {
for (unsigned i = 0; i < WORDS; ++i) val[i] |= other.val[i];
return *this;
}
BitSet& operator&=(const BitSet& other) {
for (unsigned i = 0; i < WORDS; ++i) val[i] &= other.val[i];
return *this;
}
BitSet& Remove(const BitSet& other) {
for (unsigned i = 0; i < WORDS; ++i) val[i] &= ~other.val[i];
return *this;
}
uint64_t Count() const {
uint64_t ret = 0;
for (unsigned i = 0; i < WORDS; ++i) ret += __builtin_popcountl(val[i]);
return ret;
}
explicit operator bool() const {
for (unsigned i = 0; i < WORDS; ++i) {
if (val[i]) return true;
}
return false;
}
bool Get(unsigned pos) const {
return (val[pos >> 6] >> (pos & 63)) & 1;
}
void Set(unsigned pos) {
val[pos >> 6] |= uint64_t{1} << (pos & 63);
}
void Set(unsigned pos, bool bit) {
val[pos >> 6] = (val[pos >> 6] & ~(uint64_t{1} << (pos & 63))) | (uint64_t{bit} << (pos & 63));
}
void Unset(unsigned pos) {
val[pos >> 6] &= ~(uint64_t{1} << (pos & 63));
}
void OrHex(unsigned hexpos, uint8_t v) {
val[hexpos >> 4] |= uint64_t{v} << (4 * (hexpos & 15));
}
friend bool operator!=(const BitSet& a, const BitSet& b) {
for (unsigned i = 0; i < WORDS; ++i) {
if (a.val[i] != b.val[i]) return true;
}
return false;
}
friend bool operator==(const BitSet& a, const BitSet& b) {
for (unsigned i = 0; i < WORDS; ++i) {
if (a.val[i] != b.val[i]) return false;
}
return true;
}
friend bool operator<(const BitSet& a, const BitSet& b) {
for (unsigned i = 0; i < WORDS; ++i) {
if (a.val[i] < b.val[i]) return true;
if (a.val[i] > b.val[i]) return false;
}
return false;
}
};
template<unsigned BITS>
BitSet<BITS> FromString(std::string_view view) {
static constexpr int8_t TBL[256] = {
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
0,1,2,3,4,5,6,7,8,9,-1,-1,-1,-1,-1,-1,
-1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,0xa,0xb,0xc,0xd,0xe,0xf,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,
-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1
};
if (view.size() >= 2 && view[0] == '0' && view[1] == 'x') {
view = view.substr(2, view.size() - 2);
}
if (view.size() * 4 > BITS) {
assert(!"Too long strong");
abort();
}
BitSet<BITS> ret;
for (unsigned i = 0; i < view.size(); ++i) {
int c = TBL[(uint8_t)view[i]];
if (c == -1) {
assert(!"Invalid hex character");
abort();
}
ret.OrHex(view.size() - 1 - i, uint8_t(c));
}
return ret;
}
template<unsigned BITS>
unsigned LoadFile(std::vector<BitSet<BITS>>& out, std::istream& istream, unsigned max_elem, unsigned thread_count) {
std::mutex mutex;
unsigned total_loaded = 0;
out.resize(max_elem);
auto thread_fn = [&]() {
std::string line;
unsigned loaded = 0;
while (true) {
{
std::unique_lock<std::mutex> lock(mutex);
total_loaded += loaded;
if (loaded && (total_loaded & 0xffff) == 0) {
printf("Loaded %u entries.\n", total_loaded);
}
loaded = 0;
if (!std::getline(istream, line)) break;
}
char* ptr;
unsigned long idx = strtoul(line.c_str(), &ptr, 0);
if (*ptr != ' ') {
assert(!"Missing space");
abort();
}
if (idx >= max_elem) break;
if (!out[idx]) {
out[idx] = FromString<BITS>(std::string_view{line}.substr(ptr - line.c_str() + 1));
loaded += 1;
}
}
};
std::vector<std::thread> threads;
threads.reserve(thread_count - 1);
for (unsigned i = 1; i < thread_count; ++i) threads.emplace_back(thread_fn);
thread_fn();
for (auto& thread : threads) thread.join();
while (!out.empty() && !out.back()) out.pop_back();
return total_loaded;
}
std::vector<BitSet<FEATURES>> TABLE;
std::vector<uint64_t> WEIGHT;
uint64_t GetScore(const BitSet<FEATURES>& mask) {
/* double ret = 0;
for (unsigned i = 0; i < FEATURES; ++i) {
if (mask.Get(i)) ret += WEIGHT[i];
}
return ret;*/
return mask.Count();
}
/* Set out = in = max(in, out), atomically. */
void AtomicMax(std::atomic<uint64_t>& global, uint64_t& local) {
uint64_t old_global = global.load();
while (true) {
if (old_global >= local) {
local = old_global;
break;
}
if (global.compare_exchange_weak(old_global, local)) break;
}
}
uint64_t RdSeed() {
uint8_t ok;
uint64_t r;
do {
__asm__ volatile ("rdseed %0; setc %1" : "=a"(r), "=q"(ok) :: "cc");
if (ok) return r;
__asm__ volatile ("pause");
} while(1);
}
template<int K>
inline uint64_t Rotl(const uint64_t x) {
return (x << K) | (x >> (64 - K));
}
class RNG {
uint64_t v0, v1, v2, v3;
public:
RNG() : v0(RdSeed()), v1(RdSeed()), v2(RdSeed()), v3(RdSeed()) {}
/* Xoshiro256++ */
uint64_t operator()() {
const uint64_t result = Rotl<23>(v0 + v3) + v0;
const uint64_t t = v1 << 17;
v2 ^= v0;
v3 ^= v1;
v1 ^= v2;
v0 ^= v3;
v2 ^= t;
v3 = Rotl<45>(v3);
return result;
}
uint64_t randrange(uint64_t max) {
if (max == 0) return 0;
static_assert(sizeof(unsigned long) == sizeof(uint64_t));
uint64_t mask = 0xffffffffffffffff >> __builtin_clzl(max);
while (true) {
uint64_t rand = (*this)() & mask;
if (rand < max) return rand;
}
}
static double constexpr entropy() { return 0.0; }
static uint64_t constexpr min() { return 0; }
static uint64_t constexpr max() { return 0xffffffffffffffff; }
};
template <typename I, typename R>
void Shuffle(I first, I last, R&& rng)
{
while (first != last) {
size_t j = rng.randrange(last - first - 1);
if (j) {
using std::swap;
swap(*first, *(first + j));
}
++first;
}
}
template <typename I, typename C>
I Prune(I first, I last, C cond)
{
while (first != last) {
if (cond(*first)) {
std::advance(last, -1);
if (first != last) {
using std::swap;
swap(*first, *last);
}
} else {
std::advance(first, 1);
}
}
return last;
}
std::vector<unsigned> Optimize(unsigned bufsize, unsigned memsize, int threads, unsigned batch_size) {
using bufelem = std::tuple<uint64_t, BitSet<FEATURES>, std::vector<unsigned>>;
using buftype = std::vector<bufelem>;
static const std::string LOGSTR[4] = {"Local inner", "Local final", "Merge inner", "Global final"};
buftype buf, nbuf;
BitSet<FEATURES> allmask;
for (const auto& mask : TABLE) {
allmask |= mask;
}
buf.emplace_back(0.0, BitSet<FEATURES>{}, std::vector<unsigned>{});
const unsigned tblsize = TABLE.size();
RNG grng;
std::atomic<uint64_t> glimit{0};
std::mutex mutex;
std::optional<std::vector<unsigned>> solution;
using float_ms = std::chrono::duration<double, std::ratio<1, 1000>>;
auto compact_fn = [&](buftype& buffer, uint64_t& limit, int thread, int stage, RNG& rng) {
auto start = std::chrono::steady_clock::now();
unsigned old_size = buffer.size();
uint64_t old_limit = limit;
if (old_size > bufsize) {
// Prune.
buffer.erase(Prune(std::begin(buffer), std::end(buffer), [limit](const bufelem& a) { return std::get<0>(a) < limit; }), std::end(buffer));
}
unsigned prn_size = buffer.size();
if (old_size > bufsize || stage == 3) {
std::sort(std::begin(buffer), std::end(buffer), [](const bufelem& a, const bufelem& b) {
return std::tie(std::get<0>(a), std::get<1>(a)) > std::tie(std::get<0>(b), std::get<1>(b));
});
// Deduplicate.
buffer.erase(std::unique(std::begin(buffer), std::end(buffer), [](const bufelem& a, const bufelem& b) { return std::get<0>(a) == std::get<0>(b) && std::get<1>(a) == std::get<1>(b); }), buffer.end());
}
unsigned mid_size = buffer.size();
if (mid_size > bufsize) {
// Shuffle similar scores.
uint64_t last_score = std::get<0>(buffer[bufsize - 1]);
auto end_last_score = buffer.begin() + bufsize;
auto begin_last_score = std::prev(end_last_score);
while (begin_last_score != buffer.begin() && std::get<0>(*std::prev(begin_last_score)) == last_score) {
std::advance(begin_last_score, -1);
}
while (end_last_score != buffer.end() && std::get<0>(*end_last_score) == last_score) {
std::advance(end_last_score, 1);
}
Shuffle(begin_last_score, end_last_score, rng);
// Compact.
buffer.resize(bufsize);
limit = std::max(limit, std::get<0>(buffer.back()));
}
unsigned fin_size = buffer.size();
auto stop = std::chrono::steady_clock::now();
printf("- [thread %i/%i] %s compaction orig=%u -> prune=%u -> dedup=%u -> shrink=%u: minscore=%lu->%lu time=%gms\n", thread, threads, LOGSTR[stage].c_str(), old_size, prn_size, mid_size, fin_size, old_limit, limit, float_ms(stop - start).count());
};
std::atomic<uint64_t> todo_pos;
uint64_t todo;
auto thread_fn = [&](int thread) {
buftype lbuf;
uint64_t limit{0};
std::vector<unsigned> nsol;
const unsigned bufsize = buf.size();
RNG trng;
while (true) {
uint64_t pos = todo_pos.fetch_add(batch_size);
if (pos >= todo) break;
unsigned tbl_pos = pos / bufsize;
unsigned buf_pos = pos - uint64_t{tbl_pos} * bufsize;
int todo_local = batch_size;
limit = std::max(limit, glimit.load());
while (todo_local--) {
const auto& tbl_entry = TABLE[tbl_pos];
const auto& buf_entry = buf[buf_pos];
auto nmask = tbl_entry;
if (nmask) {
nmask.Remove(std::get<1>(buf_entry));
if (nmask) {
uint64_t nscore = GetScore(nmask) + std::get<0>(buf_entry);
if (nscore >= limit) {
nsol = std::get<2>(buf_entry);
nsol.push_back(tbl_pos);
nmask |= std::get<1>(buf_entry);
if (nmask == allmask) {
std::unique_lock<std::mutex> lock(mutex);
solution = nsol;
todo_pos = todo;
return;
}
lbuf.emplace_back(nscore, nmask, std::move(nsol));
if (lbuf.size() >= memsize) {
limit = std::max(limit, glimit.load());
compact_fn(lbuf, limit, thread, 0, trng);
AtomicMax(glimit, limit);
}
}
}
}
buf_pos += 1;
if (buf_pos == bufsize) {
buf_pos = 0;
tbl_pos += 1;
if (tbl_pos == tblsize) break;
}
}
}
limit = std::max(limit, glimit.load());
compact_fn(lbuf, limit, thread, 1, trng);
AtomicMax(glimit, limit);
std::unique_lock<std::mutex> lock(mutex);
if (solution.has_value()) return;
for (auto& entry : lbuf) {
if (std::get<0>(entry) >= limit) {
nbuf.emplace_back(std::move(entry));
if (nbuf.size() >= memsize) {
limit = std::max(limit, glimit.load());
compact_fn(nbuf, limit, thread, 2, trng);
AtomicMax(glimit, limit);
}
}
}
};
while (buf.size()) {
printf("Iteration: %u elements, %u vectors, scores %lu..%lu\n", (unsigned)buf.size(), (unsigned)std::get<2>(buf.front()).size(), std::get<0>(buf.back()), std::get<0>(buf.front()));
nbuf.clear();
todo = uint64_t{buf.size()} * tblsize;
todo_pos = 0;
std::vector<std::thread> vthreads;
vthreads.reserve(threads);
for (int thread = 0; thread < threads; ++thread) {
vthreads.emplace_back(thread_fn, thread);
}
for (int thread = 0; thread < threads; ++thread) {
vthreads[thread].join();
}
{
std::unique_lock<std::mutex> lock(mutex);
if (solution.has_value()) return *solution;
}
uint64_t limit = glimit.load();
compact_fn(nbuf, limit, -1, 3, grng);
std::swap(buf, nbuf);
}
return {};
}
}
int main(int argc, char** argv) {
unsigned long max_elem = strtoul(argv[1], NULL, 0);
unsigned total_loaded = 0;
std::vector<BitSet<FEATURES>> load;
load.reserve(max_elem);
setlinebuf(stdout);
for (int j = 2; j < argc; ++j) {
std::ifstream ifs(argv[j]);
unsigned loaded = LoadFile(load, ifs, max_elem, 1 + !!THREADS);
printf("Loaded %u elements from %s\n", loaded, argv[j]);
total_loaded += loaded;
}
Shuffle(load.begin(), load.end(), RNG{});
std::vector<unsigned> feature_counts(FEATURES, 0);
unsigned nonzero_loaded = 0;
BitSet<FEATURES> allmask;
for (const auto& elem : load) {
if (elem) {
allmask |= elem;
for (unsigned f = 0; f < FEATURES; ++f) {
feature_counts[f] += elem.Get(f);
}
nonzero_loaded += 1;
}
}
printf("Loaded %u elements in total (%u nonzero)\n", total_loaded, nonzero_loaded);
std::vector<double> feature_logs(FEATURES, 0.0);
unsigned total_features = 0;
unsigned cond_features = 0;
double total_logs = 0.0;
for (unsigned f = 0; f < FEATURES; ++f) {
unsigned c = feature_counts[f];
if (c == 0) {
feature_logs[f] = 0.0;
} else {
if (c == total_loaded) {
feature_logs[f] = 0.0;
} else {
feature_logs[f] = std::log(double(c) / total_loaded) * -1.4426950408889634073599246810;
total_logs += feature_logs[f];
cond_features += 1;
}
total_features += 1;
}
}
printf("Number of used features: %u (or %u)\n", total_features, (unsigned)allmask.Count());
printf("Number of not-always used features: %u\n", cond_features);
printf("Rarest feature probability: 1/2^%f\n", *std::max_element(feature_logs.begin(), feature_logs.end()));
printf("Total -log2(prob): %f\n", total_logs);
double scale = 0xFFFFFFFFFFFFFFFF / total_logs;
std::vector<uint64_t> feature_weights(FEATURES, 0);
uint64_t total_weight = 0;
for (unsigned f = 0; f < FEATURES; ++f) {
feature_weights[f] = uint64_t(feature_logs[f] * scale);
total_weight += feature_weights[f];
}
feature_logs = {};
printf("Total weight: %lu\n", total_weight);
TABLE = std::move(load);
WEIGHT = std::move(feature_weights);
auto ret = Optimize(2048, 65536, THREADS, 16384);
BitSet<FEATURES> reconstruct;
printf("Solution: %u vectors\n", (unsigned)ret.size());
printf("Solution:");
for (auto i : ret) {
printf(" %u,", (unsigned)i);
reconstruct |= TABLE[i];
}
printf("\n");
printf("Equal: %i\n", (int)(reconstruct == allmask));
return 0;
}
import random
N = 115792089237316195423570985008687907852837564279074904382605163141518161494337
def get_configs(bits):
configs = []
for comb_blocks in range(1, bits + 1):
for comb_teeth in range(1, 9):
comb_spacing = (bits - 1 + comb_blocks * comb_teeth) // (comb_blocks * comb_teeth)
assert comb_blocks * comb_teeth * comb_spacing >= bits
if comb_blocks * comb_teeth * (comb_spacing - 1) >= bits:
continue
if comb_blocks * (comb_teeth - 1) * comb_spacing >= bits:
continue
if (comb_blocks - 1) * comb_teeth * comb_spacing >= bits:
continue
tbl_size = (comb_blocks << (comb_teeth - 1)) * 64
configs.append((tbl_size, comb_blocks, comb_teeth, comb_spacing))
configs.sort()
return configs
def get_stats(comb_blocks, comb_teeth, comb_spacing):
comb_off = comb_spacing - 1
comb_points = 1 << (comb_teeth - 1)
first = True
stat_cmovs = 0
stat_adds = 0
stat_dbls = 0
while True:
bit_pos = comb_off
for block in range(comb_blocks):
stat_cmovs += comb_points
if first:
first = False
else:
stat_adds += 1
if comb_off == 0:
break
comb_off -= 1
stat_dbls += 1
return (stat_adds, stat_dbls, stat_cmovs)
def simulate(comb_blocks, comb_teeth, comb_spacing, m, val):
# Precompute 1/2 mod m.
half_m = pow(2, -1, m)
# Precompute the offset of the first table's entries.
offset0 = (-half_m * sum(1 << (j * comb_spacing) for j in range(comb_teeth))) % m
# Precompute the first table.
table0 = [(offset0 + sum(((q >> j) & 1) << (j * comb_spacing) for j in range(comb_teeth))) % m for q in range(1 << comb_teeth)]
# Precompute all tables.
tables = [[(t << (j * comb_teeth * comb_spacing)) % m for t in table0] for j in range(comb_blocks)]
# Adjust scalar for signed-digit offset.
scalar = (val + ((1 << (comb_blocks * comb_teeth * comb_spacing)) - 1) * half_m) % m
# Simulate.
comb_off = comb_spacing - 1
comb_points = 1 << (comb_teeth - 1)
first = True
ret = 0
while True:
bit_pos = comb_off
for block in range(comb_blocks):
bits = 0
for tooth in range(comb_teeth):
bits += ((scalar >> bit_pos) & 1) << tooth
bit_pos += comb_spacing
sign = (bits >> (comb_teeth - 1)) & 1
absbits = (bits ^ -sign) & (comb_points - 1)
add = tables[block][absbits]
if sign:
add = (-add) % m
if first:
ret = add
first = False
else:
if add == ret:
return f"double block={block} comb_off={comb_off}"
ret = (ret + add) % m
if ret == 0:
return f"cancel block={block} comb_off={comb_off}"
if comb_off == 0:
break
comb_off -= 1
ret = (ret << 1) % m
# Verify
assert ((ret - val) % m) == 0
return "ok"
class MaskVal:
"""Represents a function f(v) where v is an integer in [0,O), of the form ((((v >> S) ^ X) & P) + C) * G.
In the expressions below we use order = O, shift = S, pattern = P, xor = X, constant = C. G is the
generator of a cyclic group of order O.
"""
def __init__(self, constant, order, pattern=0, xor=0, shift=0):
"""Initialize as expression with specified C, S, P, X, O."""
assert (xor & ~pattern) == 0
assert shift >= 0
self._order = order
self._constant = constant % order
self._shift = shift
self._xor = xor
self._pattern = pattern
def __str__(self):
if self._pattern == 0:
return f"0x{self._constant:x}"
ret = "v"
if self._shift > 0:
ret = f"({ret} >> {self._shift})"
if self._xor != 0:
ret = f"({ret} ^ 0x{self._xor:x})"
if self._pattern != ((1 << (self._order.bit_length() - self._shift)) - 1):
ret = f"({ret} & 0x{self._pattern:x})"
if self._constant != 0:
ret = f"({ret} + 0x{self._constant:x})"
return ret
def __add__(self, other):
"""Add two compatible expressions, or expression and integer."""
if isinstance(other, int):
other = MaskVal(other, self._order)
# We cannot combine expressions from groups of different order.
assert self._order == other._order
# To add two expressions, the patterns may not overlap.
assert (self._pattern & other._pattern) == 0
# If both expressions have non-zero patterns, their shift must be equal.
assert self._pattern == 0 or other._pattern == 0 or self._shift == other._shift
return MaskVal(
constant=self._constant + other._constant,
order=self._order,
pattern=self._pattern | other._pattern,
xor=self._xor | other._xor,
shift=self._shift if self._pattern != 0 else other._shift)
def __neg__(self):
"""Negate an expression."""
return MaskVal(
constant=-self._constant - self._pattern,
order=self._order,
pattern=self._pattern,
xor=self._xor ^ self._pattern,
shift=self._shift)
def __sub__(self, other):
"""Subtract two expressions."""
return self + (-other)
def __lshift__(self, n):
"""Multiply expression by 2^n."""
return MaskVal(
constant=self._constant << n,
order=self._order,
pattern=self._pattern << n,
xor=self._xor << n,
shift=self._shift - n)
def root(self):
"""Solve for v such that f(v) = 0."""
# ((((v >> S) ^ X) & P) + C) * G = 0
# <=> (((v >> S) ^ X) & P) + C = k * O (for some k)
# Let s = ((v >> S) ^ X) & P
# <=> s + C = k * O (for some k)
# s <= P, and thus s + C <= P + C. Loop over all values in k from 0 to floor(P + C) / O.
k = 0
while k * self._order <= self._pattern + self._constant:
if k * self._order >= self._constant:
s = k * self._order - self._constant
# We have to check if s = ((v >> S) ^ X) & P, for some value of v.
if (s & ~self._pattern) == 0:
# Note that many v solutions may exist, by choosing the bits not covered by
# P. First we set all of them equal to 0, as that yields the smallest solution,
# so if any solution less than O exists, that one is certainly included.
v = (s ^ self._xor) << self._shift
# Finally we need to check if v is in range [0,O).
if v >= self._order:
continue
# Once we know a solution exists, randomly set the unconstrained bits until a
# not-too-large result appears.
while True:
r = v + (random.randrange(1 << self._order.bit_length()) & ~(self._pattern << self._shift))
if r < self._order:
return r
k += 1
return None
def find_offences(comb_blocks, comb_teeth, comb_spacing, m):
# Precompute the bit pattern covered by the first table (block=0).
comb_bits = sum(1 << (j * comb_spacing) for j in range(comb_teeth))
# Precompute -1/2 mod m.
half_m = pow(2, -1, m)
# Precompute the mask of all bits in the scalar.
mask_all = (1 << m.bit_length()) - 1
# Count number of additions.
num_add = 0
tot_add = comb_blocks * comb_spacing - 1
# Symbolically execute the algorithm.
comb_off = comb_spacing - 1
first = True
res = MaskVal(order=m, constant=0)
while True:
bit_pos = comb_off
for block in range(comb_blocks):
constant = -(comb_bits << (bit_pos - comb_off)) * half_m
scalar_bits = (comb_bits << bit_pos) & mask_all
add = MaskVal(order=m, constant=constant, pattern=scalar_bits >> comb_off, shift=comb_off)
if first:
res = add
first = False
else:
num_add += 1
summed = res + add
diffed = res - add
root_summed = summed.root()
root_diffed = diffed.root()
if root_summed is not None:
root_summed_cor = (root_summed - half_m * ((1 << (comb_blocks * comb_teeth * comb_spacing)) - 1)) % m
sim = simulate(comb_blocks, comb_teeth, comb_spacing, m, root_summed_cor)
print(f"- offset={comb_off}, block={block} (add {num_add}/{tot_add}): cancel 0x{root_summed_cor:x}: {sim}")
if root_diffed is not None:
root_diffed_cor = (root_diffed - half_m * ((1 << (comb_blocks * comb_teeth * comb_spacing)) - 1)) % m
sim = simulate(comb_blocks, comb_teeth, comb_spacing, m, root_diffed_cor)
print(f"- offset={comb_off}, block={block} (add {num_add}/{tot_add}): double 0x{root_diffed_cor:x}: {sim}")
res = summed
bit_pos += comb_spacing * comb_teeth
if comb_off == 0:
break
res = (res << 1)
comb_off -= 1
res = res + (half_m) * ((1 << (comb_blocks * comb_teeth * comb_spacing)) - 1)
assert str(res) == "v"
def sdmc_table_access(blocks, teeth, spacing, scalar, m, halfm, accessed):
assert (2 * halfm) % m == 1
off = spacing - 1
points = 1 << (teeth - 1)
accessed <<= (blocks << (teeth - 1))
while True:
bit_pos = off
for block in range(blocks):
bits = 0
for tooth in range(teeth):
bits += ((scalar >> bit_pos) & 1) << tooth
bit_pos += spacing
sign = (bits >> (teeth - 1)) & 1
absbits = (bits ^ -sign) & (points - 1)
accessed |= (1 << (block * points + absbits))
if off == 0:
break
off -= 1
return accessed
def old_table_access(groupsize, groups, scalar, accessed):
mask = (1 << groupsize) - 1
accessed = accessed << (groups << groupsize)
for off in range(0, groupsize * groups, groupsize):
tblpos = (scalar >> off) & mask
accessed = (accessed << (1 << groupsize)) | (1 << tblpos)
return accessed
from hashlib import sha256
def build_tests(sdmc_configs, old_configs, m, cpu, cpus, maxval):
halfm = pow(2, -1, m)
i = cpu
tacc = 0
cnt = 0
with open(f"simsdmc-{cpu}-{cpus}.txt", "w") as f:
while i < maxval:
val = int.from_bytes(sha256(i.to_bytes(4, 'little')).digest(), 'big')
sacc = 0
for _, blocks, teeth, spacing in sdmc_configs:
sacc = sdmc_table_access(blocks, teeth, spacing, val, m, halfm, sacc)
for groupsize, groups in old_configs:
sacc = old_table_access(groupsize, groups, val, sacc)
f.write(f"{i} 0x{sacc:x}\n")
tacc |= sacc
cnt += 1
if (cnt & 0xff) == 0:
print(f"[{cpu}/{cpus}]: {i} ({tacc.bit_length()} features)")
i += cpus
import sys
maxval = int(sys.argv[1]) or 1048576
cpu = int(sys.argv[2]) or 0
cpus = int(sys.argv[3]) or 1
build_tests(get_configs(256), [(4, 64)], N, cpu, cpus, maxval)
exit()
for tbl_size, comb_blocks, comb_teeth, comb_spacing in get_configs(256):
stat_adds, stat_dbls, stat_cmovs = get_stats(comb_blocks, comb_teeth, comb_spacing)
print(f"(blk={comb_blocks},tth={comb_teeth},spc={comb_spacing}): tbl={tbl_size} adds={stat_adds} dbls={stat_dbls} cmovs={stat_cmovs}")
find_offences(comb_blocks, comb_teeth, comb_spacing, N)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment