Skip to content

Instantly share code, notes, and snippets.

@pv
Last active December 16, 2015 22:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pv/5506308 to your computer and use it in GitHub Desktop.
Save pv/5506308 to your computer and use it in GitHub Desktop.
Dump from Sympy to Fortran
"""
Dump matrices and whatnot from Sympy to Fortran.
P. Virtanen 2013, Public Domain.
"""
import os
import shutil
import subprocess
import argparse
import re
import sympy
def example():
a = sympy.Symbol('a')
b = sympy.Symbol('b')
values = {
sympy.Symbol('foo'): sympy.Matrix([[a, b], [b, a+2]])
}
dump_fortran_values('out.inc', values)
def dump_fortran_values(filename, values):
print("Generating %s" % filename)
inputs = set()
for v in values.values():
for atom in iter_atoms(v):
if atom not in values and atom.is_Symbol and not atom.is_NumberSymbol:
inputs.add(atom)
code = format_fortran_values(values, extra_symbols=True,
extra_items=inputs)
with open(filename, 'wb') as f:
write_fortran_header(f)
f.write("!! Inputs: %s\n" % ", ".join(sorted(map(str, inputs))))
f.write("\n")
f.write(code)
def write_fortran_header(f):
f.write("!! -*-fortran-*-\n")
f.write("!! This is automatically generated by %s -- do not edit\n"
% __file__)
def iter_atoms(m):
if isinstance(m, sympy.Matrix):
for i in range(m.shape[0]):
for j in range(m.shape[1]):
for atom in m[i,j].atoms():
yield atom
else:
for x in m.atoms():
yield x
def format_fortran_values(values, extra_symbols=False, extra_items=()):
# -- Sort items in dependency order
items = list(values.items())
items.sort(key=lambda x: x[0].name)
items.reverse()
new_items = []
while items:
for it in items:
if not any(in_expression(it[0], expr) for _, expr in items):
new_items.insert(0, it)
items.remove(it)
break
else:
raise ValueError("dependency loop in definitions")
items = new_items
# -- Produce Fortran output
s = ""
defined = set()
for k, v in items + list((z, None) for z in extra_items):
if hasattr(k, 'is_integer') and k.is_integer:
s += " integer :: %s\n" % str(k)
elif hasattr(k, 'is_real') and k.is_real:
s += " double precision :: %s\n" % str(k)
else:
s += " double complex :: %s\n" % str(k)
defined.add(k)
if extra_symbols and v is not None:
if isinstance(v, sympy.Matrix):
atoms = set()
for i in range(v.shape[0]):
for j in range(v.shape[1]):
atoms.update(v[i,j].atoms())
else:
atoms = v.atoms()
for k2 in atoms:
if (k2.is_Symbol or k2.is_NumberSymbol) and k2 not in defined:
defined.add(k2)
if hasattr(k2, 'is_integer') and k2.is_integer:
s += " integer :: %s\n" % str(k2)
elif hasattr(k2, 'is_real') and k2.is_real:
s += " double precision :: %s\n" % str(k2)
else:
s += " double complex :: %s\n" % str(k2)
for k, v in items:
s += myfcode(v, assign_to=str(k)) + "\n"
s = fcode_postprocess(s)
return s
def in_expression(sym, expr):
if isinstance(expr, sympy.Matrix):
return any(any(in_expression(sym, expr[i,j])
for j in range(expr.shape[1]))
for i in range(expr.shape[0]))
if sym in expr.atoms():
return True
return False
from sympy.printing.fcode import FCodePrinter
class MyFCodePrinter(FCodePrinter):
def _print_Sum(self, expr):
assert len(expr.variables) == 1
sum_var = expr.variables[0]
a = expr.args[1][1]
b = expr.args[1][2]
assert sum_var.is_Atom
return "sum( (/ ( %s , %s=%s, %s ) /) )" % (
self._print(expr.args[0]), str(sum_var), str(a), str(b))
def _print_Function(self, expr):
if expr.func == sympy.sign:
return "sign(1d0, %s)" % (self.stringify(expr.args, ", "),)
elif expr.func == sympy.im:
return "aimag(%s)" % (self.stringify(expr.args, ", "),)
elif expr.func == sympy.conjugate:
return "conjg(%s)" % (self.stringify(expr.args, ", "),)
else:
return "%s(%s)" % (str(expr.func),
self.stringify(expr.args, ", "),)
def _doprint_a_piece(self, expr, assign_to=None):
if isinstance(expr, sympy.Matrix):
return self._doprint_a_matrix(expr, assign_to=assign_to)
return FCodePrinter._doprint_a_piece(self, expr, assign_to)
def _doprint_a_matrix(self, expr, assign_to=None):
lines = []
if assign_to is not None:
lhs_printed = self._print(assign_to)
else:
lhs_printed = "matrix"
lines.append("dimension %s(%d,%d)"
% (lhs_printed, expr.shape[0], expr.shape[1]))
lines.append("%s(:,:) = 0" % (lhs_printed,))
for i in range(expr.shape[0]):
for j in range(expr.shape[1]):
if expr[i,j] == 0:
pass
else:
lines.append("%s(%d,%d) = %s" % (lhs_printed, i+1, j+1, self._print(expr[i,j])))
return lines
def myfcode(expr, **settings):
s = MyFCodePrinter(settings).doprint(expr)
s = re.sub('\n(\\s*)@', ' &\n\\1 ', s)
s = re.sub('/ &\n \\)', '/) &\n ', s, re.S)
return s
def fcode_postprocess(s):
lines = []
def_lines = []
for line in s.splitlines():
if (line.lstrip().startswith('double precision')
or line.lstrip().startswith('double complex')
or line.lstrip().startswith('integer')
or line.lstrip().startswith('parameter')
or line.lstrip().startswith('dimension')):
if line not in def_lines:
def_lines.append(line)
else:
lines.append(line)
return "\n".join(def_lines + lines)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment