Created
July 9, 2021 21:11
-
-
Save tehrengruber/51df8283ad5f5501d6ad9f7af74d1d6b to your computer and use it in GitHub Desktop.
Field/Local execution model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
from copy import copy | |
import types | |
from types import LambdaType | |
from typing import List, Callable, Tuple, Any, Union, Literal | |
import inspect | |
import textwrap | |
import tempfile | |
import eve.datamodels as datamodels | |
import typing_inspect | |
from eve.datamodels import DataModel, field | |
from gt4py_lazy_fields.utils.index_space import UnitRange, ProductSet, CartesianSet | |
import gt4py_lazy_fields | |
import numpy as np | |
from gt4py_lazy_fields.tracing.tracing import trace, tracable, isinstance_, if_, tuple_, zip_, OpaqueCall, PyClosureVar, PyExternalFunction, DataModelConstruction | |
from gt4py_lazy_fields.generic import Lambda, BuiltInFunction, Constant, Call, Symbol, Tuple_, GenericIRNode, Function, Let, Var | |
class Stencil(DataModel): | |
impl: Callable | |
extent: Tuple[int, ...] | |
@tracable | |
def __call__(self, *args): | |
return self.impl(*args) | |
@tracable | |
def stencil_decorator(*, extent: int): | |
return lambda func: Stencil(func, extent) | |
class AbstractField: | |
pass | |
class LazyMap(DataModel): | |
func: Callable | |
domain: ProductSet | |
def __getitem__(self, memory_idx): | |
origin = self.domain[0, 0] | |
image_idx = tuple(image_idx + o for image_idx, o in zip(memory_idx, origin)) | |
assert image_idx in domain | |
return self.func(image_idx) | |
def materialize(self): | |
return np.array([self.func(idx) for idx in self.domain]).reshape(self.domain.shape) | |
#image_type = LazyMap | |
#def _map(stencil, domain): | |
# return LazyMap(stencil, domain) | |
image_type = np.ndarray | |
def _map(stencil, domain): | |
return np.array([stencil(*idx) for idx in domain]).reshape(domain.shape) | |
#getitem(map(lambda i, j: getitem(field, (i, j)), domain), (i, j)) = field(i, j) | |
class Field(DataModel, AbstractField): | |
domain: ProductSet | |
image: image_type | |
def __getitem__(self, image_idx): | |
assert image_idx in self.domain | |
origin = self.domain[0, 0] | |
memory_idx = tuple(idx-o for idx, o in zip_(image_idx, origin)) | |
return self.image[memory_idx] | |
def view(self, domain): | |
assert domain.issubset(self.domain) | |
return apply_stencil(lambda *idx: self[idx], domain) | |
def transparent_view(self, domain): | |
return TransparentView(domain, self) | |
class TransparentView(DataModel, AbstractField): | |
domain: ProductSet | |
field: Field | |
def __getitem__(self, position): | |
assert position in self.field.domain | |
return self.field[position] | |
@tracable | |
def new_accessor(field, position): | |
def accessor(*shift): | |
gen = (idx+s for idx, s in zip_(position, shift)) | |
idx = tuple_(gen) | |
return field[idx] | |
return accessor | |
@tracable | |
def apply_stencil(stencil: "Callable", domain, *fields): | |
if len(fields) > 0: | |
stencil = lambda *pos: stencil(pos, *( | |
new_accessor(field.view(domain) if field.domain == domain else field.transparent_view(domain), pos) for field in | |
fields)) | |
return apply_stencil(stencil, domain) | |
return Field(domain, _map(stencil, domain)) | |
@tracable | |
def _fmap_field(stencil, field: "Field"): | |
wrapped_stencil = Stencil(lambda *pos: stencil(new_accessor(field, pos)), stencil.extent) | |
return apply_stencil(wrapped_stencil, field.domain) | |
@tracable | |
def _fmap_transparent_view(stencil, view: "TransparentView"): | |
wrapped_stencil = Stencil(lambda *pos: stencil(new_accessor(view, pos)), stencil.extent) | |
valid_domain = view.field.domain.extend(*(-e for e in stencil.extent)) # consume halo lines | |
if not view.domain.issubset(valid_domain): | |
raise ValueError("Not enough halo lines.") | |
new_field = apply_stencil(wrapped_stencil, valid_domain) | |
return new_field.transparent_view(view.domain) | |
@tracable | |
def fmap(stencil, field): | |
return if_(isinstance_(field, TransparentView), | |
lambda: _fmap_transparent_view(stencil, field), | |
lambda: _fmap_field(stencil, field)) | |
@tracable | |
def laplacian(field: "Field"): | |
# todo: remove extend use apply stencil temporarily | |
@stencil_decorator(extent=(1, 1)) | |
def stencil(f): | |
return -4 * f(0, 0) + f(-1, 0) + f(1, 0) + f(0, -1) + f(0, 1) | |
return fmap(stencil, field) | |
@tracable | |
def laplap(field): | |
lap = laplacian(field) | |
laplap = laplacian(lap) | |
return laplap | |
domain = UnitRange(0, 5)*UnitRange(0, 5) | |
input = apply_stencil(lambda *pos: 1 if 1 < pos[0] < 3 else 0, domain) | |
input_view = input.transparent_view(input.domain.extend(-2, -2)) | |
result = laplap(input_view) | |
trc=trace(laplap, (Symbol(name="INPUT_FIELD", type_=Field),)) | |
from eve import NodeTranslator | |
import gt4py_lazy_fields.tracing.tracing as tracing | |
from gt4py_lazy_fields.utils.uid import uid | |
from gt4py_lazy_fields.tracing.pass_helper.pass_manager import PassManager | |
class SymbolicEvalStencil(NodeTranslator): | |
def is_applicable(self, node, *, symtable): | |
return isinstance(node, Call) and isinstance(node.func, tracing.PyExternalFunction) and node.func.func == _map \ | |
and not isinstance(node.args[0], Lambda) | |
def transform(self, node: Call, *, symtable): | |
args = node.args | |
assert not node.kwargs | |
idx_symbs = tuple(Symbol(name=f"{dim}_{uid(dim)}", type_=int) for dim in ["I", "J"]) | |
pos_stencil_expr = Lambda( | |
args=idx_symbs, | |
expr=OpaqueCall(func=args[0], args=idx_symbs, kwargs={})) | |
return tracing.Call(node.func, args=(pos_stencil_expr, *node.args[1:]), kwargs={}) | |
from gt4py_lazy_fields.tracing.passes.constant_folding import ConstantFold | |
from gt4py_lazy_fields.tracing.passes.datamodel import DataModelConstructionResolver, DataModelMethodResolution, \ | |
DataModelGetAttrResolver, DataModelCallOperatorResolution, DataModelExternalGetAttrInliner | |
from gt4py_lazy_fields.tracing.passes.opaque_call_resolution import OpaqueCallResolution1, OpaqueCallResolution2 | |
from gt4py_lazy_fields.tracing.passes.fix_call_type import FixCallType | |
from gt4py_lazy_fields.tracing.passes.tuple_getitem_resolver import TupleGetItemResolver | |
from gt4py_lazy_fields.tracing.passes.remove_constant_refs import RemoveConstantRefs | |
from gt4py_lazy_fields.tracing.passes.tracable_function_resolver import TracableFunctionResolver | |
from gt4py_lazy_fields.tracing.passes.remove_unused_symbols import RemoveUnusedSymbols | |
from gt4py_lazy_fields.tracing.passes.single_use_inliner import SingleUseInliner | |
pass_manager = PassManager([ | |
DataModelConstructionResolver(), | |
DataModelCallOperatorResolution(), | |
DataModelGetAttrResolver(), | |
DataModelMethodResolution(), | |
SymbolicEvalStencil(), | |
OpaqueCallResolution1(), | |
OpaqueCallResolution2(), | |
FixCallType(), | |
TupleGetItemResolver(), | |
RemoveConstantRefs(), | |
TracableFunctionResolver() | |
]) | |
trc2 = pass_manager.visit(trc.expr, symtable={}) | |
trc3 = ConstantFold().visit(trc2, symtable={}) | |
trc4 = PassManager([RemoveUnusedSymbols()]).visit(trc3, symtable={}) | |
from gt4py_lazy_fields.tracing.passes.global_symbol_collision_resolver import GlobalSymbolCollisionCollector | |
collisions = GlobalSymbolCollisionCollector.apply(trc4) | |
assert not collisions | |
trc5 = SingleUseInliner.apply(trc4) | |
trc6 = PassManager([DataModelExternalGetAttrInliner(), RemoveUnusedSymbols()]).visit(trc5, symtable={}) | |
class NAryOpsTransformer(NodeTranslator): | |
def visit_Call(self, node: Call): | |
if isinstance(node.func, BuiltInFunction) and node.func.name == "__add__": | |
first_arg, *rem_args = node.args | |
if isinstance(first_arg, Call) and isinstance(first_arg.func, BuiltInFunction) and first_arg.func.name == "__add__": | |
# todo: validate domains match | |
return self.visit(Call(BuiltInFunction("__add__"), args=(*first_arg.args, *rem_args), kwargs={})) | |
return self.generic_visit(node) | |
trc7 = NAryOpsTransformer().visit(trc6) | |
from gt4py_lazy_fields.tracing.pass_helper.scope_visitor import ScopeTranslator | |
from gt4py_lazy_fields.tracing.pass_helper.conversion import beta_reduction | |
from gt4py_lazy_fields.tracing.tracerir_utils import resolve_symbol | |
class InlineOnce(ScopeTranslator): | |
def visit_Call(self, node: Call, symtable, **kwargs): | |
if isinstance(node.func, PyExternalFunction) and node.func.func == Field.__getitem__: | |
field = resolve_symbol(node.args[0], symtable) | |
if isinstance(field, DataModelConstruction): # todo: no typing for datamodels yet... | |
stencil = field.attrs["image"].args[0] | |
return beta_reduction(stencil, node.args[1].elts, {}, closure_symbols=symtable) | |
return self.generic_visit(node, symtable=symtable, **kwargs) | |
#trc8 = InlineOnce.apply(trc7) | |
trc8 = trc7 | |
trc9 = PassManager([RemoveUnusedSymbols()]).visit(trc8, symtable={}) | |
import gt4py_lazy_fields.gtl2ir.gtl2ir as gtl2ir | |
import re | |
class TranslateToGTL2IR(NodeTranslator): | |
def _translate_symbolname(self, node: Symbol): | |
if node.type_ == Field: | |
new_node = gtl2ir.SymbolName(name=node.name, type_=gtl2ir.SymbolRef("Field")) | |
elif typing_inspect.is_tuple_type(node.type_) or node.type_ in [int, ProductSet]: | |
new_node = gtl2ir.SymbolName(name=node.name, type_=node.type_) | |
elif issubclass(node.type_, np.ndarray): | |
new_node = gtl2ir.SymbolName(name=node.name, type_=gtl2ir.Array) | |
else: | |
raise ValueError() | |
return new_node | |
def visit_GenericIRNode(self, node: tracing.GenericIRNode, **kwargs): | |
raise ValueError() | |
def visit_DataModelConstruction(self, node: DataModelConstruction, **kwargs): | |
assert node.type_ == Field | |
return gtl2ir.Call(func=gtl2ir.Construct(), args=(gtl2ir.SymbolRef("Field"), self.visit(node.attrs["domain"]), self.visit(node.attrs["image"]))) | |
def visit_Symbol(self, node: tracing.Symbol, **kwargs): | |
return gtl2ir.SymbolRef(node.name) | |
def visit_Let(self, node: Let, **kwargs): | |
vars_ = tuple(gtl2ir.Var(name=self._translate_symbolname(var.name), value=self.visit(var.value, **kwargs)) for var in node.vars) | |
return gtl2ir.Let(vars=vars_, expr=self.visit(node.expr, **kwargs)) | |
def visit_Lambda(self, node: Lambda, **kwargs): | |
return gtl2ir.Lambda(args=tuple(self._translate_symbolname(arg) for arg in node.args), expr=self.visit(node.expr, **kwargs)) | |
def visit_Constant(self, node: Constant, **kwargs): | |
return gtl2ir.Constant(val=node.val) | |
def visit_Tuple_(self, node: Tuple_, **kwargs): | |
return gtl2ir.Call(func=gtl2ir.ConstructTuple(), args=self.visit(node.elts, **kwargs)) | |
def visit_Call(self, node: Call, **kwargs): | |
if isinstance(node.func, tracing.PyExternalFunction) and node.func.func == _map: | |
return gtl2ir.Call(func=gtl2ir.SymbolRef("map"), args=self.visit(node.args, **kwargs)) | |
elif isinstance(node.func, BuiltInFunction): | |
func_name = node.func.name | |
magic_method_match = re.match("__(?!(__))(.*)__", node.func.name) | |
if magic_method_match: | |
func_name = magic_method_match.groups()[1] | |
if func_name == "rmul": | |
func_name = "mul" | |
if func_name == "getattr": | |
return gtl2ir.Call(func=gtl2ir.GetStructAttr(), args=self.visit(node.args, **kwargs)) | |
else: | |
return gtl2ir.Call(func=gtl2ir.SymbolRef(func_name), args=self.visit(node.args, **kwargs)) | |
elif isinstance(node.func, Lambda): | |
assert not node.kwargs | |
return gtl2ir.Call(func=self.visit(node.func, **kwargs), args=self.visit(node.args, **kwargs)) | |
raise ValueError() | |
def translate_gtl2ir(node: Lambda): | |
vars = [] | |
for func_name, declaration in gtl2ir.declarations.items(): | |
vars.append(gtl2ir.Var(name=gtl2ir.SymbolName(name=func_name, type_=None), value=declaration)) | |
vars.append(gtl2ir.Var(name=gtl2ir.SymbolName(name="Field", type_=None), value=gtl2ir.StructDecl(attr_names=("domain", "image"), attr_types=(gtl2ir.ProductSet, gtl2ir.Array)))) | |
return gtl2ir.Let(vars=tuple(vars), expr=TranslateToGTL2IR().visit(node)) | |
gtl2ir_node = translate_gtl2ir(trc9) | |
from gt4py_lazy_fields.l2ir.passes.type_inference import TypeInference | |
class InlineOnce(NodeTranslator): | |
def visit_Call(self, node): | |
pass | |
#if isinstance(node.func, BuiltInFunction) and node.func.name == "getitem" and | |
bla=1+1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment