Skip to content

Instantly share code, notes, and snippets.

@amykyta3
Last active May 8, 2022 19:51
Show Gist options
  • Save amykyta3/8285559e95074c4431c2836d78b36530 to your computer and use it in GitHub Desktop.
Save amykyta3/8285559e95074c4431c2836d78b36530 to your computer and use it in GitHub Desktop.
Post-process Antlr Python3 parser and replace constant expressions with integer literals
#!/usr/bin/env python3
# Usage:
# ./hoist_parser_constants.py <input Parser.py file> <output Parser.py file>
import re
import os
import sys
import importlib.util
parser_path = sys.argv[1]
output_path = sys.argv[2]
parser_name = os.path.splitext(os.path.basename(parser_path))[0]
# Import parser module to gain access to the parser class
spec = importlib.util.spec_from_file_location(parser_name, parser_path)
parser_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(parser_module)
parser_class = getattr(parser_module, parser_name)
dst_lines = []
with open(parser_path, encoding='utf-8') as f:
#---------------------------------------------------------------------------
# Applies optimization described in:
# https://github.com/antlr/antlr4/issues/3698
#---------------------------------------------------------------------------
def const_shift_exp_replacer(m: re.Match) -> str:
"""
Callback function that evaluates the constant expression match and returns
an evaluated literal string
"""
expr = m.group(0)
result = eval(expr, {parser_name:parser_class})
return(f"0x{result:x}")
# Matches: (1 << ParserName.TOKEN)
shift1_re = r"\(\d+\s*<<\s*\w+\.\w+\)"
# Matches: (1 << (ParserName.TOKEN - N))
shift2_re = r"\(\d+\s*<<\s*\(\w+\.\w+\s*[+\-]\s*\d+\)\)"
# Matches both constant shift styles
shift_re = f"(?:{shift1_re}|{shift2_re})"
# Matches entire ORed combination of shift expressions
shift_expr_re = r"\(" + shift_re + r"(?:\s*\|\s*" + shift_re + r")*\)"
#---------------------------------------------------------------------------
# Applies optimization described in:
# https://github.com/antlr/antlr4/issues/3703
#---------------------------------------------------------------------------
# Matches: token in [ParserName.Token, ...]
token_match_re = r"token\s+in\s*(\[\w+\.\w+(?:,\s*\w+\.\w+)*\])"
def token_match_exp_replacer(m: re.Match) -> str:
expr = m.group(1)
token_list = eval(expr, {parser_name:parser_class})
if len(token_list) == 1:
# Only one token. Do direct comparison
return(f"token == {token_list[0]}")
else:
# flatten token list into mask
token_mask = 0
for token in token_list:
token_mask |= (1 << token)
return(f"(1 << token) & 0x{token_mask:x} != 0")
#---------------------------------------------------------------------------
# Perform replacements
#---------------------------------------------------------------------------
for line in f:
new_line = re.sub(shift_expr_re, const_shift_exp_replacer, line)
new_line = re.sub(token_match_re, token_match_exp_replacer, new_line)
dst_lines.append(new_line)
# Write out new file
with open(output_path, "w", encoding='utf-8') as f:
for line in dst_lines:
f.write(line)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment