Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@jorendorff
Last active February 16, 2019 07:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jorendorff/55fd6cb694f69b802912340b5a6bba21 to your computer and use it in GitHub Desktop.
Save jorendorff/55fd6cb694f69b802912340b5a6bba21 to your computer and use it in GitHub Desktop.
import random
class IndexedList:
""" Like a list, but append-only, values must be hashable, and duplicates are ignored. """
def __init__(self, values=None):
self._list = []
self._index = {}
if values is not None:
for v in values:
self.append(v)
def append(self, value):
if value not in self._index:
self._index[value] = len(self._list)
self._list.append(value)
def __contains__(self, value):
return value in self._index
def __len__(self):
return len(self._list)
def __getitem__(self, index):
return self._list[index]
def __iter__(self):
return iter(self._list)
def index(self, value):
return self._index[value]
def topo_sort(iterable, predecessors):
"""Topological sort the values in `iterable`, discarding duplicates.
predecessors(value) returns the value's predecessors; it should return a
list of values that are in `iterable`.
If possible, return a list that is a permutation of `list(set(iterable))`,
such that for all X, Y in iterable, if X in predecessors(Y), then X
appears before Y in the output.
If no such permutation exists, raise a ValueError.
"""
seen = set()
out = []
def add(value):
if value in seen:
if value not in out:
raise ValueError(f"cycle detected involving {value!r}")
else:
seen.add(value)
for p in predecessors(value):
add(p)
out.append(value)
for value in iterable:
add(value)
return out
grammar = {
'add': [
['mul'],
['add', '+', 'mul'],
['add', '-', 'mul'],
],
'mul': [
['pre'],
['mul', '*', 'pre'],
['mul', '/', 'pre'],
],
'pre': [
['prim'],
['-', 'pre'],
['prim', '^', 'pre'],
],
'prim': [
['N'],
['v'],
['(', 'add', ')'],
],
}
def is_nt(symbol):
return symbol in grammar
class Grammar:
def __init__(self, indexified_grammar, suffixes, counts_by_nt, counts_by_seq):
self.indexified_grammar = indexified_grammar
self.suffixes = suffixes
self.counts_by_nt = counts_by_nt
self.counts_by_seq = counts_by_seq
def is_nt(self, symbol):
return symbol in self.indexified_grammar
def sentence(self, nt, length, fuel):
if not (0 <= fuel < self.counts_by_nt[nt][length]):
raise IndexError("sentence index out of range")
# Select a production.
for index in self.indexified_grammar[nt]:
d = self.counts_by_seq[index][length]
if fuel < d:
return self.sequence(index, length, fuel)
fuel -= d
def sequence(self, index, length, fuel):
head, tail = self.suffixes[index]
if tail is None:
if self.is_nt(head):
return self.sentence(head, length, fuel)
else:
assert fuel == 0
return [head]
else:
if self.is_nt(head):
for k in range(1, length):
d_head = self.counts_by_nt[head][k]
d_tail = self.counts_by_seq[tail][length - k]
d = d_head * d_tail
if fuel < d:
return (self.sentence(head, k, fuel // d_tail) +
self.sequence(tail, length - k, fuel % d_tail))
fuel -= d
else:
return [head] + self.sequence(tail, length - 1, fuel)
def compile_grammar(grammar, length):
nts = list(grammar.keys())
def is_nt(symbol):
return symbol in grammar
# First, just compute the set of all suffixes, saving the index of each
# whole production.
indexified_grammar = {nt: [] for nt in nts}
suffixes = IndexedList()
for nt in nts:
for prod in grammar[nt]:
acc = None
for i in reversed(range(len(prod))):
pair = (prod[i], acc)
if pair not in suffixes:
suffixes.append(pair)
acc = suffixes.index(pair)
indexified_grammar[nt].append(acc)
# Now list all the data dependencies. Note that a sequence `( term )` does
# not have to be counted after its suffix `term )`! This is because we
# count sequences `term )` of length N-1 in a previous round before trying
# to count sequences `( term )` of length N. And no symbol ever matches the
# empty string.
dependencies = [[]] * len(suffixes)
for prod_index_list in indexified_grammar.values():
for index in prod_index_list:
head, tail = suffixes[index]
if is_nt(head) and tail is None:
# A full production consisting of only a nonterminal can be
# counted only after all that nonterminal's productions have
# been counted.
dependencies[index] = indexified_grammar[head]
# Use the dependencies to sort the suffixes.
sorted_indexes = topo_sort(range(len(suffixes)), dependencies.__getitem__)
worklist = []
for index in sorted_indexes:
head, tail = suffixes[index]
nts_satisfied = [nt for nt in nts
if index in indexified_grammar[nt]]
worklist.append((index, head, tail, nts_satisfied))
# Ready to calculate.
counts_by_seq = [[0] for _ in suffixes]
counts_by_nt = {nt: [0] * (length + 1) for nt in nts}
for current_length in range(1, length + 1):
for index, head, tail, nts_satisfied in worklist:
if tail is None:
if is_nt(head):
n = counts_by_nt[head][current_length]
else:
n = 1 if current_length == 1 else 0
else:
if is_nt(head):
n = sum(counts_by_nt[head][k] * counts_by_seq[tail][current_length - k]
for k in range(1, current_length))
else:
n = counts_by_seq[tail][current_length - 1]
counts_by_seq[index].append(n)
for nt in nts_satisfied:
counts_by_nt[nt][current_length] += n
return Grammar(indexified_grammar, suffixes, counts_by_nt, counts_by_seq)
def main():
maxlen = 20
g = compile_grammar(grammar, maxlen)
n = g.counts_by_nt['add'][maxlen]
print(f"There are {n} strings " +
f"of length {maxlen} that match the 'add' nonterminal. Here are a few uniformly selected ones:")
for _ in range(15):
i = random.randrange(n)
print(f"{i}: {' '.join(g.sentence('add', maxlen, i))}")
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment