Created
January 3, 2016 15:58
-
-
Save niklasf/9e487c4ef57156d16b0f to your computer and use it in GitHub Desktop.
Compress and decompress chess games using huffman encoding
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import bitarray | |
import chess | |
import chess.pgn | |
import sys | |
import pickle | |
import textwrap | |
def base_counts(): | |
counts = {} | |
PROMOTION_TYPES = [chess.KNIGHT, chess.BISHOP, chess.ROOK, chess.QUEEN] | |
# White promotes. | |
for a, b in zip(chess.SquareSet(chess.BB_RANK_7), chess.SquareSet(chess.BB_RANK_8)): | |
for promotion in PROMOTION_TYPES: | |
counts[chess.Move(a, b, promotion)] = 0 | |
# White captures left and promotes. | |
for a, b in zip(chess.SquareSet(chess.BB_RANK_7), chess.SquareSet(chess.BB_RANK_8 & ~chess.BB_A8)): | |
for promotion in PROMOTION_TYPES: | |
counts[chess.Move(a, b, promotion)] = 0 | |
# White captures right and promotes. | |
for a, b in zip(chess.SquareSet(chess.BB_RANK_7 & ~chess.BB_A7), chess.SquareSet(chess.BB_RANK_8)): | |
for promotion in PROMOTION_TYPES: | |
counts[chess.Move(a, b, promotion)] = 0 | |
# Black promotes. | |
for a, b in zip(chess.SquareSet(chess.BB_RANK_2), chess.SquareSet(chess.BB_RANK_1)): | |
for promotion in PROMOTION_TYPES: | |
counts[chess.Move(a, b, promotion)] = 0 | |
# Black captures right and promotes. | |
for a, b in zip(chess.SquareSet(chess.BB_RANK_2), chess.SquareSet(chess.BB_RANK_1 & ~chess.BB_A1)): | |
for promotion in PROMOTION_TYPES: | |
counts[chess.Move(a, b, promotion)] = 0 | |
# Black captures left and promotes. | |
for a, b in zip(chess.SquareSet(chess.BB_RANK_2 & ~chess.BB_A2), chess.SquareSet(chess.BB_RANK_1)): | |
for promotion in PROMOTION_TYPES: | |
counts[chess.Move(a, b, promotion)] = 0 | |
# All normal moves. | |
for from_square, bb in zip(chess.SQUARES, chess.BB_SQUARES): | |
targets = chess.SquareSet( | |
chess.KNIGHT_MOVES[bb] | | |
chess.RANK_ATTACKS[bb][chess.BB_VOID] | | |
chess.FILE_ATTACKS[bb][chess.BB_VOID] | | |
chess.DIAG_ATTACKS_NE[bb][chess.BB_VOID] | | |
chess.DIAG_ATTACKS_NW[bb][chess.BB_VOID]) | |
for to_square in targets: | |
counts[chess.Move(from_square, to_square)] = 0 | |
return counts | |
class Node(object): | |
def __init__(self, b, freq): | |
self.b = set(b) | |
self.freq = freq | |
self.l = None | |
self.r = None | |
def build_tree(counts): | |
nodes = [Node([b], freq) for b, freq in counts.items()] | |
while len(nodes) > 1: | |
l = min(nodes, key=lambda node: node.freq) | |
nodes.remove(l) | |
r = min(nodes, key=lambda node: node.freq) | |
nodes.remove(r) | |
node = Node(l.b | r.b, l.freq + r.freq) | |
node.l = l | |
node.r = r | |
nodes.append(node) | |
return nodes[0] | |
def codebook(code, prefix, node): | |
if node.l and node.r: | |
codebook(code, prefix + "0", node.l) | |
codebook(code, prefix + "1", node.r) | |
else: | |
assert not node.l and not node.r | |
code[next(iter(node.b))] = bitarray.bitarray(prefix) | |
def build_codebook(counts): | |
tree = build_tree(counts) | |
code = {} | |
codebook(code, "", tree) | |
return code | |
def compress(code, in_file, out_file): | |
with open(in_file, "r") as pgn, open(out_file, "wb") as out: | |
game = chess.pgn.read_game(pgn) | |
output = bitarray.bitarray() | |
node = game | |
while not node.is_end(): | |
node = node.variation(0) | |
output.encode(code, [node.move]) | |
output.tofile(out) | |
def decompress(code, in_file): | |
with open(in_file, "rb") as f: | |
compressed = bitarray.bitarray() | |
compressed.fromfile(f) | |
print(chess.Board().variation_san(compressed.decode(code))) | |
def train(in_file): | |
counts = base_counts() | |
with open(in_file, "r") as pgn: | |
while True: | |
game = chess.pgn.read_game(pgn) | |
if game is None: | |
break | |
node = game | |
while not node.is_end(): | |
node = node.variation(0) | |
counts[node.move] += 1 | |
code = build_codebook(counts) | |
with open("codebook", "wb") as codebook_file: | |
pickle.dump(code, codebook_file) | |
for move, codeword in sorted(code.items(), key=lambda i: len(i[1])): | |
print(move, codeword.to01(), sep="\t") | |
def head(in_file, i=100): | |
with open(in_file, "r") as pgn: | |
while i > 0: | |
game = chess.pgn.read_game(pgn) | |
i -= 1 | |
if game is None: | |
break | |
exporter = chess.pgn.FileExporter(sys.stdout, headers=False) | |
game.accept(exporter) | |
def load_code(): | |
with open("codebook", "rb") as codebook_file: | |
return pickle.load(codebook_file) | |
def usage(): | |
print(textwrap.dedent("""\ | |
Usage: huffman-pgn.py [command] | |
Creates a huffman codebook for chess move encoding. Compresses and | |
decompresses PGNs. | |
huffman-pgn.py head big.pgn > stripped.pgn | |
huffman-pgn.py train stripped.pgn # Creates/overwrites codebook | |
huffman-pgn.py compress in.pgn out.pgnc | |
huffman-pgn.py decompress out.pgnc""")) | |
if __name__ == "__main__": | |
if len(sys.argv) < 2: | |
sys.exit(usage()) | |
if sys.argv[1] == "head": | |
sys.exit(head(sys.argv[2])) | |
if sys.argv[1] == "train": | |
sys.exit(train(sys.argv[2])) | |
else: | |
code = load_code() | |
if sys.argv[1] == "compress": | |
compress(code, sys.argv[2], sys.argv[3]) | |
elif sys.argv[1] == "decompress": | |
decompress(code, sys.argv[2]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment