Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@wheerd
Last active March 26, 2019 09:47
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save wheerd/13ac5a0e9b560db6201d96136fa0bbcc to your computer and use it in GitHub Desktop.
Save wheerd/13ac5a0e9b560db6201d96136fa0bbcc to your computer and use it in GitHub Desktop.
import re
import os
from sympy.core.expr import Basic
from sympy.integrals.rubi.patterns import rubi_object
from sympy.integrals.rubi.patterns import *
from sympy.core.singleton import Singleton
from matchpy.utils import get_short_lambda_source
from matchpy.matching.code_generation import CodeGenerator
if not os.path.exists('generated.py'):
rubi = rubi_object()
class RubiCodeGenerator(CodeGenerator):
def final_label(self, pattern_index, subst_name):
label = self._matcher.patterns[pattern_index][1]
if label is None:
return super().final_label(pattern_index, subst_name)
src = get_short_lambda_source(label)
subst_vars = [v for v in self._matcher.pattern_vars[pattern_index].keys() if v]
if subst_vars:
pstr = r'\b({})\b'.format('|'.join(re.escape(v) for v in subst_vars if v))
src = re.sub(pstr, lambda m: '{}[{!r}]'.format(subst_name, m[0]), src)
return src
def constraint_repr(self, constraint):
if isinstance(constraint, CustomConstraint) and isinstance(constraint.constraint, type(lambda: 0)):
src = get_short_lambda_source(constraint.constraint)
mapping = {k: v for v, k in constraint._variables.items() }
params = constraint._variables.keys()
pstr = r'\b({})\b'.format('|'.join(map(re.escape, params)))
new_src = re.sub(pstr, lambda m: 'subst{}[{!r}]'.format(self._substs, constraint._variables[m[0]]), src)
return new_src, False
return super().constraint_repr(constraint)
def expr(self, expr):
if isinstance(type(expr), Singleton):
return 'S({!r})'.format(expr)
return repr(expr)
def get_args(self, operation, operation_type):
if issubclass(operation_type, Integral):
return '({0}._args[0],) + {0}._args[1]'.format(operation)
if issubclass(operation_type, Basic):
return '{}._args'.format(operation)
return super().get_args(operation, operation_type)
GENERATED_TEMPLATE = '''
# -*- coding: utf-8 -*-
from matchpy import *
from sympy.integrals.rubi.utility_function import *
from sympy.integrals.rubi.patterns import *
{}
{}
'''.strip()
generator = RubiCodeGenerator(rubi.matcher)
global_code, code = generator.generate_code()
code = GENERATED_TEMPLATE.format(global_code, code)
with open('generated.py', 'w', encoding='utf-8') as f:
f.write(code)
from generated import match_root
x = symbols('x')
for r, _ in match_root(Integral(x, x)):
print(r)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment