Last active
December 19, 2019 15:34
-
-
Save Arianxx/603dc688a4b68f207ada2c4534758637 to your computer and use it in GitHub Desktop.
Huffman Coding Python Implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import heapq | |
class HuffmanNode: | |
def __init__(self, symbol=None, freq=None): | |
self.symbol = symbol | |
self.freq = freq | |
self.parent = None | |
self.left = None | |
self.right = None | |
def __lt__(self, other): | |
return self.freq < other.freq | |
def is_leaf(self): | |
return not self.left and not self.right | |
def get_code(self): | |
# 调试用 | |
if not self.is_laef(): | |
raise ValueError("Not a leaf node.") | |
code = '' | |
node = self | |
while node.parent: | |
if node.parent.left == node: | |
code = '0' + code | |
else: | |
code = '1' + code | |
code = code.parent | |
return code | |
class Huffman: | |
BYTE_MAX_NUM = 255 | |
def __init__(self): | |
self.origin = None | |
self.compressed = None | |
self.huffman_tree = None | |
self.freqs = [0 for _ in range(self.BYTE_MAX_NUM + 1)] | |
self.coding_table = [0 for _ in range(self.BYTE_MAX_NUM + 1)] | |
self.reverse_table = {} | |
self.coding_str = '' | |
def _minimize_frequencies(self): | |
# 缩小字频使其在一个字节范围以内 | |
max_freq = max(self.freqs) | |
for symbol, freq in enumerate(self.freqs): | |
scale_freq = int(self.BYTE_MAX_NUM * (freq / max_freq)) | |
scale_freq = 1 if not scale_freq and freq else scale_freq | |
self.freqs[symbol] = scale_freq | |
def _get_symbol_frequencies(self): | |
for symbol in self.origin: | |
self.freqs[symbol] += 1 | |
self._minimize_frequencies() | |
def _initial_node_heap(self): | |
self._heap = [] | |
for symbol, freq in enumerate(self.freqs): | |
node = HuffmanNode(symbol, freq) | |
heapq.heappush(self._heap, node) | |
def _build_huffman_tree(self): | |
self._initial_node_heap() | |
while len(self._heap) > 1: | |
node1 = heapq.heappop(self._heap) | |
node2 = heapq.heappop(self._heap) | |
new_node = HuffmanNode(symbol=None, freq=node1.freq + node2.freq) | |
new_node.left, new_node.right = node1, node2 | |
node1.parent, node2.parent = new_node, new_node | |
heapq.heappush(self._heap, new_node) | |
self.huffman_tree = heapq.heappop(self._heap) | |
del self._heap | |
return self.huffman_tree | |
def _build_coding_table(self, node, code_str=''): | |
if node is None: | |
return | |
if node.symbol is not None: | |
self.coding_table[node.symbol] = code_str | |
self.reverse_table[code_str] = node.symbol | |
self._build_coding_table(node.left, code_str + '0') | |
self._build_coding_table(node.right, code_str + '1') | |
def _pading_coding_str(self): | |
pading_count = 8 - len(self.coding_str) % 8 | |
self.coding_str += '0' * pading_count | |
state_str = '{:08b}'.format(pading_count) | |
self.coding_str = state_str + self.coding_str | |
def _prefix_coding_freqs(self): | |
coding_freqs = [] | |
for freq in self.freqs: | |
coding_freqs.append('{:08b}'.format(freq)) | |
coding_freqs = ''.join(coding_freqs) | |
self.coding_str = coding_freqs + self.coding_str | |
def _build_codeing_str(self): | |
temp = [] | |
for symbol in self.origin: | |
temp.append(self.coding_table[symbol]) | |
self.coding_str = ''.join(temp) | |
self._pading_coding_str() | |
self._prefix_coding_freqs() | |
return self.coding_str | |
def _get_compressed(self): | |
assert(len(self.coding_str) % 8 == 0) | |
b = bytearray() | |
for index in range(0, len(self.coding_str), 8): | |
code_num = int(self.coding_str[index:index + 8], 2) | |
b.append(code_num) | |
self.compressed = bytes(b) | |
return self.compressed | |
def _read_frequencies_from_compressed(self): | |
coding_freqs = self.compressed[:self.BYTE_MAX_NUM + 1] | |
for index, freq in enumerate(coding_freqs): | |
self.freqs[index] = freq | |
def _get_real_coding_from_compressed(self): | |
pading_count = self.compressed[self.BYTE_MAX_NUM + 1] | |
byte_coding_str = self.compressed[self.BYTE_MAX_NUM + 2:] | |
coding_str = [] | |
for num in byte_coding_str: | |
temp = bin(num)[2:] | |
# 补足省略掉的前导零 | |
temp = '0' * (8 - len(temp)) + temp | |
assert(len(temp) == 8) | |
coding_str.append(temp) | |
coding_str = ''.join(coding_str) | |
assert(len(coding_str) % 8 == 0) | |
real_coding_str = coding_str[:-pading_count] | |
return real_coding_str | |
def _decode_compressed(self): | |
real_coding_str = self._get_real_coding_from_compressed() | |
decode_content = [] | |
node = self.huffman_tree | |
for state in real_coding_str: | |
if state == '0': | |
node = node.left | |
elif state == '1': | |
node = node.right | |
if node.symbol is not None: | |
assert(0 <= node.symbol <= self.BYTE_MAX_NUM) | |
hex_str = hex(node.symbol)[2:] | |
# fromhex方法将两个字符识别为一个16进制数 | |
# 所以单个数需要补零 | |
hex_str = '0' + hex_str if len(hex_str) == 1 else hex_str | |
decode_content.append(hex_str) | |
node = self.huffman_tree | |
decode_content = ''.join(decode_content) | |
return bytes.fromhex(decode_content) | |
def clear(self): | |
self.__init__() | |
def encode(self, origin): | |
self.clear() | |
self.origin = origin | |
self._get_symbol_frequencies() | |
self._build_huffman_tree() | |
self._build_coding_table(self.huffman_tree) | |
self._build_codeing_str() | |
return self._get_compressed() | |
def compresse(self, filename, output_filename=None): | |
with open(filename, 'rb') as file: | |
origin = file.read() | |
compressed_content = self.encode(origin) | |
if output_filename is None: | |
output_filename = filename + '.hfm' | |
with open(output_filename, 'wb') as file: | |
file.write(compressed_content) | |
return True | |
def decode(self, compressed): | |
self.clear() | |
self.compressed = compressed | |
self._read_frequencies_from_compressed() | |
self._build_huffman_tree() | |
return self._decode_compressed() | |
def uncompresse(self, filename, output_filename=None): | |
with open(filename, 'rb') as file: | |
compressed = file.read() | |
decode_content = self.decode(compressed) | |
if output_filename is None: | |
if filename.endswith('.hfm'): | |
output_filename = filename[:-4] | |
else: | |
output_filename = filename + '.dhfm' | |
with open(output_filename, 'wb') as file: | |
file.write(decode_content) | |
return True |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment