Skip to content

Instantly share code, notes, and snippets.

@anderzzz
Created October 30, 2020 14:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anderzzz/fdbe784e1db855f809e93e1e286ef144 to your computer and use it in GitHub Desktop.
Save anderzzz/fdbe784e1db855f809e93e1e286ef144 to your computer and use it in GitHub Desktop.
First part of encoder based on VGG16
import torch
from torch import nn
from torchvision import models
class EncoderVGG(nn.Module):
'''Encoder of image based on the architecture of VGG-16 with batch normalization.
Args:
pretrained_params (bool, optional): If the network should be populated with pre-trained VGG parameters.
Defaults to True.
'''
channels_in = 3
channels_code = 512
def __init__(self, pretrained_params=True):
super(EncoderVGG, self).__init__()
vgg = models.vgg16_bn(pretrained=pretrained_params)
del vgg.classifier
del vgg.avgpool
self.encoder = self._encodify_(vgg)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment