Skip to content

Instantly share code, notes, and snippets.

@speedcell4
Last active November 23, 2017 11:50
Show Gist options
  • Save speedcell4/9c941723a0efe7e28165a83b91479973 to your computer and use it in GitHub Desktop.
Save speedcell4/9c941723a0efe7e28165a83b91479973 to your computer and use it in GitHub Desktop.
import chainer.functions as F
import numpy as np
from chainer import Variable, Function
from chainer import cuda
class Scatter(Function):
def __init__(self, row: int, col: int):
self.row = row
self.col = col
def forward(self, inputs):
xp = cuda.get_array_module(*inputs)
x, a = inputs
y = xp.zeros((self.row, self.col), dtype=xp.float32)
self.indexes = xp.arange(x.shape[0]), x
y[self.indexes] = a
return y,
def backward(self, inputs, grad_outputs):
gy, = grad_outputs
return None, gy[self.indexes]
def scatter(row: int, col: int, x, a):
return Scatter(row, col)(x, a)
if __name__ == '__main__':
x = Variable(np.array([3, 6, 4], dtype=np.int32))
a = Variable(np.array([0.3, 0.2, 0.5], dtype=np.float32))
z = scatter(3, 10, x, a) ** 2
w = F.sum(z, axis=None)
w.cleargrad()
w.grad = np.array(1.0, dtype=np.float32)
w.backward()
print(z)
print(w)
print(x.grad)
print(a.grad)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment