Skip to content

Instantly share code, notes, and snippets.

@SrivastavaKshitij
Created October 19, 2018 14:55
Show Gist options
  • Save SrivastavaKshitij/11972067f20f2ea81f25fb39851a0e75 to your computer and use it in GitHub Desktop.
Save SrivastavaKshitij/11972067f20f2ea81f25fb39851a0e75 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
class autoencoder(nn.Module):
def __init__(self,downsizing_factor=None,in_channels=1):
self.downsize = downsizing_factor
self.in_channels = in_channels
super(autoencoder,self).__init__()
conv_modules=[]
self.in_channels = self.in_channels
self.out_channels = 2 * self.in_channels
self.block = [nn.Conv2d(self.in_channels,self.out_channels,3,stride=2,padding=1),
nn.ReLU(True)]
conv_modules.extend(self.block)
for i in range(1,self.downsize):
print(f"in channels , out channel {self.in_channels,self.out_channels}")
self.in_channels=int((self.out_channels))
self.out_channels = int(2 * self.in_channels)
self.block = [nn.Conv2d(self.in_channels,self.out_channels,3,stride=2,padding=1),
nn.ReLU(True)]
conv_modules.extend(self.block)
self.conv = nn.Sequential(*conv_modules)
## Doconv part
self.deconv_in_channels = int((self.out_channels))
deconv_modules=[]
for i in range(self.downsize):
print(f"in channels , out channel {self.deconv_in_channels,int(self.deconv_in_channels/2)}")
self.deconv_block = [nn.ConvTranspose2d(self.deconv_in_channels,int(self.deconv_in_channels/2),2,stride=2),
nn.ReLU(True)]
deconv_modules.extend(self.deconv_block)
self.deconv_in_channels=int(self.deconv_in_channels/2)
self.deconv = nn.Sequential(*deconv_modules)
def forward(self,x):
x = self.conv(x)
print(f"shape of input after encoder part {np.shape(x)}")
return self.deconv(x)
@SrivastavaKshitij
Copy link
Author

a = autoencoder(downsizing_factor=3,in_channels=1)
print(a)
b = a(torch.unsqueeze(torch.randn(1,512,512),0))
print(np.shape(b))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment