Created
October 3, 2018 17:19
-
-
Save justusschock/b054f46d0643b27974f3d66ce1fbaddf to your computer and use it in GitHub Desktop.
Shows export failure of torch to ONNX with pytorch 1.0rc1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class BaseBlock(torch.nn.Module): | |
"""class to define an API for all Layer Blocks""" | |
def __init__(self, outer_nc, dropout_value): | |
""" | |
function to create and initialize class variables | |
:param outer_nc: number of channels outside the block (used as block-input and -output) | |
:param dropout_value: dropout-value | |
""" | |
super(BaseBlock, self).__init__() | |
self.outer_nc = outer_nc | |
self.dropout_value = dropout_value | |
self.model = torch.nn.Sequential() | |
def _build(self): | |
""" | |
Abstract Function to build the block, must be implemented in subclasses | |
:return: None | |
""" | |
pass | |
def forward(self, input_tensor): | |
""" | |
Abstract Function to forward through a block (necessary for implicit backward function), | |
must be implemented in subclasses | |
:param input_tensor: input tensor of the block | |
:return: forwarded result | |
""" | |
return self.model(input_tensor) | |
class UnetSkipConnectionBlock(BaseBlock): | |
"""class containing a U Skip Connection Block implementation""" | |
def __init__(self, outer_nc=64, dropout_value=0.25, inner_nc=64, | |
submodule=None, | |
outermost=False, innermost=False): | |
""" | |
function to create and initialize the class variables | |
:param outer_nc: number of input and output channels for block | |
:param dropout_value: dropout value | |
:param inner_nc: number of channels given to submodule | |
:param submodule: submodule inside the unet block | |
:param outermost: True if block is directly connected to input image, False otherwise | |
:param innermost: True if block contains no submodule | |
""" | |
super(UnetSkipConnectionBlock, self).__init__(outer_nc, dropout_value) | |
self.inner_nc = inner_nc | |
self.submodule = submodule | |
self.outermost = outermost | |
self.innermost = innermost | |
self._build() | |
def _build(self): | |
""" | |
function to build the block | |
:return: None | |
""" | |
downconv = torch.nn.Conv2d(self.outer_nc, self.inner_nc, kernel_size=4, | |
stride=2, padding=1) | |
downrelu = torch.nn.LeakyReLU(0.2, True) | |
downnorm = torch.nn.BatchNorm2d(self.inner_nc, affine=True) | |
uprelu = torch.nn.ReLU(True) | |
upnorm = torch.nn.BatchNorm2d(self.outer_nc, affine=True) | |
if self.outermost: | |
upconv = torch.nn.ConvTranspose2d(self.inner_nc * 2, self.outer_nc, | |
kernel_size=4, stride=2, | |
padding=1) | |
down = [torch.nn.Dropout(self.dropout_value), downconv] | |
# up = [uprelu, torch.nn.Dropout(self.dropout_value), upconv, torch.nn.Tanh()] | |
up = [uprelu, torch.nn.Dropout(self.dropout_value), upconv] | |
model = down + [self.submodule] + up | |
elif self.innermost: | |
upconv = torch.nn.ConvTranspose2d(self.inner_nc, self.outer_nc, | |
kernel_size=4, stride=2, | |
padding=1) | |
down = [downrelu, torch.nn.Dropout(self.dropout_value), downconv] | |
up = [uprelu, torch.nn.Dropout(self.dropout_value), upconv, upnorm] | |
model = down + up | |
else: | |
upconv = torch.nn.ConvTranspose2d(self.inner_nc * 2, self.outer_nc, | |
kernel_size=4, stride=2, | |
padding=1) | |
down = [downrelu, torch.nn.Dropout(self.dropout_value), downconv, | |
downnorm] | |
up = [uprelu, upconv, upnorm] | |
model = down + [self.submodule] + up + [ | |
torch.nn.Dropout(self.dropout_value)] | |
self.model = torch.nn.Sequential(*model) | |
def forward(self, input_tensor): | |
""" | |
Function to forward through a block (necessary for implicit backward function) | |
:param input_tensor: input tensor | |
:return: None | |
""" | |
if self.outermost: | |
return self.model(input_tensor) | |
else: | |
out = torch.cat([self.model(input_tensor), input_tensor], 1) | |
return out | |
class UNet(torch.nn.Module): | |
"""class containing a generic generator implementation""" | |
def __init__(self, n_input_channels=3, n_output_channels=1, n_blocks=9, | |
initial_filters=64, dropout_value=0.0, | |
gpu_ids=[] | |
): | |
""" | |
function to create and initialize the class variables | |
:param n_input_channels: number of input channels | |
:param n_output_channels: number of output channels | |
:param initial_filters: number of filters for first layer | |
:param dropout_value: dropout value (0 for no dropout) | |
:param gpu_ids: list of gpu ids | |
:param block_specific_args: arguments needed to initialize blocks | |
""" | |
super(UNet, self).__init__() | |
self.input_nc = n_input_channels | |
self.output_nc = n_output_channels | |
self.initial_filters = initial_filters | |
self.gpu_ids = gpu_ids | |
self.dropout_value = dropout_value | |
self.model = torch.nn.Sequential() | |
self.n_blocks = n_blocks | |
self._build() | |
if len(gpu_ids): | |
# self.model.cuda(device_id=gpu_ids[0]) | |
self.model.cuda() | |
def _build(self): | |
""" | |
Build model with Unet Skip Connection Block | |
:return: None | |
""" | |
nf_mult = min(8, 2 ** self.n_blocks) | |
model = UnetSkipConnectionBlock(nf_mult * self.initial_filters, | |
self.dropout_value, | |
nf_mult * self.initial_filters, | |
innermost=True) | |
for i in range(self.n_blocks - 1, 1, -1): | |
nf_mult_out = min(8, 2 ** (int(i - 2))) | |
nf_mult_in = 8 if nf_mult_out == 8 else 2 * nf_mult_out | |
model = UnetSkipConnectionBlock(nf_mult_out * self.initial_filters, | |
self.dropout_value, | |
nf_mult_in * self.initial_filters, | |
model) | |
model = [UnetSkipConnectionBlock(self.input_nc, self.dropout_value, | |
self.initial_filters, model, | |
outermost=True)] | |
model += [ | |
torch.nn.Conv2d(self.input_nc, self.output_nc, kernel_size=1), | |
torch.nn.Sigmoid()] | |
# model += [torch.nn.Conv2d(self.input_nc, self.output_nc, kernel_size=1)] | |
self.model = torch.nn.Sequential(*model) | |
def forward(self, input_tensor): | |
""" | |
Function to forward through model (necessary for implicit back-propagation | |
:param input_tensor: input tensor | |
:return: forwarded input | |
""" | |
if self.gpu_ids and isinstance(input_tensor.data, | |
torch.cuda.FloatTensor): | |
tmp = torch.nn.parallel.data_parallel(self.model, input_tensor, | |
self.gpu_ids) | |
# return torch.nn.parallel.data_parallel(self.model, input_tensor) | |
else: | |
tmp = self.model(input_tensor) | |
return tmp | |
if __name__ == '__main__': | |
import torch.onnx | |
model = UNet(3, 1, 5, 96) | |
model.eval() | |
x = torch.rand(1, 3, 480, 480) | |
torch_out = torch.onnx.export(model, | |
x, | |
"unet.onnx", | |
export_params=True) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment