Skip to content

Instantly share code, notes, and snippets.

@readevalprint
Created December 19, 2022 12:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save readevalprint/e7e5609e7e34b914f1c27818aa495a8a to your computer and use it in GitHub Desktop.
Save readevalprint/e7e5609e7e34b914f1c27818aa495a8a to your computer and use it in GitHub Desktop.
'''
from nltk import CFG
from nltk.parse.generate import demo_grammar
grammar = CFG.fromstring(demo_grammar)
s = list(grammar._lhs_index.keys())[0]
generate_sample(grammar, s, 1, 10)
'''
def grammar_sentence(index, depth, symbols, grammar):
# print('=' * depth * 2, " index, depth, symbols:", index, depth, symbols)
sentence = []
symbols_len = len_symbols(symbols, depth, grammar)
if abs(index) > symbols_len:
return sentence
for symbol in symbols:
# print("symbol", symbol, end=" ")
if symbol not in grammar:
# print("is terminal")
sentence.append(symbol)
else:
symbol_len = len_symbols([symbol], depth, grammar)
if symbol_len > 0:
rem = index % symbol_len
l = 0
# print('has options:', grammar[symbol])
for option in grammar[symbol]:
option_len = len_symbols(option, depth - 1, grammar)
# print("l, rem, option_len, symbol_len:", l, rem, option_len, symbol_len)
if l <= rem and rem < l + option_len:
# print("using", option)
sentence.extend(
grammar_sentence(
(rem - l) % option_len, depth - 1, option, grammar
)
)
index = index // symbol_len
break
l += option_len
# print('=' * depth * 2, 'sentence', sentence)
return sentence
len_dict = {}
def len_symbols(rhs, depth, grammar):
if depth <= 0:
return 0
ret = 1
for symbol in rhs:
key = symbol, depth, id(grammar)
if key in len_dict:
ret *= len_dict[key]
continue
options = grammar.get(symbol, [])
op_ret = 0
if not options: # is terminal
continue
for option in options:
op_ret += len_symbols(option, depth - 1, grammar)
len_dict[key] = op_ret
ret *= op_ret
return ret
def grammar_index2(grammar, index):
sentence = [grammar.start()]
while set(sentence) & set(grammar._lhs_index):
for i, symbol in enumerate(sentence):
possible_replacements = grammar._lhs_index.get(symbol)
if possible_replacements:
base = len(possible_replacements)
rem = index % base
index = index // base
replacement = possible_replacements[rem]
del sentence[i]
sentence[i:i] = replacement.rhs()
break
if index <= 0:
break
return sentence
def _generate_all(grammar, items, depth, index):
if items:
try:
if depth > 0:
base = len(items)
next_index, rem = index // base, index % base
item = items[0]
if isinstance(item, Nonterminal):
for prod in grammar.productions(lhs=item):
rhs = grammar, prod.rhs()
if rhs:
frag1 = _generate_all(
grammar, [rhs[rem]], depth - 1, next_index
)
for frag2 in _generate_all(
grammar, items[1:], depth, next_index
):
return frag1 + frag2
else:
for frag2 in _generate_all(grammar, items[1:], depth, next_index):
return [item] + frag2
except RuntimeError as _error:
if _error.message == "maximum recursion depth exceeded":
# Helpful error message while still showing the recursion stack.
raise RuntimeError(
"The grammar has rule(s) that yield infinite recursion!!"
)
else:
raise
else:
return []
def generate_sample(grammar, prod, index, depth, frags=None):
if frags is None:
frags = []
if depth > 0:
if prod in grammar._lhs_index: # Derivation
derivations = grammar._lhs_index[prod]
base = len(derivations)
next_index, rem = index // base, index % base
derivation = derivations[rem]
for d in derivation._rhs:
generate_sample(grammar, d, next_index, depth - 1, frags)
elif prod in grammar._rhs_index:
# terminal
print(prod)
frags.append(str(prod))
return frags
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment