Skip to content

Instantly share code, notes, and snippets.

@dlibenzi
Created May 30, 2020 21:54
Show Gist options
  • Save dlibenzi/64419788b8f0e029298811c8e1ebbc38 to your computer and use it in GitHub Desktop.
Save dlibenzi/64419788b8f0e029298811c8e1ebbc38 to your computer and use it in GitHub Desktop.
import torch
import torch_xla
import torch_xla.core.xla_builder as xb
import torch_xla.core.xla_op_registry as xor
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
def _split_indices(index):
ishape = index.shape()
assert ishape.rank == 1
indices = []
for dim in range(0, ishape.sizes[0]):
indices.append(index.slice_in_dim(dim, dim + 1, 0).reshape([]))
return indices
def _dynamic_slice_forward(input, start_indices, slice_sizes=None):
return input.dynamic_slice(_split_indices(start_indices), slice_sizes)
def _dynamic_slice_backward(grad_output, input, start_indices, slice_sizes=None):
return input.zeros_like().dynamic_update_slice(grad_output, _split_indices(start_indices))
DYNAMIC_SLICE_FORWARD = xor.register('DynamicSliceForward', _dynamic_slice_forward)
DYNAMIC_SLICE_BACKWARD = xor.register('DynamicSliceBackward', _dynamic_slice_backward)
class DynamicSlice(torch.autograd.Function):
@staticmethod
def forward(ctx, input, start_indices, slice_sizes):
ctx.slice_sizes = slice_sizes
output = DYNAMIC_SLICE_FORWARD(input, start_indices, slice_sizes=slice_sizes)
ctx.save_for_backward(input, start_indices)
return output
@staticmethod
def backward(ctx, grad_output):
input, start_indices = ctx.saved_tensors
grad = DYNAMIC_SLICE_BACKWARD(grad_output, input, start_indices,
slice_sizes=ctx.slice_sizes)
return grad, None, None
def dynamic_slice(input, start_indices, slice_sizes):
return DynamicSlice.apply(input, start_indices, slice_sizes)
def _mp_fn(index):
device = xm.xla_device()
x = torch.randn(6, 8, device=device, requires_grad=True)
index = torch.tensor([2, 4], dtype=torch.int32, device=device)
out = dynamic_slice(x, index, (2, 3))
loss = out.pow(2).sum()
loss.backward()
print(x.grad.cpu())
if __name__ == '__main__':
xmp.spawn(_mp_fn, nprocs=None)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment