Skip to content

Instantly share code, notes, and snippets.

@shuffle2
Created April 6, 2021 07:24
Show Gist options
  • Save shuffle2/64c2cb794b6b7997647482321b6d1800 to your computer and use it in GitHub Desktop.
Save shuffle2/64c2cb794b6b7997647482321b6d1800 to your computer and use it in GitHub Desktop.
LZRA Decompressor
/* LZRA Decompressor
* The format appears similar to LZRC (see tlzrc.c)
*/
#include <array>
#include <cstdint>
#include <cstring>
#include <cstdio>
typedef uint8_t u8;
typedef uint32_t u32;
struct LzraDecoder {
bool Read8(u8* val) {
if (pos_in >= len_in) {
return false;
}
*val = buf_in[pos_in++];
return true;
}
bool Read32(u32* val) {
*val = 0;
for (u32 i = sizeof(*val); i != 0; i--) {
u8 tmp;
if (!Read8(&tmp)) {
return false;
}
*val |= (u32)tmp << (8 * (i - 1));
}
return true;
}
bool Write8(u8 val) {
if (pos_out >= len_out) {
return false;
}
buf_out[pos_out++] = val;
return true;
}
bool WriteMatch(size_t distance, size_t len) {
if (distance > pos_out) {
return false;
}
if (pos_out + len >= len_out) {
return false;
}
size_t copy_len = len + 1;
auto src = &buf_out[pos_out - distance];
if (copy_len <= distance) {
memcpy(&buf_out[pos_out], src, copy_len);
pos_out += copy_len;
}
else {
for (size_t i = 0; i < copy_len; i++) {
if (!Write8(src[i])) {
return false;
}
}
}
return true;
}
bool Init(u8* out, size_t out_len, u8* in, size_t in_len) {
buf_out = out;
len_out = out_len;
buf_in = in;
len_in = in_len;
pos_in = pos_out = 0;
if (!Read8(&control)) {
return false;
}
if (control & (3 << 5)) {
printf("unknown control bits %02x %x\n", control, (control >> 5) & 3);
return false;
}
if (!Read32(&code)) {
return false;
}
range = 0xffffffff;
return true;
}
bool Normalize() {
if ((range >> 24) == 0) {
range <<= 8;
code <<= 8;
u8 tmp = 0;
if (!Read8(&tmp)) {
return false;
}
code |= tmp;
}
return true;
}
template<size_t ProbBits>
bool DecodeBit(u8* val, u8* prob) {
if (!Normalize()) {
return false;
}
u32 bound = (range >> 8) * *prob;
*prob -= *prob >> ProbBits;
if (code < bound) {
range = bound;
*prob += 0xff >> ProbBits;
*val = 1;
}
else {
code -= bound;
range -= bound;
*val = 0;
}
return true;
}
// TODO ensure probs cant be OOBd (annoying since the array sizes vary and there can be offset)
template<size_t BitUnit>
bool DecodeBitTree(u32* val, u8* probs, u32 limit) {
u32 num = BitUnit;
do {
u8 bit = 0;
if (!DecodeBit<3>(&bit, &probs[num - BitUnit])) {
return false;
}
num <<= 1;
num += bit * BitUnit;
} while (num < limit);
// returned val will be [0, limit+BitUnit-1)
*val = num - limit;
// could be a problem if BitUnit > 1...should it be handled?
if (BitUnit > 1 && *val >= limit) {
printf("%s overflow", __func__);
return false;
}
return true;
}
bool DecodeNumber(u32* val, u8* probs, u32 n) {
u32 num = 1;
u8 bit = 0;
if (!DecodeBit<4>(&bit, &probs[8 * 0])) {
return false;
}
num = (num << 1) + bit;
if (n > 2) {
if (n > 3) {
if (!DecodeBit<4>(&bit, &probs[8 * 0])) {
return false;
}
num = (num << 1) + bit;
if (n > 4) {
if (!Normalize()) {
return false;
}
for (u32 i = 0; i < n - 4; i++) {
range >>= 1;
num <<= 1;
if (code < range) {
num += 1;
}
else {
code -= range;
}
}
}
}
if (!DecodeBit<4>(&bit, &probs[8 * 3])) {
return false;
}
num = (num << 1) + bit;
}
if (n > 0) {
if (!DecodeBit<4>(&bit, &probs[8 * 1])) {
return false;
}
num = (num << 1) + bit;
if (n > 1) {
if (!DecodeBit<4>(&bit, &probs[8 * 2])) {
return false;
}
num = (num << 1) + bit;
}
}
*val = num;
return true;
}
bool Decompress(u8* out, size_t* out_len, u8* in, size_t* in_len) {
if (!Init(out, *out_len, in, *in_len)) {
return false;
}
if ((control >> 7) & 1) {
if (pos_in + code > len_in || code > len_out) {
return false;
}
*in_len = pos_in + code;
*out_len = code;
memcpy(buf_out, &buf_in[pos_in], code);
return true;
}
u8 init_sel = (control >> 3) & 3;
u8 init_val = 0x80 - init_sel * 0x10;
memset(probs0, init_val, sizeof(probs0));
memset(probs1, init_val, sizeof(probs1));
memset(probs2, init_val, sizeof(probs2));
memset(probs3, init_val, sizeof(probs3));
memset(probs4, init_val, sizeof(probs4));
u8 control_shift = control & 7;
u32 index0 = 0;
u8 last_byte = 0;
while (true) {
u32 literal_state = ((((pos_out & 7) << 8) | last_byte) >> control_shift) & 7;
u32 literal_val = 0;
if (!DecodeBitTree<1>(&literal_val, &probs0[literal_state][0], 0x100)) {
return false;
}
if (!Write8((u8)literal_val)) {
return false;
}
while (true) {
u32 len_bits = 0;
for (size_t i = 0; i < 8; i++) {
u8 bit = 0;
if (!DecodeBit<4>(&bit, &probs3[i][index0])) {
return false;
}
if (bit == 0) {
break;
}
len_bits++;
}
if (len_bits == 0) {
if (index0 > 0) {
index0--;
}
last_byte = buf_out[pos_out - 1];
break;
}
len_bits--;
u32 match_len = 0;
if (len_bits == 0) {
match_len = 1;
}
else {
u32 n = len_bits - 1;
u32 len_state = (n << 5) | (((pos_out << n) & 3) << 3) | (index0 & 7);
if (!DecodeNumber(&match_len, &probs4[len_state], n)) {
return false;
}
if (match_len == 0xff) {
*in_len = pos_in;
*out_len = pos_out;
return true;
}
}
u32 dist_bits = 0;
u32 limit = 0x100;
u32 dist_offset = len_bits;
if (match_len < 3) {
limit = 0x40;
dist_offset += 8 * 31;
}
if (!DecodeBitTree<8>(&dist_bits, &probs1[dist_offset], limit)) {
return false;
}
u32 match_dist = 0;
if (dist_bits == 0) {
match_dist = 1;
}
else {
// NOTE len_bits is allowed to be 0 (and wrap/truncate to 7)
u32 n2 = len_bits - 1;
// dist_bits at least 8
u32 n = (dist_bits >> 3) - 1;
u32 len_state = (n << 5) | ((((match_len > 2) << n) & 3) << 3) | (n2 & 7);
if (!DecodeNumber(&match_dist, &probs2[len_state], n)) {
return false;
}
}
if (!WriteMatch(match_dist, match_len)) {
return false;
}
index0 = 7;
last_byte = buf_out[pos_out - 1];
}
}
}
u8* buf_out{};
u8* buf_in{};
size_t len_out{};
size_t len_in{};
size_t pos_out{};
size_t pos_in{};
u8 control{};
u32 code{};
u32 range{};
u8 probs0[8][255]{};
u8 probs1[0x130]{};
u8 probs2[0x240]{};
u8 probs3[8][8]{};
u8 probs4[8 * 4 * 8]{};
};
// An example (somewhat like an existing API. in your own code in_len should be sensible of course)
extern "C"
int lzra_decompress(u8* buf_out, u32 buf_out_len, u8* buf_in, u8** buf_in_next) {
LzraDecoder decompressor;
size_t out_len = buf_out_len;
size_t in_len = SIZE_MAX;
if (!decompressor.Decompress(buf_out, &out_len, buf_in, &in_len)) {
return -1;
}
if (buf_in_next) {
*buf_in_next = buf_in + in_len;
}
return out_len;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment