Last active
January 3, 2024 14:37
-
-
Save Validark/1c49b2b00ff930df76a3ee1d22f18244 to your computer and use it in GitHub Desktop.
Select nth bit from bitstring
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
const std = @import("std"); | |
const builtin = @import("builtin"); | |
const assert = std.debug.assert; | |
inline fn pdep(src: u64, mask: u64) u64 { | |
return asm ("pdep %[mask], %[src], %[ret]" | |
: [ret] "=r" (-> u64), | |
: [src] "r" (src), | |
[mask] "r" (mask), | |
); | |
} | |
const USE_POPCNT = switch (builtin.cpu.arch) { | |
.aarch64_32, .aarch64_be, .aarch64 => false, | |
.mips, .mips64, .mips64el, .mipsel => std.Target.mips.featureSetHas(builtin.cpu.features, .cnmips), | |
.powerpc, .powerpc64, .powerpc64le, .powerpcle => std.Target.powerpc.featureSetHas(builtin.cpu.features, .popcntd), | |
.s390x => std.Target.s390x.featureSetHas(builtin.cpu.features, .miscellaneous_extensions_3), | |
.ve => true, | |
.avr => true, | |
.msp430 => true, | |
.riscv32, .riscv64 => std.Target.riscv.featureSetHas(builtin.cpu.features, .zbb), | |
.sparc, .sparc64, .sparcel => std.Target.sparc.featureSetHas(builtin.cpu.features, .popc), | |
.wasm32, .wasm64 => true, | |
.x86, .x86_64 => std.Target.x86.featureSetHas(builtin.cpu.features, .popcnt), | |
else => false, | |
}; | |
/// Returns the the position of the nth bit of m (0-indexed, 64 is the error condition, i.e. when @popCount(m) <= n) | |
fn select(m: u64, n: u6) u8 { | |
const USE_LOOKUP_TABLE = true; | |
const cpu_name = builtin.cpu.model.llvm_name orelse builtin.cpu.model.name; | |
if (!@inComptime() and comptime builtin.cpu.arch == .x86_64 and | |
std.Target.x86.featureSetHas(builtin.cpu.features, .bmi2) and | |
// pdep is microcoded (slow) on AMD architectures before Zen 3. | |
(!std.mem.startsWith(u8, cpu_name, "znver") or cpu_name["znver".len] >= '3')) | |
{ | |
return @ctz(pdep(@as(u64, 1) << n, m)); | |
} | |
comptime var lookup_table: [256][8]u8 = undefined; | |
comptime if (USE_LOOKUP_TABLE) { | |
@setEvalBranchQuota(std.math.maxInt(u32)); | |
for (&lookup_table, 0..) |*slot, i| { | |
for (slot, 0..) |*sub_slot, j| { | |
sub_slot.* = selectByte(i, j); | |
} | |
} | |
}; | |
const ones: u64 = 0x0101010101010101; | |
var i = m; | |
i -= (i >> 1) & 0x5555555555555555; | |
i = (i & 0x3333333333333333) + ((i >> 2) & 0x3333333333333333); | |
const prefix_sums = (((i + (i >> 4)) & 0x0F0F0F0F0F0F0F0F) *% ones); | |
assert((prefix_sums & 0x8080808080808080) == 0); | |
const broadcasted = ones * (@as(u64, n) | 0x80); | |
const bit_isolate = ones * if (USE_POPCNT) 0x80 else 0x01; | |
const mask = ((broadcasted - prefix_sums) >> if (USE_POPCNT) 0 else 7) & bit_isolate; | |
// prove it is safe to optimize (x >> 56) << 3 to (x >> 53) | |
const max_byte_index = @as(u64, ones) *% ones; | |
assert(((max_byte_index >> 53) & 0b111) == 0); | |
if (mask == bit_isolate) return 64; | |
const byte_index: u6 = if (USE_POPCNT) | |
@intCast(@popCount(mask) << 3) | |
else | |
@intCast((mask *% ones) >> 53); | |
const prefix_sum: u6 = @truncate(prefix_sums << 8 >> byte_index); | |
const target_byte: u8 = @truncate(m >> byte_index); | |
const n_for_target_byte: u3 = @intCast(n - prefix_sum); | |
return if (USE_LOOKUP_TABLE) | |
lookup_table[target_byte][n_for_target_byte] + byte_index | |
else | |
selectByte(target_byte, @intCast(n_for_target_byte)) + byte_index; | |
} | |
fn selectByte(m: u8, n: u3) u4 { | |
const ones: u64 = 0x0101010101010101; | |
const unique_bytes: u64 = 0x8040_2010_0804_0201; | |
const unique_bytes_diff_from_msb = (ones * 0x80) - unique_bytes; | |
const prefix_sums = (((((m *% ones) & unique_bytes) + unique_bytes_diff_from_msb) >> 7) & ones) *% ones; | |
const broadcasted = ones * (@as(u64, n) | 0x80); | |
const bit_isolate = ones * if (USE_POPCNT) 0x80 else 0x01; | |
const mask = (((broadcasted - prefix_sums) >> if (USE_POPCNT) 0 else 7) & bit_isolate); | |
if (mask == bit_isolate) return 8; | |
const bit_index: u3 = if (USE_POPCNT) | |
@intCast(@popCount(mask)) | |
else | |
@intCast((mask *% ones) >> 56); | |
return bit_index; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment