Skip to content

Instantly share code, notes, and snippets.

@Validark
Last active January 24, 2024 10:35
Show Gist options
  • Save Validark/40d2df74b87692fe135bbeac14eed50d to your computer and use it in GitHub Desktop.
Save Validark/40d2df74b87692fe135bbeac14eed50d to your computer and use it in GitHub Desktop.
comptime pext zig
const std = @import("std");
const builtin = @import("builtin");
const HAS_FAST_PDEP_AND_PEXT = blk: {
const cpu_name = builtin.cpu.model.llvm_name orelse builtin.cpu.model.name;
break :blk builtin.cpu.arch == .x86_64 and
std.Target.x86.featureSetHas(builtin.cpu.features, .bmi2) and
// pdep is microcoded (slow) on AMD architectures before Zen 3.
!std.mem.startsWith(u8, cpu_name, "bdver") and
(!std.mem.startsWith(u8, cpu_name, "znver") or cpu_name["znver".len] >= '3');
};
fn pext(src: anytype, comptime mask: @TypeOf(src), comptime use_vector_impl: bool) std.meta.Int(.unsigned, @popCount(mask)) {
if (mask == 0) return 0;
const num_one_groups = @popCount(mask & ~(mask << 1));
if (!@inComptime() and comptime num_one_groups >= 3 and @bitSizeOf(@TypeOf(src)) <= 64 and HAS_FAST_PDEP_AND_PEXT) {
return switch (@TypeOf(src)) {
u64, u32 => @intCast(asm ("pext %[mask], %[src], %[ret]"
: [ret] "=r" (-> @TypeOf(src)),
: [src] "r" (src),
[mask] "r" (mask),
)),
else => @intCast(pext(@as(if (@bitSizeOf(@TypeOf(src)) <= 32) u32 else u64, src), mask)),
};
} else if (num_one_groups >= 4) {
blk: {
// Attempt to produce a `global_shift` value such that
// the return statement at the end of this block moves the desired bits into the least significant
// bit position.
comptime var global_shift: @TypeOf(src) = 0;
comptime {
var x = mask;
var target = @as(@TypeOf(src), 1) << (@bitSizeOf(@TypeOf(src)) - 1);
for (0..@popCount(x) - 1) |_| target |= target >> 1;
// The maximum sum of the garbage data. If this overflows into the target bits,
// we can't use the global_shift.
var left_overs: @TypeOf(src) = 0;
var cur_pos: @TypeOf(src) = 0;
while (true) {
const shift = (@clz(x) - cur_pos);
global_shift |= @as(@TypeOf(src), 1) << shift;
var shifted_mask = x << shift;
cur_pos = @clz(shifted_mask);
cur_pos += @clz(~(shifted_mask << cur_pos));
shifted_mask = shifted_mask << cur_pos >> cur_pos;
left_overs += shifted_mask;
if ((target & left_overs) != 0) break :blk;
if ((shifted_mask & target) != 0) break :blk;
x = shifted_mask >> shift;
if (x == 0) break;
}
}
return @intCast(((src & mask) *% global_shift) >> (@bitSizeOf(@TypeOf(src)) - @popCount(mask)));
}
// TODO: add heuristics for when this is probably the best option.
// Most probably, when we can keep inside of the vector widths that the machine actually has
if (use_vector_impl) {
comptime var min_int = u0;
const vec2 = comptime relevant_masks: {
var relevant_indices: []const @TypeOf(src) = &[0]@TypeOf(src){};
var x = mask;
for (0..@popCount(mask)) |_| {
relevant_indices = relevant_indices ++ [1]@TypeOf(src){1 << @ctz(x)};
x &= x -% 1;
}
min_int = std.meta.Int(.unsigned, @ctz(relevant_indices[@popCount(mask) - 1]) + 1);
break :relevant_masks relevant_indices[0..@popCount(mask)].*;
};
const vec = @as(@Vector(@popCount(mask), min_int), @splat(@truncate(src)));
return @bitCast((vec & vec2) == vec2);
}
}
{
var ans: @TypeOf(src) = 0;
comptime var cur_pos = 0;
comptime var x = mask;
inline while (x != 0) {
const mask_ctz = @ctz(x);
const num_ones = @ctz(~(x >> mask_ctz));
comptime var ones = 1;
inline for (0..num_ones) |_| ones <<= 1;
ones -%= 1;
// @compileLog(std.fmt.comptimePrint("ans |= (src >> {}) & 0b{b}", .{ mask_ctz - cur_pos, (ones << cur_pos) }));
ans |= (src >> (mask_ctz - cur_pos)) & (ones << cur_pos);
cur_pos += num_ones;
inline for (0..num_ones) |_| x &= x - 1;
}
return @intCast(ans);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment