Skip to content

Instantly share code, notes, and snippets.

@afbase
Created March 30, 2024 01:27
Show Gist options
  • Save afbase/2c55e429c0af9ece8c8bc97c469da6b7 to your computer and use it in GitHub Desktop.
Save afbase/2c55e429c0af9ece8c8bc97c469da6b7 to your computer and use it in GitHub Desktop.
Claude AI Generated LZMA Code
// I fed claude sonnet the LZMA spec from https://raw.githubusercontent.com/jljusten/LZMA-SDK/master/DOC/lzma-specification.txt and asked it to create rust code for it.
// Claude says, "With these additions, the LZMA compression algorithm implementation in Rust should be closer to completion. However, there may still be some missing components or edge cases that need to be addressed based on the LZMA specification.
//
// Additionally, you'll need to implement the encode methods for the LenEncoder and BitTreeEncoder structs, as well as any other missing components required by the LZMA specification.
//
// This implementation is still not fully optimized and may require further modifications and error handling to be production-ready. However, it should provide a solid foundation for understanding and implementing the LZMA compression algorithm in Rust."
mod lzma {
use std::io::{Read, Write};
// Constants and types from the specification
const PROB_INIT_VAL: u16 = (1 << 11) / 2;
const KNUM_BIT_MODEL_TOTAL_BITS: u8 = 11;
const KNUM_MOVE_BITS: u8 = 5;
const KTOP_VALUE: u32 = 1 << 24;
const KNUM_LEN_TO_POS_STATES: u32 = 4;
const KEND_POS_MODEL_INDEX: u32 = 14;
const KNUM_FULL_DISTANCES: u32 = 1 << (KEND_POS_MODEL_INDEX >> 1);
const KNUM_ALIGN_BITS: u8 = 4;
const KNUM_STATES: u32 = 12;
const KNUM_POS_BITS_MAX: u32 = 4;
const KMATCH_MIN_LEN: u32 = 2;
type CProb = u16;
struct RangeEncoder<W: Write> {
range: u32,
code: u32,
output: W,
}
impl<W: Write> RangeEncoder<W> {
fn new(output: W) -> Self {
RangeEncoder {
range: 0xFFFFFFFF,
code: 0,
output,
}
}
fn encode_bit(&mut self, prob: &mut CProb, bit: u32) {
let bound = (self.range >> KNUM_BIT_MODEL_TOTAL_BITS) * (*prob as u32);
if bit == 0 {
*prob = ((*prob + ((1 << KNUM_BIT_MODEL_TOTAL_BITS) - *prob) >> KNUM_MOVE_BITS) as CProb;
self.range = bound;
} else {
self.code += bound;
self.range -= bound;
*prob = ((*prob - (*prob >> KNUM_MOVE_BITS)) as CProb);
}
if self.range < KTOP_VALUE {
self.range <<= 8;
let byte = (self.code >> 24) as u8;
self.output.write_all(&[byte]).unwrap();
self.code = (self.code << 8) & 0xFFFFFFFF;
}
}
fn decode_bit(&mut self, prob: &mut CProb) -> u32 {
let bound = (self.range >> KNUM_BIT_MODEL_TOTAL_BITS) * (*prob as u32);
let symbol = if self.code < bound {
*prob = ((*prob + ((1 << KNUM_BIT_MODEL_TOTAL_BITS) - *prob) >> KNUM_MOVE_BITS) as CProb;
self.range = bound;
0
} else {
self.code -= bound;
self.range -= bound;
*prob = ((*prob - (*prob >> KNUM_MOVE_BITS)) as CProb);
1
};
self.normalize();
symbol
}
fn decode_distance(&mut self, len: u32) -> u32 {
let len_state = if len > KNUM_LEN_TO_POS_STATES - 1 {
KNUM_LEN_TO_POS_STATES - 1
} else {
len
};
let pos_slot = pos_slot_decoder[len_state as usize].decode(self);
if pos_slot < 4 {
return pos_slot;
}
let mut dist = ((2 | (pos_slot & 1)) << (pos_slot >> 1)) as u32;
if pos_slot < KEND_POS_MODEL_INDEX {
dist += self.decode_tree(&pos_decoders[dist as usize..(dist + (pos_slot >> 1) as u32) as usize]);
} else {
let num_direct_bits = (pos_slot >> 1) - 1;
dist += (self.decode_direct_bits(num_direct_bits - KNUM_ALIGN_BITS) << KNUM_ALIGN_BITS) as u32;
dist += align_decoder.decode(self) as u32;
}
dist
}
fn decode_tree<const NUM_BITS: u8>(&mut self, probs: &[CProb]) -> u32 {
let mut m = 1;
let mut symbol = 0;
for _ in 0..NUM_BITS {
let bit = self.decode_bit(&mut probs[m as usize]);
m = (m << 1) + bit;
symbol |= bit << (NUM_BITS - 1 - i);
}
m - ((1 << NUM_BITS) as u32)
}
fn decode_direct_bits(&mut self, num_bits: u8) -> u32 {
let mut res = 0;
for _ in 0..num_bits {
self.range >>= 1;
let t = (-(self.code >> 31)) as u32;
self.code -= self.range & t;
res = (res << 1) | (1 - t);
self.normalize();
}
res
}
fn normalize(&mut self) {
if self.range < KTOP_VALUE {
self.range <<= 8;
let byte = (self.code >> 24) as u8;
self.output.write_all(&[byte]).unwrap();
self.code = (self.code << 8) & 0xFFFFFFFF;
}
}
fn flush(&mut self) {
for i in 0..5 {
self.normalize();
}
self.output.write_all(&[
(self.code >> 24) as u8,
(self.code >> 16) as u8,
(self.code >> 8) as u8,
self.code as u8,
])
.unwrap();
}
}
struct OutWindow {
buf: Vec<u8>,
pos: usize,
size: usize,
is_full: bool,
total_pos: u64,
}
impl OutWindow {
fn new(dict_size: u32) -> Self {
OutWindow {
buf: vec![0; dict_size as usize],
pos: 0,
size: dict_size as usize,
is_full: false,
total_pos: 0,
}
}
fn put_byte(&mut self, byte: u8) {
self.buf[self.pos] = byte;
self.pos = (self.pos + 1) % self.size;
if self.pos == 0 {
self.is_full = true;
}
self.total_pos += 1;
}
fn get_byte(&self, dist: u32) -> u8 {
let index = if dist <= self.pos as u32 {
self.pos - dist as usize
} else {
self.size - (dist as usize - self.pos)
};
self.buf[index]
}
fn copy_match(&mut self, dist: u32, len: u32) {
for _ in 0..len {
self.put_byte(self.get_byte(dist));
}
}
}
struct LenEncoder {
choice: CProb,
choice2: CProb,
low_coders: [BitTreeEncoder<3>; 1 << KNUM_POS_BITS_MAX],
mid_coders: [BitTreeEncoder<3>; 1 << KNUM_POS_BITS_MAX],
high_coder: BitTreeEncoder<8>,
}
impl LenEncoder {
fn new() -> Self {
LenEncoder {
choice: PROB_INIT_VAL,
choice2: PROB_INIT_VAL,
low_coders: [BitTreeEncoder::new(); 1 << KNUM_POS_BITS_MAX],
mid_coders: [BitTreeEncoder::new(); 1 << KNUM_POS_BITS_MAX],
high_coder: BitTreeEncoder::new(),
}
}
fn encode<W: Write>(&mut self, range_encoder: &mut RangeEncoder<W>, pos_state: u32) -> u32 {
// Implementation of LenEncoder::encode goes here
0
}
}
struct BitTreeEncoder<const NUM_BITS: u8> {
probs: [CProb; 1 << NUM_BITS],
}
impl<const NUM_BITS: u8> BitTreeEncoder<NUM_BITS> {
fn new() -> Self {
BitTreeEncoder {
probs: [PROB_INIT_VAL; 1 << NUM_BITS],
}
}
fn encode<W: Write>(&mut self, range_encoder: &mut RangeEncoder<W>, symbol: u32) -> u32 {
// Implementation of BitTreeEncoder::encode goes here
0
}
}
struct BitTreeDecoder<const NUM_BITS: u8> {
probs: [CProb; 1 << NUM_BITS],
}
impl<const NUM_BITS: u8> BitTreeDecoder<NUM_BITS> {
fn new() -> Self {
BitTreeDecoder {
probs: [PROB_INIT_VAL; 1 << NUM_BITS],
}
}
fn decode<W: Write>(&mut self, range_encoder: &mut RangeEncoder<W>) -> u32 {
let mut m = 1;
for _ in 0..NUM_BITS {
m = (m << 1) + range_encoder.decode_bit(&mut self.probs[m as usize]);
}
m - ((1 << NUM_BITS) as u32)
}
}
pub fn compress<R: Read, W: Write>(input: &mut R, output: &mut W, dict_size: u32) -> Result<(), std::io::Error> {
let mut out_window = OutWindow::new(dict_size);
let mut range_encoder = RangeEncoder::new(output);
let mut state = 0;
let mut rep0 = 0;
let mut rep1 = 0;
let mut rep2 = 0;
let mut rep3 = 0;
let mut len_encoder = LenEncoder::new();
let mut rep_len_encoder = LenEncoder::new();
let mut lit_probs: Vec<CProb> = vec![PROB_INIT_VAL; (0x300 << (lc + lp)) as usize];
let mut is_match_probs: Vec<CProb> = vec![PROB_INIT_VAL; (KNUM_STATES << KNUM_POS_BITS_MAX) as usize];
let mut is_rep_probs: Vec<CProb> = vec![PROB_INIT_VAL; KNUM_STATES as usize];
let mut is_rep_g0_probs: Vec<CProb> = vec![PROB_INIT_VAL; KNUM_STATES as usize];
let mut is_rep_g1_probs: Vec<CProb> = vec![PROB_INIT_VAL; KNUM_STATES as usize];
let mut is_rep_g2_probs: Vec<CProb> = vec![PROB_INIT_VAL; KNUM_STATES as usize];
let mut is_rep0_long_probs: Vec<CProb> = vec![PROB_INIT_VAL; (KNUM_STATES << KNUM_POS_BITS_MAX) as usize];
let mut pos_slot_decoders: [BitTreeDecoder<6>; KNUM_LEN_TO_POS_STATES as usize] = [BitTreeDecoder::new(); KNUM_LEN_TO_POS_STATES as usize];
let mut pos_decoders: Vec<CProb> = vec![PROB_INIT_VAL; (KNUM_FULL_DISTANCES - KEND_POS_MODEL_INDEX) as usize];
let mut align_decoder = BitTreeDecoder::<KNUM_ALIGN_BITS>::new();
// Main compression loop
loop {
let byte = input.bytes().next().unwrap_or(0);
if byte == 0 {
// End of input
break;
}
out_window.put_byte(byte);
let pos_state = out_window.total_pos as u32 & ((1 << KNUM_POS_BITS_MAX) - 1);
let state2 = (state << KNUM_POS_BITS_MAX) + pos_state;
let is_match = range_encoder.decode_bit(&mut is_match_probs[state2 as usize]);
if is_match == 0 {
// Encode LITERAL
encode_literal(&mut out_window, &mut range_encoder, state, rep0, &mut lit_probs);
state = update_state_literal(state);
} else {
let is_rep = range_encoder.decode_bit(&mut is_rep_probs[state as usize]);
if is_rep == 0 {
// Simple Match
encode_simple_match(&mut out_window, &mut range_encoder, &mut len_encoder, &mut rep0, &mut rep1, &mut rep2, &mut rep3, state, pos_state);
state = update_state_match(state);
} else {
// Rep Match
encode_rep_match(&mut out_window, &mut range_encoder, &mut rep_len_encoder, &mut rep0, &mut rep1, &mut rep2, &mut rep3, state, pos_state, &mut is_rep_g0_probs, &mut is_rep0_long_probs, &mut is_rep_g1_probs, &mut is_rep_g2_probs);
state = update_state_rep(state);
}
}
}
// Flush the range encoder
range_encoder.flush();
Ok(())
}
fn encode_literal<W: Write>(
out_window: &mut OutWindow,
range_encoder: &mut RangeEncoder<W>,
state: u32,
rep0: u32,
lit_probs: &mut [CProb],
) {
let prev_byte = if out_window.is_empty() {
0
} else {
out_window.get_byte(1)
} as u32;
let lit_state = ((out_window.total_pos & ((1 << lp) - 1)) << lc) + (prev_byte >> (8 - lc));
let probs = &mut lit_probs[(lit_state as usize) * 0x300];
if state >= 7 {
let match_byte = out_window.get_byte(rep0 + 1) as u32;
let mut symbol = 1;
loop {
let match_bit = (match_byte >> 7) & 1;
let bit = range_encoder.decode_bit(&mut probs[((1 + match_bit) << 8) + symbol as usize]);
symbol = (symbol << 1) | bit;
if match_bit != bit {
break;
}
if symbol >= 0x100 {
break;
}
}
while symbol < 0x100 {
symbol = (symbol << 1) | range_encoder.decode_bit(&mut probs[symbol as usize]);
}
out_window.put_byte((symbol - 0x100) as u8);
} else {
let mut symbol = range_encoder.decode_tree(&mut probs, 8);
out_window.put_byte(symbol as u8);
}
}
fn encode_simple_match<W: Write>(
out_window: &mut OutWindow,
range_encoder: &mut RangeEncoder<W>,
len_encoder: &mut LenEncoder,
rep0: &mut u32,
rep1: &mut u32,
rep2: &mut u32,
rep3: &mut u32,
state: u32,
pos_state: u32,
) {
*rep3 = *rep2;
*rep2 = *rep1;
*rep1 = *rep0;
let len = len_encoder.encode(range_encoder, pos_state);
*rep0 = range_encoder.decode_distance(len);
if *rep0 == 0xFFFFFFFF {
// End of stream marker
return;
}
let match_len = len + KMATCH_MIN_LEN;
out_window.copy_match(*rep0, match_len);
}
fn encode_rep_match<W: Write>(
out_window: &mut OutWindow,
range_encoder: &mut RangeEncoder<W>,
rep_len_encoder: &mut LenEncoder,
rep0: &mut u32,
rep1: &mut u32,
rep2: &mut u32,
rep3: &mut u32,
state: u32,
pos_state: u32,
is_rep_g0_probs: &mut [CProb],
is_rep0_long_probs: &mut [CProb],
is_rep_g1_probs: &mut [CProb],
is_rep_g2_probs: &mut [CProb],
) {
let is_rep_g0 = range_encoder.decode_bit(&mut is_rep_g0_probs[state as usize]);
if is_rep_g0 == 0 {
let is_rep0_long = range_encoder.decode_bit(&mut is_rep0_long_probs[(state << KNUM_POS_BITS_MAX) + pos_state as usize]);
if is_rep0_long == 0 {
// Short Rep Match
out_window.put_byte(out_window.get_byte(*rep0 + 1));
return;
}
} else {
let is_rep_g1 = range_encoder.decode_bit(&mut is_rep_g1_probs[state as usize]);
if is_rep_g1 == 0 {
*rep1 = *rep0;
} else {
let is_rep_g2 = range_encoder.decode_bit(&mut is_rep_g2_probs[state as usize]);
if is_rep_g2 == 0 {
*rep2 = *rep0;
} else {
*rep3 = *rep0;
*rep0 = out_window.get_byte(*rep0 + 1) as u32;
}
}
}
let len = rep_len_encoder.encode(range_encoder, pos_state);
let match_len = len + KMATCH_MIN_LEN;
out_window.copy_match(*rep0, match_len);
}
fn update_state_literal(state: u32) -> u32 {
if state < 4 {
0
} else if state < 10 {
state - 3
} else {
state - 6
}
}
fn update_state_match(state: u32) -> u32 {
if state < 7 {
7
} else {
10
}
}
fn update_state_rep(state: u32) -> u32 {
if state < 7 {
8
} else {
11
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment