Skip to content

Instantly share code, notes, and snippets.

@KerryHalupka
Last active August 16, 2020 09:52
Show Gist options
  • Save KerryHalupka/61165d61af5a89ae8d9968ce4a4f17b0 to your computer and use it in GitHub Desktop.
Save KerryHalupka/61165d61af5a89ae8d9968ce4a4f17b0 to your computer and use it in GitHub Desktop.
learning_rate = 0.0001
#get mobilenet
model = torchvision.models.mobilenet_v2(pretrained=True)
# freeze the feature extraction convolutional layers
for param in model.parameters():
param.requires_grad = False
# define a classification layer
model.classifier[1] = nn.Linear(in_features=model.classifier[1].in_features, out_features=1)
# send the model to the device
model = model.to(device)
# only optimize the classification layer
optimizer = torch.optim.Adam(model.classifier[1].parameters(), lr=learning_rate)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment