-
-
Save tkeyo/ddb5fce50704d1b5a7fc622e6c713494 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchvision | |
import torchvision.transforms as transforms | |
from fastai.vision.all import * | |
# load FastAI ResNet model | |
learn = load_learner('models/hot_dog_model_resnet18_256_256.pkl') | |
# get PyTorch model | |
# .model attribute stores the model | |
# .eval() sets the model into evaluation mode - no backward propagation | |
pytorch_model = learn.model.eval() | |
# define softmax layer | |
softmax_layer = torch.nn.Softmax(dim=1) # define softmax | |
# define normalization layer | |
normalization_layer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
# assemble the final model | |
final_model = nn.Sequential( | |
normalization_layer, | |
pytorch_model, | |
softmax_layer | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment