Skip to content

Instantly share code, notes, and snippets.

@pranshuj73
Created July 3, 2020 11:23
Show Gist options
  • Save pranshuj73/a64f42231cf3f3a0519e76eaee568d46 to your computer and use it in GitHub Desktop.
Save pranshuj73/a64f42231cf3f3a0519e76eaee568d46 to your computer and use it in GitHub Desktop.
# function to turn photos to tensor
def img2tensor(x):
transform = transforms.Compose(
[transforms.ToPILImage(),
transforms.Grayscale(num_output_channels=1),
transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))])
return transform(x)
# the model for predicting
model = FERModel(1, 7)
softmax = torch.nn.Softmax(dim=1)
model.load_state_dict(torch.load('FER2013-Resnet9.pth', map_location=get_default_device()))
def predict(x):
out = model(img2tensor(img)[None])
scaled = softmax(out)
prob = torch.max(scaled).item()
label = classes[torch.argmax(scaled).item()]
return {'label': label, 'probability': prob}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment