Skip to content

Instantly share code, notes, and snippets.

@jtrive84
Created April 23, 2024 20:59
Show Gist options
  • Save jtrive84/a2dd3b9468586de016a26212ef9aaa78 to your computer and use it in GitHub Desktop.
Save jtrive84/a2dd3b9468586de016a26212ef9aaa78 to your computer and use it in GitHub Desktop.
Classifier with softmax
class PreTrainedImageClassifier(nn.Module):
"""
Transfer learning using Resnet models for map image classification.
"""
def __init__(self, pt_model, dropout=0):
super().__init__()
# Set requires_grad = False for pretrained model.
for param in pt_model.parameters():
param.requires_grad = False
# self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
pt_model.fc = nn.Sequential(
nn.Linear(in_features=2048, out_features=32),
nn.Dropout(p=dropout),
nn.ReLU(),
nn.Linear(in_features=32, out_features=2)
)
self.model = pt_model
def forward(self, input):
output = self.softmax(self.model(input))
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment