Skip to content

Instantly share code, notes, and snippets.

@kylebgorman
Last active July 10, 2019 14:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kylebgorman/71b566af5e850d09b7ca8f5879d8dc29 to your computer and use it in GitHub Desktop.
Save kylebgorman/71b566af5e850d09b7ca8f5879d8dc29 to your computer and use it in GitHub Desktop.
FAR compilation for tokens, with simple UNKing support
#!/usr/bin/env python
"""Compiles compact FAR from tokenized data for LM construction."""
import argparse
import collections
import heapq
import logging
import operator
from typing import List
import pywrapfst as fst
def _tokenize(line: str) -> List[str]:
"""Splits a line into tokens."""
return line.split()
class Compiler:
"""Helper class for string FST compilation."""
def __init__(self, sym: fst.SymbolTable, OOV_index: int, arc_type: str):
self.sym = sym
self.OOV_index = OOV_index
self.arc_type = arc_type
def __call__(self, line: str, attach_symbols: bool = False) -> fst.Fst:
"""Compiles and compacts the FST."""
tokens = _tokenize(line)
f = fst.Fst(self.arc_type)
f.reserve_states(len(tokens) + 1)
one = fst.Weight.One(f.weight_type())
src = f.add_state()
f.set_start(src)
for token in tokens:
dst = f.add_state()
index = self.sym.find(token)
if index == fst.NO_SYMBOL:
index = self.OOV_index
f.add_arc(src, fst.Arc(index, index, one, dst))
src = dst
f.set_final(src)
if attach_symbols:
f.set_input_symbols(self.sym)
f.set_output_symbols(self.sym)
assert f.verify(), "Ill-formed FST"
return fst.convert(f, "compact_string")
def main(args: argparse.Namespace) -> None:
with open(args.input_path, "r") as source:
counts = collections.Counter()
# First pass: collects token counts.
for line in source:
counts.update(_tokenize(line))
logging.info(f"Total types:\t{len(counts):,d}")
logging.info(f"Total tokens:\t{sum(counts.values()):,d}")
# Populates the symbol table.
sym = fst.SymbolTable("lmcompile")
sym.add_symbol(args.epsilon_symbol)
# Filters counts on the vocabulary threshold.
if args.vocabulary_size:
vocabulary = [
token
for (token, _) in heapq.nlargest(
args.vocabulary_size,
counts.items(),
key=operator.itemgetter(1),
)
]
logging.info(f"Min token count:\t{counts[vocabulary[-1]]:,d}")
elif args.count_threshold:
vocabulary = [
token
for (token, count) in counts.items()
if count >= args.count_threshold
]
# Else unreachable.
del counts # This could be huge.
logging.info(f"Vocabulary size:\t{len(vocabulary):,d}")
for token in vocabulary:
sym.add_symbol(token)
oov_index = sym.add_symbol(args.OOV_symbol)
del vocabulary # This could be big.
# Second pass: builds the FAR.
source.seek(0)
compiler = Compiler(sym, oov_index, args.arc_type)
sink = fst.FarWriter.create(
args.output_path, arc_type=args.arc_type, far_type=args.far_type
)
# We attach the symbol table just to the first FST.
sink[f"{0:08x}"] = compiler(source.readline(), attach_symbols=True)
for (linenum, line) in enumerate(source, 1):
sink[f"{linenum:08x}"] = compiler(line)
if __name__ == "__main__":
logging.basicConfig(level="INFO", format="%(levelname)s: %(message)s")
parser = argparse.ArgumentParser(description=__doc__)
# Mandatory arguments.
parser.add_argument("input_path", help="path to tokenized input")
parser.add_argument("output_path", help="path for FAR output")
# Threshold arguments.
thresholds = parser.add_mutually_exclusive_group(required=True)
thresholds.add_argument(
"--vocabulary_size", type=int, help="maximum vocabulary size"
)
thresholds.add_argument(
"--count_threshold", type=int, help="minimum count"
)
# Shadowing `ngramsymbols` flags.
parser.add_argument(
"--OOV_symbol",
default="<unk>",
help="class label for OOV symbols (default: %(default)s)",
)
parser.add_argument(
"--epsilon_symbol",
default="<epsilon>",
help="label for epsilon (default: %(default)s)",
)
# Shadowing `farcompilestrings` flags.
parser.add_argument(
"--arc_type",
default="standard",
help="arc type (default: %(default)s)",
)
parser.add_argument(
"--far_type", default="default", help="FAR type (default: %(default)s)"
)
parser.add_argument(
"--keep_symbols",
default=False,
help="store symbol table in the FAR file (default: %(default)s)",
)
main(parser.parse_args())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment