public
Last active

Dump from Sympy to Fortran

  • Download Gist
sympyfformat.py
Python
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
"""
Dump matrices and whatnot from Sympy to Fortran.
 
P. Virtanen 2013, Public Domain.
"""
 
import os
import shutil
import subprocess
import argparse
import re
 
import sympy
 
def example():
a = sympy.Symbol('a')
b = sympy.Symbol('b')
values = {
sympy.Symbol('foo'): sympy.Matrix([[a, b], [b, a+2]])
}
dump_fortran_values('out.inc', values)
 
def dump_fortran_values(filename, values):
print("Generating %s" % filename)
 
inputs = set()
for v in values.values():
for atom in iter_atoms(v):
if atom not in values and atom.is_Symbol and not atom.is_NumberSymbol:
inputs.add(atom)
 
code = format_fortran_values(values, extra_symbols=True,
extra_items=inputs)
 
with open(filename, 'wb') as f:
write_fortran_header(f)
f.write("!! Inputs: %s\n" % ", ".join(sorted(map(str, inputs))))
f.write("\n")
f.write(code)
 
def write_fortran_header(f):
f.write("!! -*-fortran-*-\n")
f.write("!! This is automatically generated by %s -- do not edit\n"
% __file__)
 
def iter_atoms(m):
if isinstance(m, sympy.Matrix):
for i in range(m.shape[0]):
for j in range(m.shape[1]):
for atom in m[i,j].atoms():
yield atom
else:
for x in m.atoms():
yield x
 
def format_fortran_values(values, extra_symbols=False, extra_items=()):
# -- Sort items in dependency order
items = list(values.items())
items.sort(key=lambda x: x[0].name)
items.reverse()
new_items = []
while items:
for it in items:
if not any(in_expression(it[0], expr) for _, expr in items):
new_items.insert(0, it)
items.remove(it)
break
else:
raise ValueError("dependency loop in definitions")
items = new_items
 
# -- Produce Fortran output
 
s = ""
defined = set()
for k, v in items + list((z, None) for z in extra_items):
if hasattr(k, 'is_integer') and k.is_integer:
s += " integer :: %s\n" % str(k)
elif hasattr(k, 'is_real') and k.is_real:
s += " double precision :: %s\n" % str(k)
else:
s += " double complex :: %s\n" % str(k)
defined.add(k)
if extra_symbols and v is not None:
if isinstance(v, sympy.Matrix):
atoms = set()
for i in range(v.shape[0]):
for j in range(v.shape[1]):
atoms.update(v[i,j].atoms())
else:
atoms = v.atoms()
for k2 in atoms:
if (k2.is_Symbol or k2.is_NumberSymbol) and k2 not in defined:
defined.add(k2)
if hasattr(k2, 'is_integer') and k2.is_integer:
s += " integer :: %s\n" % str(k2)
elif hasattr(k2, 'is_real') and k2.is_real:
s += " double precision :: %s\n" % str(k2)
else:
s += " double complex :: %s\n" % str(k2)
for k, v in items:
s += myfcode(v, assign_to=str(k)) + "\n"
s = fcode_postprocess(s)
return s
 
def in_expression(sym, expr):
if isinstance(expr, sympy.Matrix):
return any(any(in_expression(sym, expr[i,j])
for j in range(expr.shape[1]))
for i in range(expr.shape[0]))
if sym in expr.atoms():
return True
return False
 
 
from sympy.printing.fcode import FCodePrinter
 
class MyFCodePrinter(FCodePrinter):
def _print_Sum(self, expr):
assert len(expr.variables) == 1
sum_var = expr.variables[0]
a = expr.args[1][1]
b = expr.args[1][2]
assert sum_var.is_Atom
return "sum( (/ ( %s , %s=%s, %s ) /) )" % (
self._print(expr.args[0]), str(sum_var), str(a), str(b))
 
def _print_Function(self, expr):
if expr.func == sympy.sign:
return "sign(1d0, %s)" % (self.stringify(expr.args, ", "),)
elif expr.func == sympy.im:
return "aimag(%s)" % (self.stringify(expr.args, ", "),)
elif expr.func == sympy.conjugate:
return "conjg(%s)" % (self.stringify(expr.args, ", "),)
else:
return "%s(%s)" % (str(expr.func),
self.stringify(expr.args, ", "),)
 
def _doprint_a_piece(self, expr, assign_to=None):
if isinstance(expr, sympy.Matrix):
return self._doprint_a_matrix(expr, assign_to=assign_to)
return FCodePrinter._doprint_a_piece(self, expr, assign_to)
 
def _doprint_a_matrix(self, expr, assign_to=None):
lines = []
if assign_to is not None:
lhs_printed = self._print(assign_to)
else:
lhs_printed = "matrix"
 
lines.append("dimension %s(%d,%d)"
% (lhs_printed, expr.shape[0], expr.shape[1]))
lines.append("%s(:,:) = 0" % (lhs_printed,))
for i in range(expr.shape[0]):
for j in range(expr.shape[1]):
if expr[i,j] == 0:
pass
else:
lines.append("%s(%d,%d) = %s" % (lhs_printed, i+1, j+1, self._print(expr[i,j])))
return lines
 
def myfcode(expr, **settings):
s = MyFCodePrinter(settings).doprint(expr)
s = re.sub('\n(\\s*)@', ' &\n\\1 ', s)
s = re.sub('/ &\n \\)', '/) &\n ', s, re.S)
return s
 
def fcode_postprocess(s):
lines = []
def_lines = []
for line in s.splitlines():
if (line.lstrip().startswith('double precision')
or line.lstrip().startswith('double complex')
or line.lstrip().startswith('integer')
or line.lstrip().startswith('parameter')
or line.lstrip().startswith('dimension')):
if line not in def_lines:
def_lines.append(line)
else:
lines.append(line)
return "\n".join(def_lines + lines)
 
if __name__ == "__main__":
main()

Please sign in to comment on this gist.

Something went wrong with that request. Please try again.