Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save insaneyilin/0a5e9a32f6109b6200211c40cdfed031 to your computer and use it in GitHub Desktop.
Save insaneyilin/0a5e9a32f6109b6200211c40cdfed031 to your computer and use it in GitHub Desktop.
torch.repeat_interleave alternative
def tile_along_axis(x, dim, n_tile):
init_dim = x.size(dim)
repeat_idx = [1] * x.dim()
repeat_idx[dim] = n_tile
x = x.repeat(*(repeat_idx))
order_index = torch.tensor(
torch.cat([init_dim * torch.arange(n_tile, device=x.device) + i for i in range(init_dim)]),
dtype=torch.long, device=x.device)
return torch.index_select(x, dim, order_index)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment