This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MemoryBank(object): | |
'''Memory bank | |
Args: | |
n_vectors (int): Number of vectors the memory bank should hold | |
dim_vector (int): Dimension of the vectors the memory bank should hold | |
memory_mixing_rate (float, optional): Fraction of new vector to add to currently stored vector. The value | |
should be between 0.0 and 1.0, the greater the value the more rapid the update. The mixing rate can be | |
set during calling `update_memory`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forward(self, codes, indices): | |
'''Forward pass for the local aggregation loss module''' | |
assert codes.shape[0] == len(indices) | |
codes = codes.type(torch.DoubleTensor) | |
code_data = normalize(codes.detach().numpy(), axis=1) | |
# Compute and collect arrays of indices that define the constants in the loss function. Note that | |
# no gradients are computed for these data values in backward pass | |
self.memory_bank.update_memory(code_data, indices) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import numpy as np | |
from sklearn.neighbors import NearestNeighbors | |
from sklearn.cluster import KMeans | |
from sklearn.preprocessing import normalize | |
from scipy.spatial.distance import cosine as cosine_distance |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class AutoEncoderVGG(nn.Module): | |
'''Auto-Encoder based on the VGG-16 with batch normalization template model. The class is comprised of | |
an encoder and a decoder. | |
Args: | |
pretrained_params (bool, optional): If the network should be populated with pre-trained VGG parameters. | |
Defaults to True. | |
''' | |
channels_in = EncoderVGG.channels_in |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forward(self, x, pool_indices): | |
'''Execute the decoder on the code tensor input | |
Args: | |
x (Tensor): code tensor obtained from encoder | |
pool_indices (list): Pool indices Pytorch tensors in order the pooling modules in the encoder | |
Returns: | |
x (Tensor): decoded image tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DecoderVGG(nn.Module): | |
'''Decoder of code based on the architecture of VGG-16 with batch normalization. | |
Args: | |
encoder: The encoder instance of `EncoderVGG` that is to be inverted into a decoder | |
''' | |
channels_in = EncoderVGG.channels_code | |
channels_out = 3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def forward(self, x): | |
'''Execute the encoder on the image input | |
Args: | |
x (Tensor): image tensor | |
Returns: | |
x_code (Tensor): code tensor | |
pool_indices (list): Pool indices tensors in order of the pooling modules |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _encodify_(self, encoder): | |
'''Create list of modules for encoder based on the architecture in VGG template model. | |
In the encoder-decoder architecture, the unpooling operations in the decoder require pooling | |
indices from the corresponding pooling operation in the encoder. In VGG template, these indices | |
are not returned. Hence the need for this method to extent the pooling operations. | |
Args: | |
encoder : the template VGG model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torchvision import models | |
class EncoderVGG(nn.Module): | |
'''Encoder of image based on the architecture of VGG-16 with batch normalization. | |
Args: | |
pretrained_params (bool, optional): If the network should be populated with pre-trained VGG parameters. | |
Defaults to True. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class StandardTransform(object): | |
'''Standard Image Transforms, typically instantiated and provided to the DataSet class | |
''' | |
def __init__(self, min_dim=299, to_tensor=True, square=False, | |
normalize=True, norm_mean=[0.485, 0.456, 0.406], norm_std=[0.229, 0.224, 0.225]): | |
self.transforms = [] | |
self.transforms.append(transforms.ToPILImage()) | |
self.transforms.append(transforms.Resize(min_dim)) |