Skip to content

Instantly share code, notes, and snippets.

@mibaumgartner
Last active March 1, 2019 10:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mibaumgartner/59f0be2e0a667bb9c1bec1106e49989e to your computer and use it in GitHub Desktop.
Save mibaumgartner/59f0be2e0a667bb9c1bec1106e49989e to your computer and use it in GitHub Desktop.
3d pool examples
set_seed = 0
DEVICE = 'cuda'
import torch
torch.manual_seed(set_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
import torch.nn as nn
import random
random.seed(set_seed)
import numpy as np
np.random.seed(set_seed)
# fix constants
num_samples = 50
shape_samples = (1, 1, 64, 64, 4) # N,C, y,x,z
epochs = 2
class DummyNetwork(nn.Module):
def __init__(self):
super().__init__()
self.conv_in = nn.Conv3d(1, 50, kernel_size=3, stride=1, padding=1)
self.conv1 = MiniBlock()
self.pool1 = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
self.conv2 = MiniBlock()
self.pool2 = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
self.conv3 = MiniBlock()
self.pool = nn.AvgPool3d(kernel_size=(16, 16, 1))
self.conv_out = nn.Conv3d(50, 1, kernel_size=1, stride=1, padding=0)
def forward(self, input):
out = self.conv_in(input)
out = self.conv1(out)
out = self.pool1(out)
out = self.conv2(out)
out = self.pool2(out)
out = self.conv3(out)
out = self.pool(out)
out = self.conv_out(out)
out = out.view(out.size(0), -1)
return out
class MiniBlock(nn.Module):
def __init__(self):
super().__init__()
self.c1 = nn.Conv3d(50, 50, kernel_size=3, stride=1, padding=1)
self.b1 = nn.BatchNorm3d(50)
self.r1 = nn.ReLU(inplace=True)
def forward(self, input):
out = self.c1(input)
out = self.b1(out)
out = self.r1(out)
return out
# generate data for network (deterministic because of seed)
data = []
for _ in range(num_samples):
data.append(torch.randn(shape_samples, device=DEVICE, dtype=torch.float))
gt = np.random.randint(0, 2, num_samples)
# create network
net = DummyNetwork()
net.train()
net.to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)
# train network
for _ in range(epochs):
for ind, input in enumerate(data):
target = torch.from_numpy(gt[ind].reshape(1)[np.newaxis])
target = target.to(dtype=torch.float, device=DEVICE)
pred = net(input)
loss = loss_fn(pred, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Pred: {}'.format(pred.detach().cpu()))
print('Loss: {}'.format(loss.detach().cpu()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment