Skip to content

Instantly share code, notes, and snippets.

@kprotty
Created February 9, 2024 17:52
Show Gist options
  • Save kprotty/12d24060d51d840f7388f86c4ecc22ed to your computer and use it in GitHub Desktop.
Save kprotty/12d24060d51d840f7388f86c4ecc22ed to your computer and use it in GitHub Desktop.
const M = 12; // Probability scale for rANS state. Symbol frequencies in this log range. Usually 8-12.
const L = 23; // Renormalization factor to control dumping rANS state to bitstream. From rans_byte.h.
const m_min = 8 - 2 - (std.math.divCeil(u32, M, 4) catch unreachable); // Small-size-opt limit when compressing frequencies.
const m_max = [_]u16{m_min, m_min+16, m_min+16+256, m_min+16+256+4096, 1<<M}; // Size ranges for frequencies after small size limit.
fn compress(dst: anytype, src: []const u8) !void {
// Histogram for the frequency of each byte in input.
var hist = [_]u32{0} ** 256;
for (src) |byte| hist[byte] += 1;
// Quantize histogram into 0..(1 << M).
var f: [256]u16 = undefined;
for (0..256) |i| f[i] = @intCast((@as(u64, hist[i]) << M) / src.len);
// Fix hist>0 which quantized to 0 by stealing from another (lowest) quantized >1.
for (0..256) |i| {
if (!(f[i] == 0 and hist[i] != 0)) continue;
const mask = @as(@Vector(256, u16), f) > @as(@Vector(256, u16), @splat(1));
f[@ctz(@as(u256, @bitCast(@as(@Vector(256, u1), @bitCast(mask)))))] -= 1;
f[i] = 1;
}
// Compress quantized frequencies using run-length-encoding when possible.
var rle: usize = 0;
var b = std.io.bitWriter(.little, dst);
for (0..256) |i| {
if (i < rle) continue;
rle = for (i + 1..256) |j| { if (f[i] != f[j]) break j; } else 256;
if (std.math.sub(usize, rle, i + 2) catch null) |r| {
try b.writeBits(@as(u8, 0b110) + @intFromBool(r >= 1 << 5), 3);
try b.writeBits(r, if (r >= 1 << 5) 8 else 5);
}
const m = for (0..m_max.len) |j| { if (f[i] < m_max[j]) break j; } else unreachable;
try b.writeBits(if (m > 0) m_min + m - 1 else f[i], 3);
if (m > 0) try b.writeBits(f[i] - m_max[m - 1], @min(m * 4, M));
}
// Write out uncompressed size. This is needed for decoding to terminate.
const log: u5 = @intCast(31 - @clz(@as(u32, @intCast(src.len - 1))));
try b.writeBits(((src.len - 1 - (@as(u32, 1) << log)) << 5) | log, log + 5);
try b.flushBits();
// Compute prefix sum for quantized frequencies.
var cdf = [_]u16{0} ** 256;
for (1..256) |i| cdf[i] = cdf[i - 1] + f[i - 1];
// Finally, compress the input backwards as ANS decodes symbols in LIFO.
var x: u32 = 1 << L;
for (0..src.len) |i| {
const s = src[src.len - i - 1];
// Renormalize x over L.
while (x >= @as(u32, ((1 << L) >> M) << 8) * f[s]) : (x >>= 8) try dst.writeByte(@truncate(x));
// x = C(x).
x = ((x / f[s]) << M) + (x % f[s]) + cdf[s];
}
try dst.writeInt(u32, x, .little);
}
fn decompress(dst: anytype, src: []const u8) !void {
var header = std.io.fixedBufferStream(src);
var b = std.io.bitReader(.little, header.reader());
// Decode the compressed quantized frequencies.
var flen: u32 = 0;
var f: [256]u16 = undefined;
while (flen < 256) {
var rle: u32 = 1;
var tag = try b.readBitsNoEof(u16, 3);
if (tag >= 0b110) {
rle += (try b.readBitsNoEof(u8, if (tag == 0b111) 8 else 5)) + 1;
tag = try b.readBitsNoEof(u8, 3);
}
if (std.math.sub(u16, tag, m_min) catch null) |m|
tag = (try b.readBitsNoEof(u16, @min((m + 1) * 4, M))) + m_max[m];
for (0..rle) |i| f[flen + i] = tag;
flen += rle;
}
// Decode uncompressed size used to know when to stop decompressing below.
const log = try b.readBitsNoEof(u5, 5);
const size = (try b.readBitsNoEof(u32, log)) + (@as(u32, 1) << log) + 1;
// Recompute prefix sum, but this time with full range of last frequency (needed below).
var cdf = [_]u16{0} ** (256 + 1);
for (1..cdf.len) |i| cdf[i] = cdf[i - 1] + f[i - 1];
// Finally, decompress the source using the state at the end (in reverse).
var in: u32 = @intCast(src.len - 4);
var x = std.mem.readInt(u32, src[in..][0..4], .little);
for (0..size) |_| {
// s, x = D(x).
const m = x & ((1 << M) - 1);
const s: u8 = @intCast(for (0..256) |i| { if (cdf[i] <= m and m < cdf[i + 1]) break i; } else unreachable);
x = ((x >> M) * f[s]) + m - cdf[s];
try dst.writeByte(s);
// Renormalize x over L.
while (x < (1 << L)) : (in -= 1) x = (x << 8) | src[in - 1];
}
if (in < header.pos) return error.Eof;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment