Skip to content

Instantly share code, notes, and snippets.

@ruotianluo
Created August 3, 2020 21:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ruotianluo/54c25460b2ca43a274f50e1a7daa409a to your computer and use it in GitHub Desktop.
Save ruotianluo/54c25460b2ca43a274f50e1a7daa409a to your computer and use it in GitHub Desktop.
pthread
import torch
import torch.nn as nn
class X(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(3,4)
def forward(self, x):
x = next(self.parameters())
import torchvision
data = torchvision.datasets.CIFAR10(root='./', train=False,download=True)
loader = torch.utils.data.DataLoader(data,
batch_size=4,
shuffle=True,
num_workers=4, collate_fn=lambda x: x)
x = X()
x.cuda()
x = nn.DataParallel(x)
for tmp in loader:
x(tmp)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment