Skip to content

Instantly share code, notes, and snippets.

@fg91
Last active December 17, 2018 15:26
Show Gist options
  • Save fg91/23c2961936ccc9db674a996c316ab88b to your computer and use it in GitHub Desktop.
Save fg91/23c2961936ccc9db674a996c316ab88b to your computer and use it in GitHub Desktop.
how to truncate a base network
class YourCustomModel(nn.Module):
def __init__(self):
super().__init__()
# truncated base network, „True“ refers to pretrained
self.backbone = nn.Sequential(*list(resnet34(True).children())[:8])
# and your custom layers
self.features = nn.Sequential(
self.backbone,
# custom layers:
AdaptiveConcatPool2d(1),
Flatten(),
nn.BatchNorm1d(1024),
nn.Dropout(0.1),
nn.Linear(in_features=1024, out_features=1024, bias=True),
#… more layers
)
def forward(self, x):
return self.features(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment