Last active
March 1, 2019 10:26
-
-
Save mibaumgartner/59f0be2e0a667bb9c1bec1106e49989e to your computer and use it in GitHub Desktop.
3d pool examples
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
set_seed = 0 | |
DEVICE = 'cuda' | |
import torch | |
torch.manual_seed(set_seed) | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
import torch.nn as nn | |
import random | |
random.seed(set_seed) | |
import numpy as np | |
np.random.seed(set_seed) | |
# fix constants | |
num_samples = 50 | |
shape_samples = (1, 1, 64, 64, 4) # N,C, y,x,z | |
epochs = 2 | |
class DummyNetwork(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.conv_in = nn.Conv3d(1, 50, kernel_size=3, stride=1, padding=1) | |
self.conv1 = MiniBlock() | |
self.pool1 = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) | |
self.conv2 = MiniBlock() | |
self.pool2 = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) | |
self.conv3 = MiniBlock() | |
self.pool = nn.AvgPool3d(kernel_size=(16, 16, 1)) | |
self.conv_out = nn.Conv3d(50, 1, kernel_size=1, stride=1, padding=0) | |
def forward(self, input): | |
out = self.conv_in(input) | |
out = self.conv1(out) | |
out = self.pool1(out) | |
out = self.conv2(out) | |
out = self.pool2(out) | |
out = self.conv3(out) | |
out = self.pool(out) | |
out = self.conv_out(out) | |
out = out.view(out.size(0), -1) | |
return out | |
class MiniBlock(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.c1 = nn.Conv3d(50, 50, kernel_size=3, stride=1, padding=1) | |
self.b1 = nn.BatchNorm3d(50) | |
self.r1 = nn.ReLU(inplace=True) | |
def forward(self, input): | |
out = self.c1(input) | |
out = self.b1(out) | |
out = self.r1(out) | |
return out | |
# generate data for network (deterministic because of seed) | |
data = [] | |
for _ in range(num_samples): | |
data.append(torch.randn(shape_samples, device=DEVICE, dtype=torch.float)) | |
gt = np.random.randint(0, 2, num_samples) | |
# create network | |
net = DummyNetwork() | |
net.train() | |
net.to(DEVICE) | |
loss_fn = nn.BCEWithLogitsLoss() | |
optimizer = torch.optim.Adam(net.parameters(), lr=0.1) | |
# train network | |
for _ in range(epochs): | |
for ind, input in enumerate(data): | |
target = torch.from_numpy(gt[ind].reshape(1)[np.newaxis]) | |
target = target.to(dtype=torch.float, device=DEVICE) | |
pred = net(input) | |
loss = loss_fn(pred, target) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
print('Pred: {}'.format(pred.detach().cpu())) | |
print('Loss: {}'.format(loss.detach().cpu())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment