Skip to content

Instantly share code, notes, and snippets.

@parajain
Created June 3, 2021 14:13
Show Gist options
  • Save parajain/023d0a9ddeef999f0f47c2fb2839d55b to your computer and use it in GitHub Desktop.
Save parajain/023d0a9ddeef999f0f47c2fb2839d55b to your computer and use it in GitHub Desktop.
GlobalMaxPooling1D GlobalAvgPooling1D
class GlobalMaxPooling1D(nn.Module):
'''
https://keras.io/api/layers/pooling_layers/global_max_pooling1d/
Code: https://discuss.pytorch.org/t/equivalent-of-keras-globalmaxpooling1d/45770/5
Input:
* If data_format='channels_last': 3D tensor with shape: (batch_size, steps, features)
* If data_format='channels_first': 3D tensor with shape: (batch_size, features, steps)
Output:
* 2D tensor with shape (batch_size, features).
'''
def __init__(self, data_format='channels_last'):
super(GlobalMaxPooling1D, self).__init__()
self.data_format = data_format
self.step_axis = 1 if self.data_format == 'channels_last' else 2
def forward(self, input):
return torch.max(input, axis=self.step_axis).values
class GlobalAvgPooling1D(nn.Module):
'''
https://keras.io/api/layers/pooling_layers/global_max_pooling1d/
Code: https://discuss.pytorch.org/t/equivalent-of-keras-globalmaxpooling1d/45770/5
Input:
* If data_format='channels_last': 3D tensor with shape: (batch_size, steps, features)
* If data_format='channels_first': 3D tensor with shape: (batch_size, features, steps)
Output:
* 2D tensor with shape (batch_size, features).
'''
def __init__(self, data_format='channels_last'):
super(GlobalAvgPooling1D, self).__init__()
self.data_format = data_format
self.step_axis = 1 if self.data_format == 'channels_last' else 2
def forward(self, input):
return torch.mean(input, dim=self.step_axis)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment