Skip to content

Instantly share code, notes, and snippets.

@LPeter1997
Last active May 5, 2020 19:29
Show Gist options
  • Save LPeter1997/1c88e7540d03552cacd875eb82caad8d to your computer and use it in GitHub Desktop.
Save LPeter1997/1c88e7540d03552cacd875eb82caad8d to your computer and use it in GitHub Desktop.
A simple and (ex. buggy) LZ77 compression in Rust.
/// The maximum allowed backreference distance which is also the size of the
/// sliding window.
const WINDOW_SIZE: usize = 32768;
/// The longest match length we allow.
const MAX_MATCH_LEN: usize = 258;
/// A very simple sliding window implementation.
struct Window {
data: Box<[u8; WINDOW_SIZE]>,
position: usize,
}
impl Window {
fn new() -> Self {
Self{
data: Box::new([0; WINDOW_SIZE]),
position: 0,
}
}
/// Pushes a single byte into the sliding window.
fn push(&mut self, byte: u8) {
self.data[self.position] = byte;
self.position = (self.position + 1) % WINDOW_SIZE;
}
/// Calculates the distance of the
fn distance_from(&self, index: usize) -> usize {
if self.position > index {
self.position - index
}
else {
WINDOW_SIZE - index + self.position
}
}
/// Inserts a back-reference based on distance and backreference length.
/// Returns the inserted region.
fn push_reference(&mut self, dist: usize, len: usize) -> (&[u8], &[u8]) {
assert!(dist <= WINDOW_SIZE);
assert!(len <= MAX_MATCH_LEN);
let wnd_start = (self.position + WINDOW_SIZE - dist) % WINDOW_SIZE;
let mut wnd_idx = wnd_start;
let ins_start = self.position;
for _ in 0..len {
self.data[self.position] = self.data[wnd_idx];
self.position = (self.position + 1) % WINDOW_SIZE;
wnd_idx = (wnd_idx + 1) % WINDOW_SIZE;
}
if ins_start <= self.position {
// One slice
(&self.data[ins_start..self.position], &self.data[0..0])
}
else {
// Two slices
(&self.data[ins_start..], &self.data[..self.position])
}
}
}
/// How many hash-chains we manage.
const HASH_CHAIN_COUNT: usize = 4096;
/// How long a single hash-chain can grow.
const HASH_CHAIN_LENGTH: usize = 64;
/// The nil-value in the hash-chain (meaning the end of chain).
const HASH_NIL: u16 = 65535;
/// A simple hash-table that manages chains of hash values.
struct HashTable {
chain_count: usize,
chain_length: usize,
chain_offsets: Box<[usize]>,
data: Box<[u16]>,
}
impl HashTable {
fn new() -> Self {
Self{
// Hashing
chain_count: HASH_CHAIN_COUNT,
chain_length: HASH_CHAIN_LENGTH,
chain_offsets: vec![0; HASH_CHAIN_COUNT].into_boxed_slice(),
data: vec![HASH_NIL; HASH_CHAIN_COUNT * HASH_CHAIN_LENGTH].into_boxed_slice(),
}
}
/// Hashes the given bytes. Returns the index of the corresponging
/// hash-chain.
fn hash(&self, key: &[u8]) -> usize {
// FNV hash
assert!(key.len() == 3);
const FNV_BASIS: u64 = 0xcbf29ce484222325;
const FNV_PRIME: u64 = 0x100000001b3;
let mut result = FNV_BASIS;
for b in key {
result ^= *b as u64;
result = result.wrapping_mul(FNV_PRIME);
}
result as usize % self.chain_count
}
/// Returns the hash chain for the given hash value.
fn hash_chain_for_hash(&self, hash: usize) -> &[u16] {
let offs = hash * self.chain_length;
&self.data[offs..(offs + self.chain_length)]
}
/// Returns the hash chain for the given hash value.
fn hash_chain_for_hash_mut(&mut self, hash: usize) -> &mut [u16] {
let offs = hash * self.chain_length;
&mut self.data[offs..(offs + self.chain_length)]
}
/// Returns the maximum match length of the key and the given offset of the
/// sliding window.
fn match_length(&self, key: &[u8], index: usize, window: &Window) -> usize {
let max_offs = std::cmp::min(key.len(), 258);
for i in 0..max_offs {
let w_idx = (index + i) % WINDOW_SIZE;
if w_idx == window.position || key[i] != window.data[w_idx] {
return i;
}
}
max_offs
}
/// Inserts the given value into the given hash-chain.
fn insert(&mut self, chain_idx: usize, val: u16) {
let offs = self.chain_offsets[chain_idx];
self.hash_chain_for_hash_mut(chain_idx)[offs] = val;
self.chain_offsets[chain_idx] = (offs + 1) % self.chain_length;
}
/// Searches for a backreference using the given chain index, input and
/// sliding window. Returns the distance-length tuple, if a match is found.
fn reference(&self, chain_idx: usize, upcoming: &[u8], window: &Window) -> Option<(usize, usize)> {
let chain_len = self.chain_length;
let mut chain_offset = self.chain_offsets[chain_idx];
let mut longest_len = 0;
let mut longest_idx = 0;
let chain = self.hash_chain_for_hash(chain_idx);
for _ in 0..chain_len {
chain_offset = chain_offset.wrapping_sub(1) % chain_len;
if chain[chain_offset] == HASH_NIL {
break;
}
let idx = chain[chain_offset];
let match_len = self.match_length(upcoming, idx as usize, window);
if longest_len < match_len {
longest_len = match_len;
longest_idx = idx;
}
}
if longest_len < 3 {
None
}
else {
let dist = window.distance_from(longest_idx as usize);
Some((dist, longest_len))
}
}
}
/// Simple LZ77 compression. Compresses into a simple text format, where bytes
/// are printed in decimal and are separated with commas, backreferences are in
/// the form <distance;length>.
fn compress(bytes: &[u8]) -> String {
let mut window = Window::new();
let mut tbl = HashTable::new();
let mut result = String::new();
let mut written = 0;
let mut offs = 0;
while offs < bytes.len() {
let rem = &bytes[offs..];
let mut found_match = false;
if rem.len() >= 3 {
let hash = tbl.hash(&rem[..3]);
tbl.insert(hash, window.position as u16);
if let Some((dist, len)) = tbl.reference(hash, rem, &window) {
found_match = true;
result += &format!("<{};{}>,", dist, len);
let (s1, s2) = window.push_reference(dist, len);
offs += len;
}
}
if !found_match {
let byte = rem[0];
result += &format!("{},", byte);
window.push(byte);
offs += 1;
}
written += 1;
}
println!("Uncompressed symbols: {}", bytes.len());
println!(" Compressed symbols: {}", written);
result
}
// Simple LZ77 decompression (uncompression?).
fn uncompress(mut src: &str, orig: &[u8]) -> Vec<u8> {
if src.as_bytes().last().cloned() == Some(',' as u8) {
src = &src[..(src.len() - 1)];
}
let mut window = Window::new();
let mut result = Vec::new();
for element in src.split(',') {
if element.as_bytes().first().cloned() == Some('<' as u8) {
// Reference
let mut dist_and_len = element[1..(element.len() - 1)].split(';');
let dist = dist_and_len.next().unwrap().parse::<usize>().unwrap();
let len = dist_and_len.next().unwrap().parse::<usize>().unwrap();
let (s1, s2) = window.push_reference(dist, len);
result.extend_from_slice(s1);
result.extend_from_slice(s2);
}
else {
// Raw byte
let byte = element.parse::<u8>().unwrap();
result.push(byte);
window.push(byte);
}
}
result
}
fn file_to_bytes(path: &str) -> Vec<u8> {
use std::io::Read;
let mut result = Vec::new();
let mut f = std::fs::File::open(path).unwrap();
f.read_to_end(&mut result).unwrap();
result
}
fn bytes_to_file(path: &str, bytes: &[u8]) {
use std::io::Write;
let mut f = std::fs::File::create(path).unwrap();
f.write_all(bytes).unwrap();
}
fn main() {
let orig = file_to_bytes("C:/TMP/bigimage_out.png");
let compressed = compress(&orig);
let uncompressed = uncompress(&compressed, &orig);
bytes_to_file("test_out.png", &uncompressed);
println!("Equality: {}", orig == uncompressed);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment