Skip to content

Instantly share code, notes, and snippets.

@cpuhrsch
Created July 18, 2019 16:42
Show Gist options
  • Save cpuhrsch/deb4f33d09cb00fa1f2adf43ae43493c to your computer and use it in GitHub Desktop.
Save cpuhrsch/deb4f33d09cb00fa1f2adf43ae43493c to your computer and use it in GitHub Desktop.
import torch
self_tensors = [torch.rand(2, 3, requires_grad=True),
torch.rand(4, 5, requires_grad=True)]
self_tensors = [torch.rand(2, 3),
torch.rand(4, 5)]
flat_tensors = []
for tensor in self_tensors:
flat_tensors.append(tensor.view(-1))
buffer_ = torch.cat(flat_tensors)
current_offset = 0
for i in range(len(self_tensors)):
self_tensors[i].set_(buffer_.storage(),
storage_offset=current_offset,
size=self_tensors[i].size(),
stride=self_tensors[i].stride())
current_offset += self_tensors[i].numel()
print('self_tensors 1')
print(self_tensors)
buffer_.add_(2)
print('self_tensors 2')
print(self_tensors)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment