Skip to content

Instantly share code, notes, and snippets.

@justusschock
Created October 3, 2018 17:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save justusschock/b054f46d0643b27974f3d66ce1fbaddf to your computer and use it in GitHub Desktop.
Save justusschock/b054f46d0643b27974f3d66ce1fbaddf to your computer and use it in GitHub Desktop.
Shows export failure of torch to ONNX with pytorch 1.0rc1
import torch
class BaseBlock(torch.nn.Module):
"""class to define an API for all Layer Blocks"""
def __init__(self, outer_nc, dropout_value):
"""
function to create and initialize class variables
:param outer_nc: number of channels outside the block (used as block-input and -output)
:param dropout_value: dropout-value
"""
super(BaseBlock, self).__init__()
self.outer_nc = outer_nc
self.dropout_value = dropout_value
self.model = torch.nn.Sequential()
def _build(self):
"""
Abstract Function to build the block, must be implemented in subclasses
:return: None
"""
pass
def forward(self, input_tensor):
"""
Abstract Function to forward through a block (necessary for implicit backward function),
must be implemented in subclasses
:param input_tensor: input tensor of the block
:return: forwarded result
"""
return self.model(input_tensor)
class UnetSkipConnectionBlock(BaseBlock):
"""class containing a U Skip Connection Block implementation"""
def __init__(self, outer_nc=64, dropout_value=0.25, inner_nc=64,
submodule=None,
outermost=False, innermost=False):
"""
function to create and initialize the class variables
:param outer_nc: number of input and output channels for block
:param dropout_value: dropout value
:param inner_nc: number of channels given to submodule
:param submodule: submodule inside the unet block
:param outermost: True if block is directly connected to input image, False otherwise
:param innermost: True if block contains no submodule
"""
super(UnetSkipConnectionBlock, self).__init__(outer_nc, dropout_value)
self.inner_nc = inner_nc
self.submodule = submodule
self.outermost = outermost
self.innermost = innermost
self._build()
def _build(self):
"""
function to build the block
:return: None
"""
downconv = torch.nn.Conv2d(self.outer_nc, self.inner_nc, kernel_size=4,
stride=2, padding=1)
downrelu = torch.nn.LeakyReLU(0.2, True)
downnorm = torch.nn.BatchNorm2d(self.inner_nc, affine=True)
uprelu = torch.nn.ReLU(True)
upnorm = torch.nn.BatchNorm2d(self.outer_nc, affine=True)
if self.outermost:
upconv = torch.nn.ConvTranspose2d(self.inner_nc * 2, self.outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [torch.nn.Dropout(self.dropout_value), downconv]
# up = [uprelu, torch.nn.Dropout(self.dropout_value), upconv, torch.nn.Tanh()]
up = [uprelu, torch.nn.Dropout(self.dropout_value), upconv]
model = down + [self.submodule] + up
elif self.innermost:
upconv = torch.nn.ConvTranspose2d(self.inner_nc, self.outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downrelu, torch.nn.Dropout(self.dropout_value), downconv]
up = [uprelu, torch.nn.Dropout(self.dropout_value), upconv, upnorm]
model = down + up
else:
upconv = torch.nn.ConvTranspose2d(self.inner_nc * 2, self.outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downrelu, torch.nn.Dropout(self.dropout_value), downconv,
downnorm]
up = [uprelu, upconv, upnorm]
model = down + [self.submodule] + up + [
torch.nn.Dropout(self.dropout_value)]
self.model = torch.nn.Sequential(*model)
def forward(self, input_tensor):
"""
Function to forward through a block (necessary for implicit backward function)
:param input_tensor: input tensor
:return: None
"""
if self.outermost:
return self.model(input_tensor)
else:
out = torch.cat([self.model(input_tensor), input_tensor], 1)
return out
class UNet(torch.nn.Module):
"""class containing a generic generator implementation"""
def __init__(self, n_input_channels=3, n_output_channels=1, n_blocks=9,
initial_filters=64, dropout_value=0.0,
gpu_ids=[]
):
"""
function to create and initialize the class variables
:param n_input_channels: number of input channels
:param n_output_channels: number of output channels
:param initial_filters: number of filters for first layer
:param dropout_value: dropout value (0 for no dropout)
:param gpu_ids: list of gpu ids
:param block_specific_args: arguments needed to initialize blocks
"""
super(UNet, self).__init__()
self.input_nc = n_input_channels
self.output_nc = n_output_channels
self.initial_filters = initial_filters
self.gpu_ids = gpu_ids
self.dropout_value = dropout_value
self.model = torch.nn.Sequential()
self.n_blocks = n_blocks
self._build()
if len(gpu_ids):
# self.model.cuda(device_id=gpu_ids[0])
self.model.cuda()
def _build(self):
"""
Build model with Unet Skip Connection Block
:return: None
"""
nf_mult = min(8, 2 ** self.n_blocks)
model = UnetSkipConnectionBlock(nf_mult * self.initial_filters,
self.dropout_value,
nf_mult * self.initial_filters,
innermost=True)
for i in range(self.n_blocks - 1, 1, -1):
nf_mult_out = min(8, 2 ** (int(i - 2)))
nf_mult_in = 8 if nf_mult_out == 8 else 2 * nf_mult_out
model = UnetSkipConnectionBlock(nf_mult_out * self.initial_filters,
self.dropout_value,
nf_mult_in * self.initial_filters,
model)
model = [UnetSkipConnectionBlock(self.input_nc, self.dropout_value,
self.initial_filters, model,
outermost=True)]
model += [
torch.nn.Conv2d(self.input_nc, self.output_nc, kernel_size=1),
torch.nn.Sigmoid()]
# model += [torch.nn.Conv2d(self.input_nc, self.output_nc, kernel_size=1)]
self.model = torch.nn.Sequential(*model)
def forward(self, input_tensor):
"""
Function to forward through model (necessary for implicit back-propagation
:param input_tensor: input tensor
:return: forwarded input
"""
if self.gpu_ids and isinstance(input_tensor.data,
torch.cuda.FloatTensor):
tmp = torch.nn.parallel.data_parallel(self.model, input_tensor,
self.gpu_ids)
# return torch.nn.parallel.data_parallel(self.model, input_tensor)
else:
tmp = self.model(input_tensor)
return tmp
if __name__ == '__main__':
import torch.onnx
model = UNet(3, 1, 5, 96)
model.eval()
x = torch.rand(1, 3, 480, 480)
torch_out = torch.onnx.export(model,
x,
"unet.onnx",
export_params=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment