Created
July 13, 2021 15:52
-
-
Save tehrengruber/2495f4f372434644cb90a651df8fbaf9 to your computer and use it in GitHub Desktop.
playground11_mr.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
from copy import copy | |
import types | |
from types import LambdaType | |
from typing import List, Callable, Tuple, Any, Union, Literal | |
import inspect | |
import textwrap | |
import tempfile | |
import eve.datamodels as datamodels | |
import typing_inspect | |
from eve.datamodels import DataModel, field | |
from gt4py_lazy_fields.utils.index_space import UnitRange, ProductSet, CartesianSet | |
import gt4py_lazy_fields | |
import numpy as np | |
from gt4py_lazy_fields.tracing.tracing import trace, tracable, isinstance_, if_, tuple_, zip_, OpaqueCall, PyClosureVar, PyExternalFunction, DataModelConstruction | |
from gt4py_lazy_fields.generic import Lambda, BuiltInFunction, Constant, Call, Symbol, Tuple_, GenericIRNode, Function, Let, Var | |
class Stencil(DataModel): | |
impl: Callable | |
extent: Tuple[int, ...] | |
@tracable | |
def __call__(self, *args): | |
return self.impl(*args) | |
@tracable | |
def stencil_decorator(*, extent: int): | |
return lambda func: Stencil(func, extent) | |
class AbstractField: | |
pass | |
class LazyMap(DataModel): | |
func: Callable | |
domain: ProductSet | |
def __getitem__(self, memory_idx): | |
origin = self.domain[0, 0] | |
image_idx = tuple(image_idx + o for image_idx, o in zip(memory_idx, origin)) | |
assert image_idx in domain | |
return self.func(image_idx) | |
def materialize(self): | |
return np.array([self.func(idx) for idx in self.domain]).reshape(self.domain.shape) | |
#image_type = LazyMap | |
#def _map(stencil, domain): | |
# return LazyMap(stencil, domain) | |
image_type = np.ndarray | |
def map_(stencil, domain): | |
return np.array([stencil(*idx) for idx in domain]).reshape(domain.shape) | |
#getitem(map(lambda i, j: getitem(field, (i, j)), domain), (i, j)) = field(i, j) | |
class Field(DataModel, AbstractField): | |
domain: ProductSet | |
image: image_type | |
#@tracable | |
def __getitem__(self, image_idx): | |
assert image_idx in self.domain | |
origin = self.domain[0, 0] | |
memory_idx = tuple(idx-o for idx, o in zip_(image_idx, origin)) | |
return self.image[memory_idx] | |
def view(self, domain): | |
assert domain.issubset(self.domain) | |
return apply_stencil(lambda *idx: self[idx], domain) | |
def transparent_view(self, domain): | |
return TransparentView(domain, self) | |
class TransparentView(DataModel, AbstractField): | |
domain: ProductSet | |
field: Field | |
def __getitem__(self, position): | |
assert position in self.field.domain | |
return self.field[position] | |
@tracable | |
def new_accessor(field, position): | |
def accessor(*shift): | |
gen = (idx+s for idx, s in zip_(position, shift)) | |
idx = tuple_(gen) | |
return field[idx] | |
return accessor | |
@tracable | |
def apply_stencil(stencil: "Callable", domain, *fields): | |
if len(fields) > 0: | |
stencil = lambda *pos: stencil(pos, *( | |
new_accessor(field.view(domain) if field.domain == domain else field.transparent_view(domain), pos) for field in | |
fields)) | |
return apply_stencil(stencil, domain) | |
return Field(domain, map_(stencil, domain)) | |
@tracable | |
def _fmap_field(stencil, field: "Field"): | |
wrapped_stencil = Stencil(lambda *pos: stencil(new_accessor(field, pos)), stencil.extent) | |
return apply_stencil(wrapped_stencil, field.domain) | |
@tracable | |
def _fmap_transparent_view(stencil, view: "TransparentView"): | |
wrapped_stencil = Stencil(lambda *pos: stencil(new_accessor(view, pos)), stencil.extent) | |
valid_domain = view.field.domain.extend(*(-e for e in stencil.extent)) # consume halo lines | |
if not view.domain.issubset(valid_domain): | |
raise ValueError("Not enough halo lines.") | |
new_field = apply_stencil(wrapped_stencil, valid_domain) | |
return new_field.transparent_view(view.domain) | |
@tracable | |
def fmap(stencil, field): | |
return if_(isinstance_(field, TransparentView), | |
lambda: _fmap_transparent_view(stencil, field), | |
lambda: _fmap_field(stencil, field)) | |
@tracable | |
def laplacian(field: "Field"): | |
@stencil_decorator(extent=(1, 1)) | |
def stencil(f): | |
return -4 * f(0, 0) + f(-1, 0) + f(1, 0) + f(0, -1) + f(0, 1) | |
return fmap(stencil, field) | |
@tracable | |
def laplap(field): | |
lap = laplacian(field) | |
laplap = laplacian(lap) | |
return laplap | |
# | |
# Example 1 - zero origin | |
# | |
domain = UnitRange(0, 5)*UnitRange(0, 5) | |
i_image = map_(lambda i, j: i, domain) | |
j_image = map_(lambda i, j: j, domain) | |
assert (i_image[0, 0], i_image[0, 0]) == (0, 0) | |
assert (j_image[4, 4], j_image[4, 4]) == (4, 4) | |
i_field = Field(domain, i_image) | |
j_field = Field(domain, j_image) | |
assert (i_field[0, 0], j_field[0, 0]) == (0, 0) | |
assert (i_field[domain[-1, -1]], j_field[domain[-1, -1]]) == (4, 4) | |
# | |
# Example 2 - non-zero origin | |
# | |
domain = UnitRange(1, 6)*UnitRange(1, 6) # equal to domain.translate(1, 1) | |
i_image = map_(lambda i, j: i, domain) | |
j_image = map_(lambda i, j: j, domain) | |
assert (i_image[0, 0], i_image[0, 0]) == (1, 1) | |
assert (j_image[4, 4], j_image[4, 4]) == (5, 5) | |
i_field = Field(domain, i_image) | |
j_field = Field(domain, j_image) | |
assert (i_field[1, 1], i_field[1, 1]) == (1, 1) | |
assert (i_field[5, 5], i_field[5, 5]) == (5, 5) | |
# | |
# Example 3 - apply_stencil (frontend feature) | |
# instead of constructing the field manually from a domain and an array we can automate that part | |
domain = UnitRange(0, 5)*UnitRange(0, 5) | |
# manual construction | |
one_field = Field(domain, map_(lambda i, j: 1, domain)) | |
# automatic one using apply_stencil | |
one_field = apply_stencil(lambda i, j: 1, domain) | |
# | |
# Example - virtual assignment (frontend-feature) | |
def virtual_assign(field, domain, stencil): | |
return apply_stencil(lambda i, j: stencil(i, j) if (i, j) in domain else field[i, j], field.domain) | |
zero_field = apply_stencil(lambda i, j: 0, domain) | |
# stateful pseudo-code: zero_field[domain] = map(lambda i, j: 1, domain) | |
modified_field = virtual_assign(zero_field, domain[1:-1, 1:-1], lambda i, j: 1) | |
# conclusion: effort to translate from stateful to functional not higher than from dusk to gt4py | |
# question: what is the effort to remove the conditional inside the "loop"? | |
# answer: the same effort you would have in a stateful inline pass, but now it's something seperate that | |
# can be debugged in isolation | |
# (@tehrengruber see notes for algorithm) | |
# | |
# Example 3 - laplacian of f(x, y) = 1/6*x^3 | |
# | |
f = lambda x, y: 1/6*x**3 | |
lap_f = lambda x, y: x | |
domain = UnitRange(0, 5)*UnitRange(0, 5) | |
interior_domain = domain[1:-1, 1:-1] | |
input = apply_stencil(f, domain) | |
def lap_sten(i, j): | |
return -4 * input[i, j] + input[i-1, j] + input[i+1, j] + input[i, j-1] + input[i, j+1] | |
# without bcs | |
# (wanted to keep the example simple, so the domain get's smaller here) | |
lap_interior = apply_stencil(lap_sten, interior_domain) | |
# with bcs (use analytical solution on boundary) | |
lap = apply_stencil(lambda i, j: lap_sten(i, j) if (i, j) in interior_domain else lap_f(i, j), domain) | |
# | |
# Example - laplacian closer to the frontend | |
# | |
@tracable | |
def laplacian(input: "Field"): | |
@stencil_decorator(extent=(1, 1)) | |
def stencil(i, j): | |
return -4 * input[i, j] + input[i-1, j] + input[i+1, j] + input[i, j-1] + input[i, j+1] | |
return apply_stencil(stencil, input.domain[1:-1, 1:-1]) | |
@tracable | |
def laplap(field): | |
lap = laplacian(field) | |
laplap = laplacian(lap) | |
return laplap | |
domain = UnitRange(0, 5)*UnitRange(0, 5) | |
input = apply_stencil(lambda i, j: 1/6*i**3, domain) | |
result = laplap(input) | |
# | |
# Example in-out fields | |
# Proposition: we don't want to waste memory bandwidth on point-wise stencils | |
# simple example, identity | |
io_field = apply_stencil(lambda i, j: 1, domain) | |
io_field = apply_stencil(lambda i, j: io_field[i, j], domain) | |
# consider something stateful like this (similar to fencil in antons model) | |
# `run_program(identity, input=(io_field,), output=(io_field,))` | |
# central question now, how to avoid the unnecessary copy? | |
# answer: inline everything check if read to io_field with offset, if no avoid copy, if yes copy | |
# done. | |
# | |
# Example - ease of inlining in a functional model | |
# | |
input = apply_stencil(lambda i, j: i, domain) | |
field = Field(domain[1:-1, 1:-1], map_(lambda i, j: input[i+1, j+1], domain[1:-1, 1:-1])) | |
assert field[1, 1] == input[2, 2] # e.g. field[i, j] == input[i+1, j+1] | |
# Field(domain, map_(lambda i, j: input[i+1, j+1], domain))[i, j] == input[i+1, j+1] | |
@tracable | |
def shift(input): | |
@stencil_decorator(extent=(1, 1)) | |
def stencil(i, j): | |
return input[i+1, j+1] | |
return apply_stencil(stencil, input.domain) # input.domain[1:-1, 1:-1] not supported during tracing yet | |
@tracable | |
def identity(input): | |
@stencil_decorator(extent=(0, 0)) | |
def stencil(i, j): | |
return input[i, j] | |
return apply_stencil(stencil, input.domain) | |
@tracable | |
def shift_with_identity(input): | |
return identity(shift(input)) | |
#trc = trace(laplap, (Symbol(name="INPUT_FIELD", type_=Field),)) | |
trc = trace(shift_with_identity, (Symbol(name="INPUT_FIELD", type_=Field),)) | |
from eve import NodeTranslator | |
import gt4py_lazy_fields.tracing.tracing as tracing | |
from gt4py_lazy_fields.utils.uid import uid | |
from gt4py_lazy_fields.tracing.pass_helper.pass_manager import PassManager | |
class SymbolicEvalStencil(NodeTranslator): | |
def is_applicable(self, node, *, symtable): | |
return isinstance(node, Call) and isinstance(node.func, tracing.PyExternalFunction) and node.func.func == map_ \ | |
and not isinstance(node.args[0], Lambda) | |
def transform(self, node: Call, *, symtable): | |
args = node.args | |
assert not node.kwargs | |
idx_symbs = tuple(Symbol(name=f"{dim}_{uid(dim)}", type_=int) for dim in ["I", "J"]) | |
pos_stencil_expr = Lambda( | |
args=idx_symbs, | |
expr=OpaqueCall(func=args[0], args=idx_symbs, kwargs={})) | |
return tracing.Call(node.func, args=(pos_stencil_expr, *node.args[1:]), kwargs={}) | |
from gt4py_lazy_fields.tracing.passes.constant_folding import ConstantFold | |
from gt4py_lazy_fields.tracing.passes.datamodel import DataModelConstructionResolver, DataModelMethodResolution, \ | |
DataModelGetAttrResolver, DataModelCallOperatorResolution, DataModelExternalGetAttrInliner | |
from gt4py_lazy_fields.tracing.passes.opaque_call_resolution import OpaqueCallResolution1, OpaqueCallResolution2 | |
from gt4py_lazy_fields.tracing.passes.fix_call_type import FixCallType | |
from gt4py_lazy_fields.tracing.passes.tuple_getitem_resolver import TupleGetItemResolver | |
from gt4py_lazy_fields.tracing.passes.remove_constant_refs import RemoveConstantRefs | |
from gt4py_lazy_fields.tracing.passes.tracable_function_resolver import TracableFunctionResolver | |
from gt4py_lazy_fields.tracing.passes.remove_unused_symbols import RemoveUnusedSymbols | |
from gt4py_lazy_fields.tracing.passes.single_use_inliner import SingleUseInliner | |
pass_manager = PassManager([ | |
DataModelConstructionResolver(), | |
DataModelCallOperatorResolution(), | |
DataModelGetAttrResolver(), | |
DataModelMethodResolution(), | |
SymbolicEvalStencil(), | |
OpaqueCallResolution1(), | |
OpaqueCallResolution2(), | |
FixCallType(), | |
TupleGetItemResolver(), | |
RemoveConstantRefs(), | |
TracableFunctionResolver() | |
]) | |
trc2 = pass_manager.visit(trc.expr, symtable={}) | |
trc3 = ConstantFold().visit(trc2, symtable={}) | |
trc4 = PassManager([RemoveUnusedSymbols()]).visit(trc3, symtable={}) | |
from gt4py_lazy_fields.tracing.passes.global_symbol_collision_resolver import GlobalSymbolCollisionCollector | |
collisions = GlobalSymbolCollisionCollector.apply(trc4) | |
assert not collisions | |
trc5 = SingleUseInliner.apply(trc4) | |
trc6 = PassManager([DataModelExternalGetAttrInliner(), RemoveUnusedSymbols()]).visit(trc5, symtable={}) | |
class NAryOpsTransformer(NodeTranslator): | |
def visit_Call(self, node: Call): | |
if isinstance(node.func, BuiltInFunction) and node.func.name == "__add__": | |
first_arg, *rem_args = node.args | |
if isinstance(first_arg, Call) and isinstance(first_arg.func, BuiltInFunction) and first_arg.func.name == "__add__": | |
# todo: validate domains match | |
return self.visit(Call(BuiltInFunction("__add__"), args=(*first_arg.args, *rem_args), kwargs={})) | |
return self.generic_visit(node) | |
trc7 = NAryOpsTransformer().visit(trc6) | |
from gt4py_lazy_fields.tracing.pass_helper.scope_visitor import ScopeTranslator | |
from gt4py_lazy_fields.tracing.pass_helper.conversion import beta_reduction | |
from gt4py_lazy_fields.tracing.tracerir_utils import resolve_symbol | |
class InlineOnce(ScopeTranslator): | |
def visit_Call(self, node: Call, symtable, **kwargs): | |
if isinstance(node.func, PyExternalFunction) and node.func.func == Field.__getitem__: | |
field = resolve_symbol(node.args[0], symtable) | |
if isinstance(field, DataModelConstruction): # todo: no typing for datamodels yet... | |
stencil = field.attrs["image"].args[0] | |
return beta_reduction(stencil, node.args[1].elts, {}, closure_symbols=symtable) | |
return self.generic_visit(node, symtable=symtable, **kwargs) | |
trc8 = InlineOnce.apply(trc7) | |
#trc8 = trc7 | |
trc9 = PassManager([RemoveUnusedSymbols()]).visit(trc8, symtable={}) | |
import gt4py_lazy_fields.gtl2ir.gtl2ir as gtl2ir | |
import re | |
class TranslateToGTL2IR(NodeTranslator): | |
def _translate_symbolname(self, node: Symbol): | |
if node.type_ == Field: | |
new_node = gtl2ir.SymbolName(name=node.name, type_=gtl2ir.SymbolRef("Field")) | |
elif typing_inspect.is_tuple_type(node.type_) or node.type_ in [int, ProductSet]: | |
new_node = gtl2ir.SymbolName(name=node.name, type_=node.type_) | |
elif issubclass(node.type_, np.ndarray): | |
new_node = gtl2ir.SymbolName(name=node.name, type_=gtl2ir.Array) | |
else: | |
raise ValueError() | |
return new_node | |
def visit_GenericIRNode(self, node: tracing.GenericIRNode, **kwargs): | |
raise ValueError() | |
def visit_DataModelConstruction(self, node: DataModelConstruction, **kwargs): | |
assert node.type_ == Field | |
return gtl2ir.Call(func=gtl2ir.Construct(), args=(gtl2ir.SymbolRef("Field"), self.visit(node.attrs["domain"]), self.visit(node.attrs["image"]))) | |
def visit_Symbol(self, node: tracing.Symbol, **kwargs): | |
return gtl2ir.SymbolRef(node.name) | |
def visit_Let(self, node: Let, **kwargs): | |
vars_ = tuple(gtl2ir.Var(name=self._translate_symbolname(var.name), value=self.visit(var.value, **kwargs)) for var in node.vars) | |
return gtl2ir.Let(vars=vars_, expr=self.visit(node.expr, **kwargs)) | |
def visit_Lambda(self, node: Lambda, **kwargs): | |
return gtl2ir.Lambda(args=tuple(self._translate_symbolname(arg) for arg in node.args), expr=self.visit(node.expr, **kwargs)) | |
def visit_Constant(self, node: Constant, **kwargs): | |
return gtl2ir.Constant(val=node.val) | |
def visit_Tuple_(self, node: Tuple_, **kwargs): | |
return gtl2ir.Call(func=gtl2ir.ConstructTuple(), args=self.visit(node.elts, **kwargs)) | |
def visit_Call(self, node: Call, **kwargs): | |
if isinstance(node.func, tracing.PyExternalFunction) and node.func.func == map_: | |
return gtl2ir.Call(func=gtl2ir.SymbolRef("map"), args=self.visit(node.args, **kwargs)) | |
elif isinstance(node.func, BuiltInFunction): | |
func_name = node.func.name | |
magic_method_match = re.match("__(?!(__))(.*)__", node.func.name) | |
if magic_method_match: | |
func_name = magic_method_match.groups()[1] | |
if func_name == "rmul": | |
func_name = "mul" | |
if func_name == "getattr": | |
return gtl2ir.Call(func=gtl2ir.GetStructAttr(), args=self.visit(node.args, **kwargs)) | |
else: | |
return gtl2ir.Call(func=gtl2ir.SymbolRef(func_name), args=self.visit(node.args, **kwargs)) | |
elif isinstance(node.func, Lambda): | |
assert not node.kwargs | |
return gtl2ir.Call(func=self.visit(node.func, **kwargs), args=self.visit(node.args, **kwargs)) | |
raise ValueError() | |
def translate_gtl2ir(node: Lambda): | |
vars = [] | |
for func_name, declaration in gtl2ir.declarations.items(): | |
vars.append(gtl2ir.Var(name=gtl2ir.SymbolName(name=func_name, type_=None), value=declaration)) | |
vars.append(gtl2ir.Var(name=gtl2ir.SymbolName(name="Field", type_=None), value=gtl2ir.StructDecl(attr_names=("domain", "image"), attr_types=(gtl2ir.ProductSet, gtl2ir.Array)))) | |
return gtl2ir.Let(vars=tuple(vars), expr=TranslateToGTL2IR().visit(node)) | |
gtl2ir_node = translate_gtl2ir(trc9) | |
bla=1+1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment