Skip to content

Instantly share code, notes, and snippets.

@AranKomat
Created April 18, 2020 16:16
Show Gist options
  • Save AranKomat/95ede74b27f16efc8c114b950f1a2a9e to your computer and use it in GitHub Desktop.
Save AranKomat/95ede74b27f16efc8c114b950f1a2a9e to your computer and use it in GitHub Desktop.
# Trivial application of scatter_add_ to hadamard product and inner product
# The following links may be helpful for understanding:
# https://github.com/rusty1s/pytorch_scatter
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_add_
# Generalization to scatter_matmul or scatter_einsum requires custom cuda kernel.
# I hope somebody will make it in the future!
# Caveat: I found the current PyTorch implementation of scatter_add_ is slower with float16, so make the inputs float32.
def scatter_inner_prod(v, w, index, dim1, dim2):
'''v: [x1, x2, ... xm, ..., xn]
w: [x1, x2, ... y, ..., xn]
index: [x1, x2, ... xm, ..., xn]
dim1: scalar (m = dim1)
dim2: scalar (dimension to reduce)
output: [x1, x2, ... xm, ..., xn]/dim2 (i.e. all indices but dim2 remains.)
The implementation is simply the application of v \dot w = (||v||^2 + ||w||^2 - ||v - w||^2) / 2
It's also possible to implement with scatter_hadamard'''
v_norm_sq = v.norm(dim=dim2) ** 2
w_norm_sq = w.norm(dim=dim2) ** 2
w_clone = w.clone()
w_clone.scatter_add_(dim1, index, -v) # w - v
return (v_norm_sq + w_norm_sq - (w_clone.norm(dim=dim2) ** 2))/2
# scatter_hadamard is unfinished. If there's a demand for this function, I'll complete it.
def scatter_hadamard(v, w, index, dim):
'''v: [x1, x2, ... xm, ..., xn]
w: [x1, x2, ... y, ..., xn]
index: [x1, x2, ... xm, ..., xn]
dim: scalar
output: [x1, x2, ... xm, ..., xn]
The implementation is simply the application of scatter_add and log(v)+log(w) = log(vw)'''
v_sign = v.sign()
w_sign = w.sign()
zero_mask_v = v == 0
zero_mask_w = w == 0
log_v = v.abs().log()
log_w = w.abs().log()
log_w.scatter_add_(dim, index, log_v) # log(vw)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment