Skip to content

Instantly share code, notes, and snippets.

@sergei-mironov
Created April 11, 2023 13:16
Show Gist options
  • Save sergei-mironov/a40d20edf2bf8f7debc046ac41a08cd9 to your computer and use it in GitHub Desktop.
Save sergei-mironov/a40d20edf2bf8f7debc046ac41a08cd9 to your computer and use it in GitHub Desktop.
Reference implementation if MHLO::Scatter in Python
from frozendict import frozendict
from typing import List, Any, Callable
from copy import deepcopy
Dim = int # Dimention "names"
Index = Dict[Dim, int] # `Index :: Dimention -> Coordinate`
# (To-be implemented as Frozendict to make Python data hashable)
Value = complex # A value
Tensor = Dict[Index, Value] # `Tensor :: Index -> Value`
def make_tensor(shape:List[int], lst:list) -> Tensor:
""" Constructs a tensor """
pass
def tensor_slice(t:Tensor, i:Index) -> List[Value]:
""" Aka `t[i[0], ... , : , ... , i[N-1]]` """
pass
def tensor_update(t:Tensor, i:Index, v:List[Values]) -> Tensor:
""" Aka `t2=copy(t); t2[i[0], ... , : , ... , i[N-1]] = v[:]` """
pass
def sctter(inputs:Tensor,
scatter_indices:Tensor,
updates:Tensor,
update_computation:Callable[[List[Value],List[Value]],List[Value]],
attrs) -> Tensor:
""" A reference `MHLO::Scatter` implementation in simple Python.
Ref. https://github.com/openxla/stablehlo/blob/main/docs/spec.md#scatter
"""
def _result_index(update_index: Index) -> Index:
# Computes the `result_index` based on `updated_index`, `scatter_indices` and `attrs`.
pass
def _exec(update_indices:List[Index], results:Tensor) -> Tensor:
if len(update_indices)>0:
update_index:Index = update_indices[0]
result_index:Index = _result_index(update_index)
updated_values = update_computation(tensor_slice(results, result_index),
tensor_slice(updates, update_index))
updated_results = tensor_update(results, result_index, updated_values)
return _exec(update_indices[1:], updated_results)
else:
return results
schedule = list(sorted(inputs.keys(), key=lambda i: sorted(i.items()))) # ???
return _exec(schedule, inputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment