Skip to content

Instantly share code, notes, and snippets.

@Validark
Last active January 3, 2024 14:37
Show Gist options
  • Save Validark/1c49b2b00ff930df76a3ee1d22f18244 to your computer and use it in GitHub Desktop.
Save Validark/1c49b2b00ff930df76a3ee1d22f18244 to your computer and use it in GitHub Desktop.
Select nth bit from bitstring
const std = @import("std");
const builtin = @import("builtin");
const assert = std.debug.assert;
inline fn pdep(src: u64, mask: u64) u64 {
return asm ("pdep %[mask], %[src], %[ret]"
: [ret] "=r" (-> u64),
: [src] "r" (src),
[mask] "r" (mask),
);
}
const USE_POPCNT = switch (builtin.cpu.arch) {
.aarch64_32, .aarch64_be, .aarch64 => false,
.mips, .mips64, .mips64el, .mipsel => std.Target.mips.featureSetHas(builtin.cpu.features, .cnmips),
.powerpc, .powerpc64, .powerpc64le, .powerpcle => std.Target.powerpc.featureSetHas(builtin.cpu.features, .popcntd),
.s390x => std.Target.s390x.featureSetHas(builtin.cpu.features, .miscellaneous_extensions_3),
.ve => true,
.avr => true,
.msp430 => true,
.riscv32, .riscv64 => std.Target.riscv.featureSetHas(builtin.cpu.features, .zbb),
.sparc, .sparc64, .sparcel => std.Target.sparc.featureSetHas(builtin.cpu.features, .popc),
.wasm32, .wasm64 => true,
.x86, .x86_64 => std.Target.x86.featureSetHas(builtin.cpu.features, .popcnt),
else => false,
};
/// Returns the the position of the nth bit of m (0-indexed, 64 is the error condition, i.e. when @popCount(m) <= n)
fn select(m: u64, n: u6) u8 {
const USE_LOOKUP_TABLE = true;
const cpu_name = builtin.cpu.model.llvm_name orelse builtin.cpu.model.name;
if (!@inComptime() and comptime builtin.cpu.arch == .x86_64 and
std.Target.x86.featureSetHas(builtin.cpu.features, .bmi2) and
// pdep is microcoded (slow) on AMD architectures before Zen 3.
(!std.mem.startsWith(u8, cpu_name, "znver") or cpu_name["znver".len] >= '3'))
{
return @ctz(pdep(@as(u64, 1) << n, m));
}
comptime var lookup_table: [256][8]u8 = undefined;
comptime if (USE_LOOKUP_TABLE) {
@setEvalBranchQuota(std.math.maxInt(u32));
for (&lookup_table, 0..) |*slot, i| {
for (slot, 0..) |*sub_slot, j| {
sub_slot.* = selectByte(i, j);
}
}
};
const ones: u64 = 0x0101010101010101;
var i = m;
i -= (i >> 1) & 0x5555555555555555;
i = (i & 0x3333333333333333) + ((i >> 2) & 0x3333333333333333);
const prefix_sums = (((i + (i >> 4)) & 0x0F0F0F0F0F0F0F0F) *% ones);
assert((prefix_sums & 0x8080808080808080) == 0);
const broadcasted = ones * (@as(u64, n) | 0x80);
const bit_isolate = ones * if (USE_POPCNT) 0x80 else 0x01;
const mask = ((broadcasted - prefix_sums) >> if (USE_POPCNT) 0 else 7) & bit_isolate;
// prove it is safe to optimize (x >> 56) << 3 to (x >> 53)
const max_byte_index = @as(u64, ones) *% ones;
assert(((max_byte_index >> 53) & 0b111) == 0);
if (mask == bit_isolate) return 64;
const byte_index: u6 = if (USE_POPCNT)
@intCast(@popCount(mask) << 3)
else
@intCast((mask *% ones) >> 53);
const prefix_sum: u6 = @truncate(prefix_sums << 8 >> byte_index);
const target_byte: u8 = @truncate(m >> byte_index);
const n_for_target_byte: u3 = @intCast(n - prefix_sum);
return if (USE_LOOKUP_TABLE)
lookup_table[target_byte][n_for_target_byte] + byte_index
else
selectByte(target_byte, @intCast(n_for_target_byte)) + byte_index;
}
fn selectByte(m: u8, n: u3) u4 {
const ones: u64 = 0x0101010101010101;
const unique_bytes: u64 = 0x8040_2010_0804_0201;
const unique_bytes_diff_from_msb = (ones * 0x80) - unique_bytes;
const prefix_sums = (((((m *% ones) & unique_bytes) + unique_bytes_diff_from_msb) >> 7) & ones) *% ones;
const broadcasted = ones * (@as(u64, n) | 0x80);
const bit_isolate = ones * if (USE_POPCNT) 0x80 else 0x01;
const mask = (((broadcasted - prefix_sums) >> if (USE_POPCNT) 0 else 7) & bit_isolate);
if (mask == bit_isolate) return 8;
const bit_index: u3 = if (USE_POPCNT)
@intCast(@popCount(mask))
else
@intCast((mask *% ones) >> 56);
return bit_index;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment