public
Last active

Dump from Sympy to Fortran

  • Download Gist
sympyfformat.py
Python

"""
Dump matrices and whatnot from Sympy to Fortran.
 
P. Virtanen 2013, Public Domain.
"""
 
import os
import shutil
import subprocess
import argparse
import re
 
import sympy
 
def example():
a = sympy.Symbol('a')
b = sympy.Symbol('b')
values = {
sympy.Symbol('foo'): sympy.Matrix([[a, b], [b, a+2]])
}
dump_fortran_values('out.inc', values)
 
def dump_fortran_values(filename, values):
print("Generating %s" % filename)
 
inputs = set()
for v in values.values():
for atom in iter_atoms(v):
if atom not in values and atom.is_Symbol and not atom.is_NumberSymbol:
inputs.add(atom)
 
code = format_fortran_values(values, extra_symbols=True,
extra_items=inputs)
 
with open(filename, 'wb') as f:
write_fortran_header(f)
f.write("!! Inputs: %s\n" % ", ".join(sorted(map(str, inputs))))
f.write("\n")
f.write(code)
 
def write_fortran_header(f):
f.write("!! -*-fortran-*-\n")
f.write("!! This is automatically generated by %s -- do not edit\n"
% __file__)
 
def iter_atoms(m):
if isinstance(m, sympy.Matrix):
for i in range(m.shape[0]):
for j in range(m.shape[1]):
for atom in m[i,j].atoms():
yield atom
else:
for x in m.atoms():
yield x
 
def format_fortran_values(values, extra_symbols=False, extra_items=()):
# -- Sort items in dependency order
items = list(values.items())
items.sort(key=lambda x: x[0].name)
items.reverse()
new_items = []
while items:
for it in items:
if not any(in_expression(it[0], expr) for _, expr in items):
new_items.insert(0, it)
items.remove(it)
break
else:
raise ValueError("dependency loop in definitions")
items = new_items
 
# -- Produce Fortran output
 
s = ""
defined = set()
for k, v in items + list((z, None) for z in extra_items):
if hasattr(k, 'is_integer') and k.is_integer:
s += " integer :: %s\n" % str(k)
elif hasattr(k, 'is_real') and k.is_real:
s += " double precision :: %s\n" % str(k)
else:
s += " double complex :: %s\n" % str(k)
defined.add(k)
if extra_symbols and v is not None:
if isinstance(v, sympy.Matrix):
atoms = set()
for i in range(v.shape[0]):
for j in range(v.shape[1]):
atoms.update(v[i,j].atoms())
else:
atoms = v.atoms()
for k2 in atoms:
if (k2.is_Symbol or k2.is_NumberSymbol) and k2 not in defined:
defined.add(k2)
if hasattr(k2, 'is_integer') and k2.is_integer:
s += " integer :: %s\n" % str(k2)
elif hasattr(k2, 'is_real') and k2.is_real:
s += " double precision :: %s\n" % str(k2)
else:
s += " double complex :: %s\n" % str(k2)
for k, v in items:
s += myfcode(v, assign_to=str(k)) + "\n"
s = fcode_postprocess(s)
return s
 
def in_expression(sym, expr):
if isinstance(expr, sympy.Matrix):
return any(any(in_expression(sym, expr[i,j])
for j in range(expr.shape[1]))
for i in range(expr.shape[0]))
if sym in expr.atoms():
return True
return False
 
 
from sympy.printing.fcode import FCodePrinter
 
class MyFCodePrinter(FCodePrinter):
def _print_Sum(self, expr):
assert len(expr.variables) == 1
sum_var = expr.variables[0]
a = expr.args[1][1]
b = expr.args[1][2]
assert sum_var.is_Atom
return "sum( (/ ( %s , %s=%s, %s ) /) )" % (
self._print(expr.args[0]), str(sum_var), str(a), str(b))
 
def _print_Function(self, expr):
if expr.func == sympy.sign:
return "sign(1d0, %s)" % (self.stringify(expr.args, ", "),)
elif expr.func == sympy.im:
return "aimag(%s)" % (self.stringify(expr.args, ", "),)
elif expr.func == sympy.conjugate:
return "conjg(%s)" % (self.stringify(expr.args, ", "),)
else:
return "%s(%s)" % (str(expr.func),
self.stringify(expr.args, ", "),)
 
def _doprint_a_piece(self, expr, assign_to=None):
if isinstance(expr, sympy.Matrix):
return self._doprint_a_matrix(expr, assign_to=assign_to)
return FCodePrinter._doprint_a_piece(self, expr, assign_to)
 
def _doprint_a_matrix(self, expr, assign_to=None):
lines = []
if assign_to is not None:
lhs_printed = self._print(assign_to)
else:
lhs_printed = "matrix"
 
lines.append("dimension %s(%d,%d)"
% (lhs_printed, expr.shape[0], expr.shape[1]))
lines.append("%s(:,:) = 0" % (lhs_printed,))
for i in range(expr.shape[0]):
for j in range(expr.shape[1]):
if expr[i,j] == 0:
pass
else:
lines.append("%s(%d,%d) = %s" % (lhs_printed, i+1, j+1, self._print(expr[i,j])))
return lines
 
def myfcode(expr, **settings):
s = MyFCodePrinter(settings).doprint(expr)
s = re.sub('\n(\\s*)@', ' &\n\\1 ', s)
s = re.sub('/ &\n \\)', '/) &\n ', s, re.S)
return s
 
def fcode_postprocess(s):
lines = []
def_lines = []
for line in s.splitlines():
if (line.lstrip().startswith('double precision')
or line.lstrip().startswith('double complex')
or line.lstrip().startswith('integer')
or line.lstrip().startswith('parameter')
or line.lstrip().startswith('dimension')):
if line not in def_lines:
def_lines.append(line)
else:
lines.append(line)
return "\n".join(def_lines + lines)
 
if __name__ == "__main__":
main()

Please sign in to comment on this gist.

Something went wrong with that request. Please try again.