Skip to content

Instantly share code, notes, and snippets.

@anderzzz
Created November 3, 2020 15:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save anderzzz/37d14736477353001a9d6c5361f1552c to your computer and use it in GitHub Desktop.
Save anderzzz/37d14736477353001a9d6c5361f1552c to your computer and use it in GitHub Desktop.
class EncoderVGGMerged(EncoderVGG):
'''Special case of the VGG Encoder wherein the code is merged along the height/width dimension. This is a thin child
class of `EncoderVGG`.
Args:
merger_type (str, optional): Defines how the code is merged.
'''
def __init__(self, merger_type='mean', pretrained_params=True):
super(EncoderVGGMerged, self).__init__(pretrained_params=pretrained_params)
if merger_type is None:
self.code_post_process = lambda x: x
self.code_post_process_kwargs = {}
elif merger_type == 'mean':
self.code_post_process = torch.mean
self.code_post_process_kwargs = {'dim' : (-2, -1)}
elif merger_type == 'flatten':
self.code_post_process = torch.flatten
self.code_post_process_kwargs = {'start_dim' : 1, 'end_dim' : -1}
else:
raise ValueError('Unknown merger type for the encoder code: {}'.format(merger_type))
def forward(self, x):
'''Execute the encoder on the image input
Args:
x (Tensor): image tensor
Returns:
x_code (Tensor): merged code tensor
'''
x_current, _ = super().forward(x)
x_code = self.code_post_process(x_current, **self.code_post_process_kwargs)
return x_code
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment