Last active
September 23, 2018 10:45
-
-
Save jalola/5ffdbe67ba806d92894a6ecde606fcf3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
class MyAdaptiveMaxPool2d(nn.Module): | |
def __init__(self, sz=None): | |
super().__init__() | |
self.p = nn.MaxPool2d((10, 10), padding=0) | |
# why (10, 10)? Because input image size is 299, \ | |
# if you use 224, this should be (7, 7) | |
# if you want to know which number for other image size, | |
# put pdb.set_trace() at forward method and print x.size() | |
def forward(self, x): | |
return self.p(x) | |
class MyAdaptiveAvgPool2d(nn.Module): | |
def __init__(self, sz=None): | |
super().__init__() | |
self.p = nn.AvgPool2d((10, 10), padding=0) | |
def forward(self, x): | |
return self.p(x) | |
class AdaptiveConcatPool2d(nn.Module): | |
def __init__(self, sz=None): | |
super().__init__() | |
sz = sz or (1,1) | |
#self.ap = nn.AdaptiveAvgPool2d(sz) | |
#self.mp = nn.AdaptiveMaxPool2d(sz) | |
self.ap = MyAdaptiveAvgPool2d(sz) # <-- replace new layer of AdaptivePooling here | |
self.mp = MyAdaptiveMaxPool2d(sz) # <-- replace new layer of AdaptivePooling here | |
def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment