Created
November 3, 2020 15:39
-
-
Save anderzzz/37d14736477353001a9d6c5361f1552c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class EncoderVGGMerged(EncoderVGG): | |
'''Special case of the VGG Encoder wherein the code is merged along the height/width dimension. This is a thin child | |
class of `EncoderVGG`. | |
Args: | |
merger_type (str, optional): Defines how the code is merged. | |
''' | |
def __init__(self, merger_type='mean', pretrained_params=True): | |
super(EncoderVGGMerged, self).__init__(pretrained_params=pretrained_params) | |
if merger_type is None: | |
self.code_post_process = lambda x: x | |
self.code_post_process_kwargs = {} | |
elif merger_type == 'mean': | |
self.code_post_process = torch.mean | |
self.code_post_process_kwargs = {'dim' : (-2, -1)} | |
elif merger_type == 'flatten': | |
self.code_post_process = torch.flatten | |
self.code_post_process_kwargs = {'start_dim' : 1, 'end_dim' : -1} | |
else: | |
raise ValueError('Unknown merger type for the encoder code: {}'.format(merger_type)) | |
def forward(self, x): | |
'''Execute the encoder on the image input | |
Args: | |
x (Tensor): image tensor | |
Returns: | |
x_code (Tensor): merged code tensor | |
''' | |
x_current, _ = super().forward(x) | |
x_code = self.code_post_process(x_current, **self.code_post_process_kwargs) | |
return x_code |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment