Skip to content

Instantly share code, notes, and snippets.

@plammens
Created December 18, 2020 18:59
Show Gist options
  • Save plammens/27746b7df2f3109a1753e2f48da31222 to your computer and use it in GitHub Desktop.
Save plammens/27746b7df2f3109a1753e2f48da31222 to your computer and use it in GitHub Desktop.
Command-line script to generate a quick Huffman tree
#!/usr/bin/env python3
import argparse
import functools
import heapq
from collections import Counter
from dataclasses import dataclass
from typing import *
import more_itertools as mitt
import networkx as nx
import numpy as np
@dataclass(unsafe_hash=True, frozen=True)
@functools.total_ordering
class HuffmanNode:
chars: FrozenSet[str]
freq: int
def __lt__(self, other):
return self.freq < other.freq
def __add__(self, other):
return HuffmanNode(self.chars | other.chars, self.freq + other.freq)
def __str__(self):
return f"{set(self.chars)}\n({self.freq})"
def huffman(s: str) -> nx.Graph:
freqs = Counter(s)
g = nx.Graph()
parentless: List[HuffmanNode] = []
for char, freq in freqs.items():
node = HuffmanNode(frozenset({char}), freq)
g.add_node(node)
parentless.append(node)
heapq.heapify(parentless)
while len(parentless) > 1:
a, b = heapq.heappop(parentless), heapq.heappop(parentless)
parent = a + b
g.add_node(parent)
g.add_edge(a, parent)
g.add_edge(b, parent)
parentless.append(parent)
return g
def flip_pos(pos: Dict[Any, float]):
coords = np.array(list(pos.values()))
coords[:, 1] = (y := coords[:, 1]).max() - (y - y.min())
new_pos = {key: tuple(coords[i, :]) for i, key in enumerate(pos)}
return new_pos
def get_label(node: HuffmanNode):
if len(node.chars) == 1:
char = mitt.only(node.chars)
char = "<SPACE>" if char == " " else char
return f"{char}\n({node.freq})"
else:
return f"{node.freq}"
def get_color(node: HuffmanNode):
if len(node.chars) == 1:
return "orange"
else:
return "skyblue"
def get_shape(node: HuffmanNode):
return "s" if len(node.chars) == 1 else "o"
def show_huffman(g: nx.Graph):
import matplotlib.pyplot as plt
from networkx.drawing.nx_pydot import graphviz_layout
plt.rcParams["figure.figsize"] = (10, 8)
pos = flip_pos(graphviz_layout(g, "dot"))
nx.draw_networkx(
g,
pos=pos,
with_labels=True,
labels={n: get_label(n) for n in g.nodes},
node_color=list(map(get_color, g.nodes)),
node_size=1000,
# node_shape=list(map(get_shape, g.nodes)),
)
plt.margins(0.2)
plt.axis("off")
plt.show()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("string")
args = parser.parse_args()
g = huffman(args.string)
show_huffman(g)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment