Skip to content

Instantly share code, notes, and snippets.

@caspark
Created February 18, 2020 05:25
Show Gist options
  • Save caspark/8826ae47aef433d2dbbfae21604c5f03 to your computer and use it in GitHub Desktop.
Save caspark/8826ae47aef433d2dbbfae21604c5f03 to your computer and use it in GitHub Desktop.
2020-02-17-kaldi-breaking-grammar
import logging, os
import dragonfly
if False:
logging.basicConfig(level=10)
logging.getLogger('grammar.decode').setLevel(20)
logging.getLogger('compound').setLevel(20)
# logging.getLogger('kaldi').setLevel(30)
logging.getLogger('engine').setLevel(10)
logging.getLogger('kaldi').setLevel(10)
else:
logging.basicConfig(level=20)
from dragonfly.log import setup_log
setup_log()
class KaldiBreakerRule(dragonfly.CompoundRule):
spec = "[<alpha>] [<beta>]"
extras = [
dragonfly.Repetition(name="alpha", min=1, max=3,
child=dragonfly.Alternative(name="alpha_alternative", children=[
dragonfly.Literal("escape"),
dragonfly.Literal("escape"),
])
),
dragonfly.Repetition(min=1, max=3, name="beta",
child=dragonfly.Literal("escape"),
),
]
def _process_recognition(self, node, extras):
print(f"Breaker recognized! node={node} and extras={extras}")
engine = dragonfly.get_engine("kaldi",
model_dir='models/daanzu_20200201_1ep-biglm',
)
engine.connect()
grammar = dragonfly.Grammar(name="mygrammar")
grammar.add_rule(KaldiBreakerRule())
grammar.load()
# import utils_dragonfly
# print(f"Grammar loaded: {utils_dragonfly.get_grammar_complexity_tree(grammar)}")
print("Preparing for recognition...")
engine.prepare_for_recognition()
print("Listening...")
engine.do_recognition()
import logging
class ComplexityNode(object):
def __init__(self, item):
self.item = item
self.children = []
self.total_descendents = 1
def build_complexity_tree(thing):
node = ComplexityNode(thing)
if isinstance(thing, Rule):
children = [thing.element]
element = thing.element
elif isinstance(thing, RuleRef):
children = [thing.rule.element]
else:
# thing is probably an Element
children = thing.children
for child in children:
child_node = build_complexity_tree(child)
node.children.append(child_node)
node.total_descendents += child_node.total_descendents
if isinstance(thing, Alternative):
node.children = sorted(node.children, reverse=False,
key=lambda node: str(node.item))
node.children = sorted(node.children, reverse=True,
key=lambda node: node.total_descendents)
return node
def get_rule_complexity_tree(rule, depth_threshold=10, complexity_threshold=10):
def render_complexity_tree(node, current_depth):
pluralized_children = "children" if len(
node.children) != 1 else "child"
node_name = "%-75s %d" % (" " * current_depth + "- " + repr(node.item), node.total_descendents)
# if current_depth >= depth_threshold:
# return ""
# elif node.total_descendents <= complexity_threshold:
# return "%s (+ %3d uncomplex direct %s)" % (node_name, len(node.children), pluralized_children)
# if (isinstance(node.item, Integer)
# or isinstance(node.item, Compound) and node.total_descendents <= 2):
# children_repr = " (+ %3d trivial direct %s)" % (
# len(node.children), pluralized_children)
# elif current_depth + 1 == depth_threshold and node.total_descendents > 1:
# children_repr = " (+ %3d truncated direct %s)" % (
# len(node.children), pluralized_children)
# else:
if True:
children_repr = ""
for child in node.children:
child_repr = render_complexity_tree(child, current_depth + 1)
if len(child_repr) > 0:
children_repr += "\n" + child_repr
return node_name + children_repr
try:
tree = build_complexity_tree(rule)
return render_complexity_tree(tree, 0)
except Exception:
logging.exception("failed to build complexity tree")
return ""
def get_grammar_complexity_score(grammar):
try:
return sum([build_complexity_tree(r).total_descendents for r in grammar.rules if r.exported])
except Exception:
logging.exception("failed to build grammar complexity score")
return 0
def get_grammar_complexity_tree(grammar, threshold=5):
rules_all = grammar.rules
rules_top = [r for r in grammar.rules if r.exported]
rules_imp = [r for r in grammar.rules if r.imported]
text = ("%s: %d rules (%d exported, %d imported):" % (
grammar, len(rules_all), len(rules_top), len(rules_imp),
))
for rule in rules_top:
text += "\n%s" % get_rule_complexity_tree(rule, threshold)
return text
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment