Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@Arianxx
Last active December 19, 2019 15:34
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Arianxx/603dc688a4b68f207ada2c4534758637 to your computer and use it in GitHub Desktop.
Save Arianxx/603dc688a4b68f207ada2c4534758637 to your computer and use it in GitHub Desktop.
Huffman Coding Python Implementation
import heapq
class HuffmanNode:
def __init__(self, symbol=None, freq=None):
self.symbol = symbol
self.freq = freq
self.parent = None
self.left = None
self.right = None
def __lt__(self, other):
return self.freq < other.freq
def is_leaf(self):
return not self.left and not self.right
def get_code(self):
# 调试用
if not self.is_laef():
raise ValueError("Not a leaf node.")
code = ''
node = self
while node.parent:
if node.parent.left == node:
code = '0' + code
else:
code = '1' + code
code = code.parent
return code
class Huffman:
BYTE_MAX_NUM = 255
def __init__(self):
self.origin = None
self.compressed = None
self.huffman_tree = None
self.freqs = [0 for _ in range(self.BYTE_MAX_NUM + 1)]
self.coding_table = [0 for _ in range(self.BYTE_MAX_NUM + 1)]
self.reverse_table = {}
self.coding_str = ''
def _minimize_frequencies(self):
# 缩小字频使其在一个字节范围以内
max_freq = max(self.freqs)
for symbol, freq in enumerate(self.freqs):
scale_freq = int(self.BYTE_MAX_NUM * (freq / max_freq))
scale_freq = 1 if not scale_freq and freq else scale_freq
self.freqs[symbol] = scale_freq
def _get_symbol_frequencies(self):
for symbol in self.origin:
self.freqs[symbol] += 1
self._minimize_frequencies()
def _initial_node_heap(self):
self._heap = []
for symbol, freq in enumerate(self.freqs):
node = HuffmanNode(symbol, freq)
heapq.heappush(self._heap, node)
def _build_huffman_tree(self):
self._initial_node_heap()
while len(self._heap) > 1:
node1 = heapq.heappop(self._heap)
node2 = heapq.heappop(self._heap)
new_node = HuffmanNode(symbol=None, freq=node1.freq + node2.freq)
new_node.left, new_node.right = node1, node2
node1.parent, node2.parent = new_node, new_node
heapq.heappush(self._heap, new_node)
self.huffman_tree = heapq.heappop(self._heap)
del self._heap
return self.huffman_tree
def _build_coding_table(self, node, code_str=''):
if node is None:
return
if node.symbol is not None:
self.coding_table[node.symbol] = code_str
self.reverse_table[code_str] = node.symbol
self._build_coding_table(node.left, code_str + '0')
self._build_coding_table(node.right, code_str + '1')
def _pading_coding_str(self):
pading_count = 8 - len(self.coding_str) % 8
self.coding_str += '0' * pading_count
state_str = '{:08b}'.format(pading_count)
self.coding_str = state_str + self.coding_str
def _prefix_coding_freqs(self):
coding_freqs = []
for freq in self.freqs:
coding_freqs.append('{:08b}'.format(freq))
coding_freqs = ''.join(coding_freqs)
self.coding_str = coding_freqs + self.coding_str
def _build_codeing_str(self):
temp = []
for symbol in self.origin:
temp.append(self.coding_table[symbol])
self.coding_str = ''.join(temp)
self._pading_coding_str()
self._prefix_coding_freqs()
return self.coding_str
def _get_compressed(self):
assert(len(self.coding_str) % 8 == 0)
b = bytearray()
for index in range(0, len(self.coding_str), 8):
code_num = int(self.coding_str[index:index + 8], 2)
b.append(code_num)
self.compressed = bytes(b)
return self.compressed
def _read_frequencies_from_compressed(self):
coding_freqs = self.compressed[:self.BYTE_MAX_NUM + 1]
for index, freq in enumerate(coding_freqs):
self.freqs[index] = freq
def _get_real_coding_from_compressed(self):
pading_count = self.compressed[self.BYTE_MAX_NUM + 1]
byte_coding_str = self.compressed[self.BYTE_MAX_NUM + 2:]
coding_str = []
for num in byte_coding_str:
temp = bin(num)[2:]
# 补足省略掉的前导零
temp = '0' * (8 - len(temp)) + temp
assert(len(temp) == 8)
coding_str.append(temp)
coding_str = ''.join(coding_str)
assert(len(coding_str) % 8 == 0)
real_coding_str = coding_str[:-pading_count]
return real_coding_str
def _decode_compressed(self):
real_coding_str = self._get_real_coding_from_compressed()
decode_content = []
node = self.huffman_tree
for state in real_coding_str:
if state == '0':
node = node.left
elif state == '1':
node = node.right
if node.symbol is not None:
assert(0 <= node.symbol <= self.BYTE_MAX_NUM)
hex_str = hex(node.symbol)[2:]
# fromhex方法将两个字符识别为一个16进制数
# 所以单个数需要补零
hex_str = '0' + hex_str if len(hex_str) == 1 else hex_str
decode_content.append(hex_str)
node = self.huffman_tree
decode_content = ''.join(decode_content)
return bytes.fromhex(decode_content)
def clear(self):
self.__init__()
def encode(self, origin):
self.clear()
self.origin = origin
self._get_symbol_frequencies()
self._build_huffman_tree()
self._build_coding_table(self.huffman_tree)
self._build_codeing_str()
return self._get_compressed()
def compresse(self, filename, output_filename=None):
with open(filename, 'rb') as file:
origin = file.read()
compressed_content = self.encode(origin)
if output_filename is None:
output_filename = filename + '.hfm'
with open(output_filename, 'wb') as file:
file.write(compressed_content)
return True
def decode(self, compressed):
self.clear()
self.compressed = compressed
self._read_frequencies_from_compressed()
self._build_huffman_tree()
return self._decode_compressed()
def uncompresse(self, filename, output_filename=None):
with open(filename, 'rb') as file:
compressed = file.read()
decode_content = self.decode(compressed)
if output_filename is None:
if filename.endswith('.hfm'):
output_filename = filename[:-4]
else:
output_filename = filename + '.dhfm'
with open(output_filename, 'wb') as file:
file.write(decode_content)
return True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment