Skip to content

Instantly share code, notes, and snippets.

@madmann91
Created March 15, 2018 13:03
Show Gist options
  • Save madmann91/6eb386afbc1a89f4adef5905adc6f383 to your computer and use it in GitHub Desktop.
Save madmann91/6eb386afbc1a89f4adef5905adc6f383 to your computer and use it in GitHub Desktop.
Simple huffman byte encoding/decoding
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <assert.h>
#include <string.h>
typedef struct symb_s symb_t;
typedef struct queue_s queue_t;
typedef struct dict_s dict_t;
struct symb_s {
uint8_t byte;
size_t freq;
symb_t* next;
symb_t* child[2];
};
struct queue_s {
symb_t* begin;
symb_t* end;
};
struct dict_s {
uint32_t bits[256];
uint8_t len[256];
};
static symb_t* sort_symbols(symb_t** symb) {
symb_t* cur = *symb;
symb_t* next = cur->next;
if (next == NULL) return cur;
if (next->next == NULL) {
if (next->freq < cur->freq) {
next->next = cur;
cur->next = NULL;
*symb = next;
return cur;
}
return next;
}
size_t i = 0;
symb_t* left = NULL;
symb_t* right = NULL;
while (cur) {
symb_t* next = cur->next;
if (i++ & 1) {
cur->next = left;
left = cur;
} else {
cur->next = right;
right = cur;
}
cur = next;
}
symb_t* end_left = sort_symbols(&left);
symb_t* end_right = sort_symbols(&right);
symb_t* begin = NULL, *end = NULL;
while (right && left) {
if (right->freq < left->freq) {
symb_t* next = right->next;
right->next = NULL;
if (end)
end->next = right;
else
begin = right;
end = right;
right = next;
} else {
symb_t* next = left->next;
left->next = NULL;
if (end)
end->next = left;
else
begin = left;
end = left;
left = next;
}
}
*symb = begin;
if (left) {
end->next = left;
return end_left;
}
if (right) {
end->next = right;
return end_right;
}
return end;
}
static inline symb_t combine_symbols(symb_t* first, symb_t* second) {
return (symb_t) {
.freq = first->freq + second->freq,
.byte = 0,
.next = NULL,
.child = { first, second }
};
}
static inline symb_t* dequeue(queue_t* queue) {
assert(queue->begin);
symb_t* elem = queue->begin;
queue->begin = elem->next;
if (elem == queue->end) {
queue->begin = NULL;
queue->end = NULL;
}
return elem;
}
static inline void enqueue(queue_t* queue, symb_t* symb) {
symb->next = NULL;
if (queue->end) {
queue->end->next = symb;
queue->end = symb;
} else {
queue->begin = queue->end = symb;
}
}
void huff_dict(const symb_t* root, dict_t* dict) {
typedef struct stack_s stack_t;
struct stack_s {
const symb_t* node;
uint32_t bits;
uint8_t len;
};
const int stack_size = 16;
stack_t stack[stack_size];
int stack_ptr = 0;
stack[0] = (stack_t) { .node = root, .bits = 0, .len = 0 };
while (stack_ptr >= 0) {
const stack_t elem = stack[stack_ptr--];
assert(elem.node);
if (elem.node->child[0]) {
assert(elem.len < 31);
assert(elem.node->child[1]);
assert(stack_ptr + 2 < stack_size);
stack[++stack_ptr] = (stack_t) {
.node = elem.node->child[0],
.bits = elem.bits | (0 << elem.len),
.len = elem.len + 1
};
stack[++stack_ptr] = (stack_t) {
.node = elem.node->child[1],
.bits = elem.bits | (1 << elem.len),
.len = elem.len + 1
};
} else {
assert(!elem.node->child[1]);
dict->bits[elem.node->byte] = elem.bits;
dict->len[elem.node->byte] = elem.len;
}
}
}
symb_t* prune_symbols(symb_t* symb) {
symb_t** cur = &symb;
while (*cur) {
if ((*cur)->freq == 0)
*cur = (*cur)->next;
else
cur = &(*cur)->next;
}
return symb;
}
size_t huff_tree(const uint8_t* bytes, size_t n, symb_t* symbs) {
for (size_t i = 0; i < 256; ++i) {
symbs[i].byte = i;
symbs[i].freq = 0;
symbs[i].next = i < 255 ? &symbs[i + 1] : NULL;
symbs[i].child[0] = NULL;
symbs[i].child[1] = NULL;
}
for (size_t i = 0; i < n; ++i) symbs[bytes[i]].freq++;
queue_t left_queue = { .begin = NULL, .end = NULL };
queue_t right_queue = { .begin = NULL, .end = NULL };
left_queue.begin = prune_symbols(symbs);
left_queue.end = sort_symbols(&left_queue.begin);
size_t n_symbols = 256;
while (left_queue.begin != left_queue.end ||
right_queue.begin != right_queue.end ||
(left_queue.begin && right_queue.begin)) {
symb_t* first = NULL;
if (!right_queue.begin)
first = dequeue(&left_queue);
else if (!left_queue.begin)
first = dequeue(&right_queue);
else {
if (right_queue.begin->freq < left_queue.begin->freq)
first = dequeue(&right_queue);
else
first = dequeue(&left_queue);
}
symb_t* second = NULL;
if (!right_queue.begin)
second = dequeue(&left_queue);
else if (!left_queue.begin)
second = dequeue(&right_queue);
else {
if (right_queue.begin->freq < left_queue.begin->freq)
second = dequeue(&right_queue);
else
second = dequeue(&left_queue);
}
assert(n_symbols < 512);
symb_t* comb = &symbs[n_symbols++];
*comb = combine_symbols(first, second);
enqueue(&right_queue, comb);
}
assert(!left_queue.begin);
assert(!left_queue.end);
assert(right_queue.begin == right_queue.end);
assert(right_queue.begin == &symbs[n_symbols - 1]);
return n_symbols;
}
size_t huff_encode(const uint8_t* bytes, size_t n, const dict_t* dict, uint8_t* res) {
size_t p = 0, q = 0;
for (size_t i = 0; i < n; ++i) {
uint8_t byte = bytes[i];
uint32_t bits = dict->bits[byte];
uint8_t len = dict->len[byte];
for (size_t j = 0; j < len;) {
res[p] |= (bits >> j) << q;
size_t w = len - j > 8 - q ? 8 - q : len - j;
j += w;
q += w;
if (q >= 8) {
p++;
q = 0;
}
}
}
return q > 0 ? p + 1 : p;
}
size_t huff_decode(const uint8_t* bytes, size_t n, const symb_t* root, uint8_t* res, size_t pmax) {
const symb_t* cur = root;
size_t p = 0;
for (size_t i = 0; i < n; ++i) {
uint8_t byte = bytes[i];
for (size_t j = 0; j < 8; ++j, byte >>= 1) {
uint8_t bit = byte & 1;
assert(cur->child[0] && cur->child[1]);
cur = bit ? cur->child[1] : cur->child[0];
if (!cur->child[0]) {
if (p >= pmax) return p;
res[p++] = cur->byte;
cur = root;
}
}
}
return p;
}
int main(int argc, char** argv) {
if (argc < 2)
return 1;
FILE* fp = fopen(argv[1], "rb");
if (!fp)
return 1;
fseek(fp, 0, SEEK_END);
size_t len = ftell(fp);
fseek(fp, 0, SEEK_SET);
uint8_t* bytes = malloc(len);
fread(bytes, 1, len, fp);
fclose(fp);
symb_t symbs[512];
dict_t dict;
size_t n_symbols = huff_tree(bytes, len, symbs);
symb_t* root = &symbs[n_symbols - 1];
huff_dict(root, &dict);
uint8_t* encoded = calloc(len, 1);
size_t n_encoded = huff_encode(bytes, len, &dict, encoded);
uint8_t* decoded = calloc(len, 1);
size_t n_decoded = huff_decode(encoded, n_encoded, root, decoded, len);
fwrite(decoded, 1, n_decoded, stdout);
printf("%lu%%\n", n_encoded * 100 / n_decoded);
free(encoded);
free(decoded);
free(bytes);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment