Skip to content

Instantly share code, notes, and snippets.

@ashhadulislam
Last active July 26, 2022 18:28
Show Gist options
  • Save ashhadulislam/05c734c6e47553f773e2cf7021a800b4 to your computer and use it in GitHub Desktop.
Save ashhadulislam/05c734c6e47553f773e2cf7021a800b4 to your computer and use it in GitHub Desktop.
def load_model():
'''
load a model
by default it is resnet 18 for now
'''
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(classes))
model.to(device)
model.load_state_dict(torch.load(PATH,map_location=device))
model.eval()
return model
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment