Skip to content

Instantly share code, notes, and snippets.

@tehrengruber
Created July 9, 2021 21:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tehrengruber/51df8283ad5f5501d6ad9f7af74d1d6b to your computer and use it in GitHub Desktop.
Save tehrengruber/51df8283ad5f5501d6ad9f7af74d1d6b to your computer and use it in GitHub Desktop.
Field/Local execution model
import ast
from copy import copy
import types
from types import LambdaType
from typing import List, Callable, Tuple, Any, Union, Literal
import inspect
import textwrap
import tempfile
import eve.datamodels as datamodels
import typing_inspect
from eve.datamodels import DataModel, field
from gt4py_lazy_fields.utils.index_space import UnitRange, ProductSet, CartesianSet
import gt4py_lazy_fields
import numpy as np
from gt4py_lazy_fields.tracing.tracing import trace, tracable, isinstance_, if_, tuple_, zip_, OpaqueCall, PyClosureVar, PyExternalFunction, DataModelConstruction
from gt4py_lazy_fields.generic import Lambda, BuiltInFunction, Constant, Call, Symbol, Tuple_, GenericIRNode, Function, Let, Var
class Stencil(DataModel):
impl: Callable
extent: Tuple[int, ...]
@tracable
def __call__(self, *args):
return self.impl(*args)
@tracable
def stencil_decorator(*, extent: int):
return lambda func: Stencil(func, extent)
class AbstractField:
pass
class LazyMap(DataModel):
func: Callable
domain: ProductSet
def __getitem__(self, memory_idx):
origin = self.domain[0, 0]
image_idx = tuple(image_idx + o for image_idx, o in zip(memory_idx, origin))
assert image_idx in domain
return self.func(image_idx)
def materialize(self):
return np.array([self.func(idx) for idx in self.domain]).reshape(self.domain.shape)
#image_type = LazyMap
#def _map(stencil, domain):
# return LazyMap(stencil, domain)
image_type = np.ndarray
def _map(stencil, domain):
return np.array([stencil(*idx) for idx in domain]).reshape(domain.shape)
#getitem(map(lambda i, j: getitem(field, (i, j)), domain), (i, j)) = field(i, j)
class Field(DataModel, AbstractField):
domain: ProductSet
image: image_type
def __getitem__(self, image_idx):
assert image_idx in self.domain
origin = self.domain[0, 0]
memory_idx = tuple(idx-o for idx, o in zip_(image_idx, origin))
return self.image[memory_idx]
def view(self, domain):
assert domain.issubset(self.domain)
return apply_stencil(lambda *idx: self[idx], domain)
def transparent_view(self, domain):
return TransparentView(domain, self)
class TransparentView(DataModel, AbstractField):
domain: ProductSet
field: Field
def __getitem__(self, position):
assert position in self.field.domain
return self.field[position]
@tracable
def new_accessor(field, position):
def accessor(*shift):
gen = (idx+s for idx, s in zip_(position, shift))
idx = tuple_(gen)
return field[idx]
return accessor
@tracable
def apply_stencil(stencil: "Callable", domain, *fields):
if len(fields) > 0:
stencil = lambda *pos: stencil(pos, *(
new_accessor(field.view(domain) if field.domain == domain else field.transparent_view(domain), pos) for field in
fields))
return apply_stencil(stencil, domain)
return Field(domain, _map(stencil, domain))
@tracable
def _fmap_field(stencil, field: "Field"):
wrapped_stencil = Stencil(lambda *pos: stencil(new_accessor(field, pos)), stencil.extent)
return apply_stencil(wrapped_stencil, field.domain)
@tracable
def _fmap_transparent_view(stencil, view: "TransparentView"):
wrapped_stencil = Stencil(lambda *pos: stencil(new_accessor(view, pos)), stencil.extent)
valid_domain = view.field.domain.extend(*(-e for e in stencil.extent)) # consume halo lines
if not view.domain.issubset(valid_domain):
raise ValueError("Not enough halo lines.")
new_field = apply_stencil(wrapped_stencil, valid_domain)
return new_field.transparent_view(view.domain)
@tracable
def fmap(stencil, field):
return if_(isinstance_(field, TransparentView),
lambda: _fmap_transparent_view(stencil, field),
lambda: _fmap_field(stencil, field))
@tracable
def laplacian(field: "Field"):
# todo: remove extend use apply stencil temporarily
@stencil_decorator(extent=(1, 1))
def stencil(f):
return -4 * f(0, 0) + f(-1, 0) + f(1, 0) + f(0, -1) + f(0, 1)
return fmap(stencil, field)
@tracable
def laplap(field):
lap = laplacian(field)
laplap = laplacian(lap)
return laplap
domain = UnitRange(0, 5)*UnitRange(0, 5)
input = apply_stencil(lambda *pos: 1 if 1 < pos[0] < 3 else 0, domain)
input_view = input.transparent_view(input.domain.extend(-2, -2))
result = laplap(input_view)
trc=trace(laplap, (Symbol(name="INPUT_FIELD", type_=Field),))
from eve import NodeTranslator
import gt4py_lazy_fields.tracing.tracing as tracing
from gt4py_lazy_fields.utils.uid import uid
from gt4py_lazy_fields.tracing.pass_helper.pass_manager import PassManager
class SymbolicEvalStencil(NodeTranslator):
def is_applicable(self, node, *, symtable):
return isinstance(node, Call) and isinstance(node.func, tracing.PyExternalFunction) and node.func.func == _map \
and not isinstance(node.args[0], Lambda)
def transform(self, node: Call, *, symtable):
args = node.args
assert not node.kwargs
idx_symbs = tuple(Symbol(name=f"{dim}_{uid(dim)}", type_=int) for dim in ["I", "J"])
pos_stencil_expr = Lambda(
args=idx_symbs,
expr=OpaqueCall(func=args[0], args=idx_symbs, kwargs={}))
return tracing.Call(node.func, args=(pos_stencil_expr, *node.args[1:]), kwargs={})
from gt4py_lazy_fields.tracing.passes.constant_folding import ConstantFold
from gt4py_lazy_fields.tracing.passes.datamodel import DataModelConstructionResolver, DataModelMethodResolution, \
DataModelGetAttrResolver, DataModelCallOperatorResolution, DataModelExternalGetAttrInliner
from gt4py_lazy_fields.tracing.passes.opaque_call_resolution import OpaqueCallResolution1, OpaqueCallResolution2
from gt4py_lazy_fields.tracing.passes.fix_call_type import FixCallType
from gt4py_lazy_fields.tracing.passes.tuple_getitem_resolver import TupleGetItemResolver
from gt4py_lazy_fields.tracing.passes.remove_constant_refs import RemoveConstantRefs
from gt4py_lazy_fields.tracing.passes.tracable_function_resolver import TracableFunctionResolver
from gt4py_lazy_fields.tracing.passes.remove_unused_symbols import RemoveUnusedSymbols
from gt4py_lazy_fields.tracing.passes.single_use_inliner import SingleUseInliner
pass_manager = PassManager([
DataModelConstructionResolver(),
DataModelCallOperatorResolution(),
DataModelGetAttrResolver(),
DataModelMethodResolution(),
SymbolicEvalStencil(),
OpaqueCallResolution1(),
OpaqueCallResolution2(),
FixCallType(),
TupleGetItemResolver(),
RemoveConstantRefs(),
TracableFunctionResolver()
])
trc2 = pass_manager.visit(trc.expr, symtable={})
trc3 = ConstantFold().visit(trc2, symtable={})
trc4 = PassManager([RemoveUnusedSymbols()]).visit(trc3, symtable={})
from gt4py_lazy_fields.tracing.passes.global_symbol_collision_resolver import GlobalSymbolCollisionCollector
collisions = GlobalSymbolCollisionCollector.apply(trc4)
assert not collisions
trc5 = SingleUseInliner.apply(trc4)
trc6 = PassManager([DataModelExternalGetAttrInliner(), RemoveUnusedSymbols()]).visit(trc5, symtable={})
class NAryOpsTransformer(NodeTranslator):
def visit_Call(self, node: Call):
if isinstance(node.func, BuiltInFunction) and node.func.name == "__add__":
first_arg, *rem_args = node.args
if isinstance(first_arg, Call) and isinstance(first_arg.func, BuiltInFunction) and first_arg.func.name == "__add__":
# todo: validate domains match
return self.visit(Call(BuiltInFunction("__add__"), args=(*first_arg.args, *rem_args), kwargs={}))
return self.generic_visit(node)
trc7 = NAryOpsTransformer().visit(trc6)
from gt4py_lazy_fields.tracing.pass_helper.scope_visitor import ScopeTranslator
from gt4py_lazy_fields.tracing.pass_helper.conversion import beta_reduction
from gt4py_lazy_fields.tracing.tracerir_utils import resolve_symbol
class InlineOnce(ScopeTranslator):
def visit_Call(self, node: Call, symtable, **kwargs):
if isinstance(node.func, PyExternalFunction) and node.func.func == Field.__getitem__:
field = resolve_symbol(node.args[0], symtable)
if isinstance(field, DataModelConstruction): # todo: no typing for datamodels yet...
stencil = field.attrs["image"].args[0]
return beta_reduction(stencil, node.args[1].elts, {}, closure_symbols=symtable)
return self.generic_visit(node, symtable=symtable, **kwargs)
#trc8 = InlineOnce.apply(trc7)
trc8 = trc7
trc9 = PassManager([RemoveUnusedSymbols()]).visit(trc8, symtable={})
import gt4py_lazy_fields.gtl2ir.gtl2ir as gtl2ir
import re
class TranslateToGTL2IR(NodeTranslator):
def _translate_symbolname(self, node: Symbol):
if node.type_ == Field:
new_node = gtl2ir.SymbolName(name=node.name, type_=gtl2ir.SymbolRef("Field"))
elif typing_inspect.is_tuple_type(node.type_) or node.type_ in [int, ProductSet]:
new_node = gtl2ir.SymbolName(name=node.name, type_=node.type_)
elif issubclass(node.type_, np.ndarray):
new_node = gtl2ir.SymbolName(name=node.name, type_=gtl2ir.Array)
else:
raise ValueError()
return new_node
def visit_GenericIRNode(self, node: tracing.GenericIRNode, **kwargs):
raise ValueError()
def visit_DataModelConstruction(self, node: DataModelConstruction, **kwargs):
assert node.type_ == Field
return gtl2ir.Call(func=gtl2ir.Construct(), args=(gtl2ir.SymbolRef("Field"), self.visit(node.attrs["domain"]), self.visit(node.attrs["image"])))
def visit_Symbol(self, node: tracing.Symbol, **kwargs):
return gtl2ir.SymbolRef(node.name)
def visit_Let(self, node: Let, **kwargs):
vars_ = tuple(gtl2ir.Var(name=self._translate_symbolname(var.name), value=self.visit(var.value, **kwargs)) for var in node.vars)
return gtl2ir.Let(vars=vars_, expr=self.visit(node.expr, **kwargs))
def visit_Lambda(self, node: Lambda, **kwargs):
return gtl2ir.Lambda(args=tuple(self._translate_symbolname(arg) for arg in node.args), expr=self.visit(node.expr, **kwargs))
def visit_Constant(self, node: Constant, **kwargs):
return gtl2ir.Constant(val=node.val)
def visit_Tuple_(self, node: Tuple_, **kwargs):
return gtl2ir.Call(func=gtl2ir.ConstructTuple(), args=self.visit(node.elts, **kwargs))
def visit_Call(self, node: Call, **kwargs):
if isinstance(node.func, tracing.PyExternalFunction) and node.func.func == _map:
return gtl2ir.Call(func=gtl2ir.SymbolRef("map"), args=self.visit(node.args, **kwargs))
elif isinstance(node.func, BuiltInFunction):
func_name = node.func.name
magic_method_match = re.match("__(?!(__))(.*)__", node.func.name)
if magic_method_match:
func_name = magic_method_match.groups()[1]
if func_name == "rmul":
func_name = "mul"
if func_name == "getattr":
return gtl2ir.Call(func=gtl2ir.GetStructAttr(), args=self.visit(node.args, **kwargs))
else:
return gtl2ir.Call(func=gtl2ir.SymbolRef(func_name), args=self.visit(node.args, **kwargs))
elif isinstance(node.func, Lambda):
assert not node.kwargs
return gtl2ir.Call(func=self.visit(node.func, **kwargs), args=self.visit(node.args, **kwargs))
raise ValueError()
def translate_gtl2ir(node: Lambda):
vars = []
for func_name, declaration in gtl2ir.declarations.items():
vars.append(gtl2ir.Var(name=gtl2ir.SymbolName(name=func_name, type_=None), value=declaration))
vars.append(gtl2ir.Var(name=gtl2ir.SymbolName(name="Field", type_=None), value=gtl2ir.StructDecl(attr_names=("domain", "image"), attr_types=(gtl2ir.ProductSet, gtl2ir.Array))))
return gtl2ir.Let(vars=tuple(vars), expr=TranslateToGTL2IR().visit(node))
gtl2ir_node = translate_gtl2ir(trc9)
from gt4py_lazy_fields.l2ir.passes.type_inference import TypeInference
class InlineOnce(NodeTranslator):
def visit_Call(self, node):
pass
#if isinstance(node.func, BuiltInFunction) and node.func.name == "getitem" and
bla=1+1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment