Last active
December 16, 2015 22:20
-
-
Save pv/5506308 to your computer and use it in GitHub Desktop.
Dump from Sympy to Fortran
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Dump matrices and whatnot from Sympy to Fortran. | |
P. Virtanen 2013, Public Domain. | |
""" | |
import os | |
import shutil | |
import subprocess | |
import argparse | |
import re | |
import sympy | |
def example(): | |
a = sympy.Symbol('a') | |
b = sympy.Symbol('b') | |
values = { | |
sympy.Symbol('foo'): sympy.Matrix([[a, b], [b, a+2]]) | |
} | |
dump_fortran_values('out.inc', values) | |
def dump_fortran_values(filename, values): | |
print("Generating %s" % filename) | |
inputs = set() | |
for v in values.values(): | |
for atom in iter_atoms(v): | |
if atom not in values and atom.is_Symbol and not atom.is_NumberSymbol: | |
inputs.add(atom) | |
code = format_fortran_values(values, extra_symbols=True, | |
extra_items=inputs) | |
with open(filename, 'wb') as f: | |
write_fortran_header(f) | |
f.write("!! Inputs: %s\n" % ", ".join(sorted(map(str, inputs)))) | |
f.write("\n") | |
f.write(code) | |
def write_fortran_header(f): | |
f.write("!! -*-fortran-*-\n") | |
f.write("!! This is automatically generated by %s -- do not edit\n" | |
% __file__) | |
def iter_atoms(m): | |
if isinstance(m, sympy.Matrix): | |
for i in range(m.shape[0]): | |
for j in range(m.shape[1]): | |
for atom in m[i,j].atoms(): | |
yield atom | |
else: | |
for x in m.atoms(): | |
yield x | |
def format_fortran_values(values, extra_symbols=False, extra_items=()): | |
# -- Sort items in dependency order | |
items = list(values.items()) | |
items.sort(key=lambda x: x[0].name) | |
items.reverse() | |
new_items = [] | |
while items: | |
for it in items: | |
if not any(in_expression(it[0], expr) for _, expr in items): | |
new_items.insert(0, it) | |
items.remove(it) | |
break | |
else: | |
raise ValueError("dependency loop in definitions") | |
items = new_items | |
# -- Produce Fortran output | |
s = "" | |
defined = set() | |
for k, v in items + list((z, None) for z in extra_items): | |
if hasattr(k, 'is_integer') and k.is_integer: | |
s += " integer :: %s\n" % str(k) | |
elif hasattr(k, 'is_real') and k.is_real: | |
s += " double precision :: %s\n" % str(k) | |
else: | |
s += " double complex :: %s\n" % str(k) | |
defined.add(k) | |
if extra_symbols and v is not None: | |
if isinstance(v, sympy.Matrix): | |
atoms = set() | |
for i in range(v.shape[0]): | |
for j in range(v.shape[1]): | |
atoms.update(v[i,j].atoms()) | |
else: | |
atoms = v.atoms() | |
for k2 in atoms: | |
if (k2.is_Symbol or k2.is_NumberSymbol) and k2 not in defined: | |
defined.add(k2) | |
if hasattr(k2, 'is_integer') and k2.is_integer: | |
s += " integer :: %s\n" % str(k2) | |
elif hasattr(k2, 'is_real') and k2.is_real: | |
s += " double precision :: %s\n" % str(k2) | |
else: | |
s += " double complex :: %s\n" % str(k2) | |
for k, v in items: | |
s += myfcode(v, assign_to=str(k)) + "\n" | |
s = fcode_postprocess(s) | |
return s | |
def in_expression(sym, expr): | |
if isinstance(expr, sympy.Matrix): | |
return any(any(in_expression(sym, expr[i,j]) | |
for j in range(expr.shape[1])) | |
for i in range(expr.shape[0])) | |
if sym in expr.atoms(): | |
return True | |
return False | |
from sympy.printing.fcode import FCodePrinter | |
class MyFCodePrinter(FCodePrinter): | |
def _print_Sum(self, expr): | |
assert len(expr.variables) == 1 | |
sum_var = expr.variables[0] | |
a = expr.args[1][1] | |
b = expr.args[1][2] | |
assert sum_var.is_Atom | |
return "sum( (/ ( %s , %s=%s, %s ) /) )" % ( | |
self._print(expr.args[0]), str(sum_var), str(a), str(b)) | |
def _print_Function(self, expr): | |
if expr.func == sympy.sign: | |
return "sign(1d0, %s)" % (self.stringify(expr.args, ", "),) | |
elif expr.func == sympy.im: | |
return "aimag(%s)" % (self.stringify(expr.args, ", "),) | |
elif expr.func == sympy.conjugate: | |
return "conjg(%s)" % (self.stringify(expr.args, ", "),) | |
else: | |
return "%s(%s)" % (str(expr.func), | |
self.stringify(expr.args, ", "),) | |
def _doprint_a_piece(self, expr, assign_to=None): | |
if isinstance(expr, sympy.Matrix): | |
return self._doprint_a_matrix(expr, assign_to=assign_to) | |
return FCodePrinter._doprint_a_piece(self, expr, assign_to) | |
def _doprint_a_matrix(self, expr, assign_to=None): | |
lines = [] | |
if assign_to is not None: | |
lhs_printed = self._print(assign_to) | |
else: | |
lhs_printed = "matrix" | |
lines.append("dimension %s(%d,%d)" | |
% (lhs_printed, expr.shape[0], expr.shape[1])) | |
lines.append("%s(:,:) = 0" % (lhs_printed,)) | |
for i in range(expr.shape[0]): | |
for j in range(expr.shape[1]): | |
if expr[i,j] == 0: | |
pass | |
else: | |
lines.append("%s(%d,%d) = %s" % (lhs_printed, i+1, j+1, self._print(expr[i,j]))) | |
return lines | |
def myfcode(expr, **settings): | |
s = MyFCodePrinter(settings).doprint(expr) | |
s = re.sub('\n(\\s*)@', ' &\n\\1 ', s) | |
s = re.sub('/ &\n \\)', '/) &\n ', s, re.S) | |
return s | |
def fcode_postprocess(s): | |
lines = [] | |
def_lines = [] | |
for line in s.splitlines(): | |
if (line.lstrip().startswith('double precision') | |
or line.lstrip().startswith('double complex') | |
or line.lstrip().startswith('integer') | |
or line.lstrip().startswith('parameter') | |
or line.lstrip().startswith('dimension')): | |
if line not in def_lines: | |
def_lines.append(line) | |
else: | |
lines.append(line) | |
return "\n".join(def_lines + lines) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment