Created
April 8, 2018 11:12
-
-
Save JustinShenk/9c5ba977065baa9534db2d73395f13ef to your computer and use it in GitHub Desktop.
Inception V3 autoencoder implementation for PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.model_zoo as model_zoo | |
from torchvision import models | |
__all__ = ['Inception3_Autoencoder', 'inception_v3_autoencoder'] | |
model_urls = { | |
# Inception v3 ported from TensorFlow | |
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth', | |
} | |
def inception_v3_autoencoder(pretrained=False, z_dim=1000, **kwargs): | |
r"""Inception v3 model architecture from | |
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_. | |
Args: | |
pretrained (bool): If True, returns a model pre-trained on ImageNet | |
""" | |
if pretrained: | |
if 'transform_input' not in kwargs: | |
kwargs['transform_input'] = True | |
pretrained_model = models.__dict__['inception_v3'](pretrained=True) | |
pretrained_state = pretrained_model.state_dict() | |
model = Inception3_Autoencoder(**kwargs) | |
state = model.state_dict() | |
state.update(pretrained_state) | |
model.load_state_dict(state) | |
# Freeze Previous Layers | |
excluded = [model.Conv2d_1a_3x3, model.Conv2d_2a_3x3,model.Conv2d_2b_3x3, model.Conv2d_3b_1x1, | |
model.Conv2d_4a_3x3, model.Mixed_5b,model.Mixed_5c,model.Mixed_5d,model.Mixed_6a, | |
model.Mixed_6b,model.Mixed_6c,model.Mixed_6d,model.Mixed_6e,model.Mixed_7a, | |
model.Mixed_7b,model.Mixed_7c,model.AuxLogits] | |
for l in excluded: | |
for param in l.parameters(): | |
param.requires_grad = False | |
# ConvTranspose layer | |
if z_dim is not 1000: | |
if z_dim % 25 != 0: raise AssertionError("z_dim {} not divisible by 25".format(z_dim)) | |
in_channels = z_dim // 25 | |
model.convT1 = nn.ConvTranspose2d(in_channels, 8, 5, stride=3, padding=1) | |
# ?Necessary to manually enable? | |
model.convT1.requires_grad = True | |
model.convT2.requires_grad = True | |
model.fc1.requires_grad = True | |
# Last layer | |
num_ftrs = model.fc.in_features | |
model.fc = nn.Linear(num_ftrs, z_dim) | |
return model | |
else: | |
model = Inception3_Autoencoder(z_dim=z_dim, **kwargs) | |
# ConvTranspose layer | |
if z_dim is not 1000: | |
if z_dim % 25 != 0: raise AssertionError("z_dim {} not divisible by 25".format(z_dim)) | |
in_channels = z_dim // 25 | |
model.convT1 = nn.ConvTranspose2d(in_channels, 8, 5, stride=3, padding=1) | |
return model | |
class Inception3_Autoencoder(nn.Module): | |
def __init__(self, aux_logits=True, transform_input=False, z_dim=1000): | |
super(Inception3_Autoencoder, self).__init__() | |
self.aux_logits = aux_logits | |
self.z_dim = z_dim | |
self.transform_input = transform_input | |
self.latent_history = [] | |
self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2) | |
self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3) | |
self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1) | |
self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1) | |
self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3) | |
self.Mixed_5b = InceptionA(192, pool_features=32) | |
self.Mixed_5c = InceptionA(256, pool_features=64) | |
self.Mixed_5d = InceptionA(288, pool_features=64) | |
self.Mixed_6a = InceptionB(288) | |
self.Mixed_6b = InceptionC(768, channels_7x7=128) | |
self.Mixed_6c = InceptionC(768, channels_7x7=160) | |
self.Mixed_6d = InceptionC(768, channels_7x7=160) | |
self.Mixed_6e = InceptionC(768, channels_7x7=192) | |
if aux_logits: | |
self.AuxLogits = InceptionAux(768) | |
self.Mixed_7a = InceptionD(768) | |
self.Mixed_7b = InceptionE(1280) | |
self.Mixed_7c = InceptionE(2048) | |
self.fc = nn.Linear(2048, z_dim) | |
# == Decoder == | |
self.convT1 = nn.ConvTranspose2d(40, 8, 5, stride=3, padding=1) | |
self.convT2 = nn.ConvTranspose2d(8, 3, 8, stride=1, padding=1) | |
self.fc1 = nn.Linear(3 * 20 * 20, 3 * 32 * 32, bias=False) | |
self.upsample5x = nn.Upsample(scale_factor=5, mode='bilinear') | |
self.upsample = nn.Upsample((299,299),mode='bilinear') | |
self.activation = nn.Tanh() | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | |
import scipy.stats as stats | |
stddev = m.stddev if hasattr(m, 'stddev') else 0.1 | |
X = stats.truncnorm(-2, 2, scale=stddev) | |
values = torch.Tensor(X.rvs(m.weight.data.numel())) | |
m.weight.data.copy_(values) | |
elif isinstance(m, nn.BatchNorm2d): | |
m.weight.data.fill_(1) | |
m.bias.data.zero_() | |
def forward(self, x): | |
if self.transform_input: | |
x = x.clone() | |
x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 | |
x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 | |
x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 | |
# 299 x 299 x 3 | |
x = self.Conv2d_1a_3x3(x) | |
# 149 x 149 x 32 | |
x = self.Conv2d_2a_3x3(x) | |
# 147 x 147 x 32 | |
x = self.Conv2d_2b_3x3(x) | |
# 147 x 147 x 64 | |
x = F.max_pool2d(x, kernel_size=3, stride=2) | |
# 73 x 73 x 64 | |
x = self.Conv2d_3b_1x1(x) | |
# 73 x 73 x 80 | |
x = self.Conv2d_4a_3x3(x) | |
# 71 x 71 x 192 | |
x = F.max_pool2d(x, kernel_size=3, stride=2) | |
# 35 x 35 x 192 | |
x = self.Mixed_5b(x) | |
# 35 x 35 x 256 | |
x = self.Mixed_5c(x) | |
# 35 x 35 x 288 | |
x = self.Mixed_5d(x) | |
# 35 x 35 x 288 | |
x = self.Mixed_6a(x) | |
# 17 x 17 x 768 | |
x = self.Mixed_6b(x) | |
# 17 x 17 x 768 | |
x = self.Mixed_6c(x) | |
# 17 x 17 x 768 | |
x = self.Mixed_6d(x) | |
# 17 x 17 x 768 | |
x = self.Mixed_6e(x) | |
# 17 x 17 x 768 | |
x = self.Mixed_7a(x) | |
# 8 x 8 x 1280 | |
x = self.Mixed_7b(x) | |
# 8 x 8 x 2048 | |
x = self.Mixed_7c(x) | |
# 8 x 8 x 2048 | |
x = F.avg_pool2d(x, kernel_size=8) | |
# 1 x 1 x 2048 | |
x = F.dropout(x, training=self.training) | |
# 1 x 1 x 2048 | |
x = x.view(x.size(0), -1) | |
# 2048 | |
x = self.fc(x) | |
# 1000 (z_dim) | |
# Decoder | |
self.latent_history.append(x.cpu().data.numpy()) | |
x = self.decoder(x) | |
return x | |
def decoder(self, x): | |
x = x.view(x.size(0), -1, 5, 5) | |
x = F.relu(self.convT1(x)) # 40 x 15 x 15 | |
x = F.relu(self.convT2(x)) # 40 x 20 x 20 | |
x = x.view(x.size(0), 3*20*20) | |
x = self.fc1(x) | |
x = self.activation(x) | |
x = x.view(x.size(0), 3, 32, 32) | |
x = self.upsample(x) # upsample | |
return x | |
class InceptionA(nn.Module): | |
def __init__(self, in_channels, pool_features, reverse=False): | |
super(InceptionA, self).__init__() | |
self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1) | |
self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1) | |
self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2) | |
self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) | |
self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) | |
self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1) | |
self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1) | |
def forward(self, x): | |
branch1x1 = self.branch1x1(x) | |
branch5x5 = self.branch5x5_1(x) | |
branch5x5 = self.branch5x5_2(branch5x5) | |
branch3x3dbl = self.branch3x3dbl_1(x) | |
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) | |
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) | |
branch_pool = self.branch_pool(branch_pool) | |
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] | |
return torch.cat(outputs, 1) | |
class InceptionB(nn.Module): | |
def __init__(self, in_channels): | |
super(InceptionB, self).__init__() | |
self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2) | |
self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1) | |
self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1) | |
self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2) | |
def forward(self, x): | |
branch3x3 = self.branch3x3(x) | |
branch3x3dbl = self.branch3x3dbl_1(x) | |
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) | |
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) | |
outputs = [branch3x3, branch3x3dbl, branch_pool] | |
return torch.cat(outputs, 1) | |
class InceptionC(nn.Module): | |
def __init__(self, in_channels, channels_7x7): | |
super(InceptionC, self).__init__() | |
self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1) | |
c7 = channels_7x7 | |
self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1) | |
self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) | |
self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0)) | |
self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1) | |
self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) | |
self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3)) | |
self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0)) | |
self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3)) | |
self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) | |
def forward(self, x): | |
branch1x1 = self.branch1x1(x) | |
branch7x7 = self.branch7x7_1(x) | |
branch7x7 = self.branch7x7_2(branch7x7) | |
branch7x7 = self.branch7x7_3(branch7x7) | |
branch7x7dbl = self.branch7x7dbl_1(x) | |
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) | |
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) | |
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) | |
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) | |
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) | |
branch_pool = self.branch_pool(branch_pool) | |
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] | |
return torch.cat(outputs, 1) | |
class InceptionD(nn.Module): | |
def __init__(self, in_channels): | |
super(InceptionD, self).__init__() | |
self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) | |
self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2) | |
self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1) | |
self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3)) | |
self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0)) | |
self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2) | |
def forward(self, x): | |
branch3x3 = self.branch3x3_1(x) | |
branch3x3 = self.branch3x3_2(branch3x3) | |
branch7x7x3 = self.branch7x7x3_1(x) | |
branch7x7x3 = self.branch7x7x3_2(branch7x7x3) | |
branch7x7x3 = self.branch7x7x3_3(branch7x7x3) | |
branch7x7x3 = self.branch7x7x3_4(branch7x7x3) | |
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2) | |
outputs = [branch3x3, branch7x7x3, branch_pool] | |
return torch.cat(outputs, 1) | |
class InceptionE(nn.Module): | |
def __init__(self, in_channels, reversed = False): | |
super(InceptionE, self).__init__() | |
self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1) | |
self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1) | |
self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) | |
self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) | |
self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1) | |
self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1) | |
self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1)) | |
self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0)) | |
self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1) | |
def forward(self, x): | |
branch1x1 = self.branch1x1(x) | |
branch3x3 = self.branch3x3_1(x) | |
branch3x3 = [ | |
self.branch3x3_2a(branch3x3), | |
self.branch3x3_2b(branch3x3), | |
] | |
branch3x3 = torch.cat(branch3x3, 1) | |
branch3x3dbl = self.branch3x3dbl_1(x) | |
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) | |
branch3x3dbl = [ | |
self.branch3x3dbl_3a(branch3x3dbl), | |
self.branch3x3dbl_3b(branch3x3dbl), | |
] | |
branch3x3dbl = torch.cat(branch3x3dbl, 1) | |
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) | |
branch_pool = self.branch_pool(branch_pool) | |
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] | |
return torch.cat(outputs, 1) | |
class InceptionAux(nn.Module): | |
def __init__(self, in_channels): | |
super(InceptionAux, self).__init__() | |
self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1) | |
self.conv1 = BasicConv2d(128, 768, kernel_size=5) | |
self.conv1.stddev = 0.01 | |
self.fc = nn.Linear(768, 1000) | |
self.fc.stddev = 0.001 | |
def forward(self, x): | |
# 17 x 17 x 768 | |
x = F.avg_pool2d(x, kernel_size=5, stride=3) | |
# 5 x 5 x 768 | |
x = self.conv0(x) | |
# 5 x 5 x 128 | |
x = self.conv1(x) | |
# 1 x 1 x 768 | |
x = x.view(x.size(0), -1) | |
# 768 | |
x = self.fc(x) | |
# 1000 | |
return x | |
class BasicConv2d(nn.Module): | |
def __init__(self, in_channels, out_channels, **kwargs): | |
super(BasicConv2d, self).__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) | |
self.bn = nn.BatchNorm2d(out_channels, eps=0.001) | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.bn(x) | |
return F.relu(x, inplace=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment