Skip to content

Instantly share code, notes, and snippets.

@AranKomat
Created November 8, 2021 21:29
Show Gist options
  • Save AranKomat/6c867357808def692426af5259d20ac2 to your computer and use it in GitHub Desktop.
Save AranKomat/6c867357808def692426af5259d20ac2 to your computer and use it in GitHub Desktop.
import jax
import jax.numpy as jnp
from functools import partial
from jax import vmap
def scatter(input, dim, index, src, reduce=None):
# Works like PyTorch's scatter. See https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html
dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,))
if reduce is None:
_scatter = jax.lax.scatter
elif reduce == "add":
_scatter = jax.lax.scatter_add
elif reduce == "multiply":
_scatter = jax.lax.scatter_mul
_scatter = partial(_scatter, dimension_numbers=dnums)
vmap_inner = partial(vmap, in_axes=(0, 0, 0), out_axes=0)
vmap_outer = partial(vmap, in_axes=(1, 1, 1), out_axes=1)
for idx in range(len(input.shape)):
if idx == dim:
pass
elif idx < dim:
_scatter = vmap_inner(_scatter)
else:
_scatter = vmap_outer(_scatter)
return _scatter(input, jnp.expand_dims(index, axis=-1), src)
index = jnp.asarray([[0, 1], [1, 0]])
src = jnp.asarray([[1, 2], [3, 4]])
input = jnp.zeros_like(src)
output = scatter(input, 0, index, src)
print(output)
'''
import torch
index = torch.tensor([[0, 1], [1, 0]])
src = torch.tensor([[1, 2], [3, 4]])
input = torch.zeros(2, 2).long()
input.scatter_(0, index, src)
print(input)
'''
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment