Skip to content

Instantly share code, notes, and snippets.

@SrivastavaKshitij
Created October 19, 2018 14:56
Show Gist options
  • Save SrivastavaKshitij/e998bc90e0f0aa8875110cc87b047643 to your computer and use it in GitHub Desktop.
Save SrivastavaKshitij/e998bc90e0f0aa8875110cc87b047643 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
class autoencoder(nn.Module):
def __init__(self,downsizing_factor=None,in_channels=1):
self.downsize = downsizing_factor
self.in_channels = in_channels
super(autoencoder,self).__init__()
conv_modules=[]
self.in_channels = self.in_channels
self.out_channels = 4 * self.in_channels
self.block = [nn.Conv2d(self.in_channels,self.out_channels,3,stride=2,padding=1),
nn.ReLU(True),
nn.Conv2d(self.out_channels,int((self.out_channels)/2),1,stride=1),nn.ReLU(True)]
conv_modules.extend(self.block)
for i in range(1,self.downsize):
#print(f"in channels , out channel {self.in_channels,self.out_channels}")
self.in_channels=int((self.out_channels)/2)
self.out_channels = int(4 * self.in_channels)
self.block = [nn.Conv2d(self.in_channels,self.out_channels,3,stride=2,padding=1),
nn.ReLU(True),
nn.Conv2d(self.out_channels,int((self.out_channels)/2),1,stride=1),nn.ReLU(True)]
conv_modules.extend(self.block)
self.conv = nn.Sequential(*conv_modules)
## Doconv part
self.deconv_in_channels = int((self.out_channels)/2)
deconv_modules=[]
for i in range(self.downsize):
self.deconv_block = [nn.ConvTranspose2d(self.deconv_in_channels,int(self.deconv_in_channels/2),2,stride=2,padding=(0)),
nn.ReLU(True)]
deconv_modules.extend(self.deconv_block)
self.deconv_in_channels=int(self.deconv_in_channels/2)
self.deconv = nn.Sequential(*deconv_modules)
def forward(self,x):
x = self.conv(x)
return self.deconv(x)
@SrivastavaKshitij
Copy link
Author

a = autoencoder(downsizing_factor=5,in_channels=1)
print(a)
b = a(torch.unsqueeze(torch.randn(1,512,512),0))
print(np.shape(b))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment