Created
April 18, 2020 16:16
-
-
Save AranKomat/95ede74b27f16efc8c114b950f1a2a9e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Trivial application of scatter_add_ to hadamard product and inner product | |
# The following links may be helpful for understanding: | |
# https://github.com/rusty1s/pytorch_scatter | |
# https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_add_ | |
# Generalization to scatter_matmul or scatter_einsum requires custom cuda kernel. | |
# I hope somebody will make it in the future! | |
# Caveat: I found the current PyTorch implementation of scatter_add_ is slower with float16, so make the inputs float32. | |
def scatter_inner_prod(v, w, index, dim1, dim2): | |
'''v: [x1, x2, ... xm, ..., xn] | |
w: [x1, x2, ... y, ..., xn] | |
index: [x1, x2, ... xm, ..., xn] | |
dim1: scalar (m = dim1) | |
dim2: scalar (dimension to reduce) | |
output: [x1, x2, ... xm, ..., xn]/dim2 (i.e. all indices but dim2 remains.) | |
The implementation is simply the application of v \dot w = (||v||^2 + ||w||^2 - ||v - w||^2) / 2 | |
It's also possible to implement with scatter_hadamard''' | |
v_norm_sq = v.norm(dim=dim2) ** 2 | |
w_norm_sq = w.norm(dim=dim2) ** 2 | |
w_clone = w.clone() | |
w_clone.scatter_add_(dim1, index, -v) # w - v | |
return (v_norm_sq + w_norm_sq - (w_clone.norm(dim=dim2) ** 2))/2 | |
# scatter_hadamard is unfinished. If there's a demand for this function, I'll complete it. | |
def scatter_hadamard(v, w, index, dim): | |
'''v: [x1, x2, ... xm, ..., xn] | |
w: [x1, x2, ... y, ..., xn] | |
index: [x1, x2, ... xm, ..., xn] | |
dim: scalar | |
output: [x1, x2, ... xm, ..., xn] | |
The implementation is simply the application of scatter_add and log(v)+log(w) = log(vw)''' | |
v_sign = v.sign() | |
w_sign = w.sign() | |
zero_mask_v = v == 0 | |
zero_mask_w = w == 0 | |
log_v = v.abs().log() | |
log_w = w.abs().log() | |
log_w.scatter_add_(dim, index, log_v) # log(vw) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment