Skip to content

Instantly share code, notes, and snippets.

@shwang
Created October 14, 2020 05:50
Show Gist options
  • Save shwang/1bbc5cbebeb2249a601104e91a29602a to your computer and use it in GitHub Desktop.
Save shwang/1bbc5cbebeb2249a601104e91a29602a to your computer and use it in GitHub Desktop.
import torch as th
import torch.utils.data as th_data
def main():
x = th.ones([30, 3, 3], requires_grad=True)
y = x * 2
dl = th_data.DataLoader(y)
batch = next(iter(dl))
print(f"batch: {batch}")
batch2 = th.as_tensor(batch)
print(f"batch2: {batch2}")
print(f"batch is batch2: {batch is batch2}")
"""
batch: tensor([[[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]]], grad_fn=<StackBackward>)
batch2: tensor([[[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]]], grad_fn=<StackBackward>)
batch is batch2: True
"""
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment