Skip to content

Instantly share code, notes, and snippets.

@vlasenkov
Last active March 19, 2019 15:26
Show Gist options
  • Save vlasenkov/b3aa7c12570fe0056fca3421453470ca to your computer and use it in GitHub Desktop.
Save vlasenkov/b3aa7c12570fe0056fca3421453470ca to your computer and use it in GitHub Desktop.
import torch
torch.backends.cudnn.deterministic = True
import torch.nn as nn
class UNet3d(nn.Module):
def __init__(self, activation=nn.ReLU, pooling=nn.MaxPool3d):
super().__init__()
self.down0 = nn.Sequential(
nn.Conv3d(1, 64, 3, padding=1),
nn.BatchNorm3d(64),
activation(),
nn.Conv3d(64, 64, 3, padding=1),
nn.BatchNorm3d(64),
activation(),
)
self.down1 = nn.Sequential(
pooling(2),
nn.Conv3d(64, 128, 3, padding=1),
nn.BatchNorm3d(128),
activation(),
nn.Conv3d(128, 128, 3, padding=1),
nn.BatchNorm3d(128),
activation(),
)
self.down2 = nn.Sequential(
pooling(2),
nn.Conv3d(128, 256, 3, padding=1),
nn.BatchNorm3d(256),
activation(),
nn.Conv3d(256, 256, 3, padding=1),
nn.BatchNorm3d(256),
activation(),
)
self.down3 = nn.Sequential(
pooling(2),
nn.Conv3d(256, 512, 3, padding=1),
nn.BatchNorm3d(512),
activation(),
nn.Conv3d(512, 512, 3, padding=1), # seems that the problem is here
)
def forward(self, inputs):
out0 = self.down0(inputs)
out1 = self.down1(out0)
out2 = self.down2(out1)
out3 = self.down3(out2)
return out3
model = UNet3d()
model.cuda()
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()
img = torch.randn(4, 1, 44, 60, 48).cuda()
optimizer.zero_grad()
out = model(img)
lbl = torch.randint(0, 3, (4,) + out.shape[-3:]).cuda().long()
loss = loss_fn(input=out, target=lbl)
loss.backward()
optimizer.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment