Last active
August 4, 2023 16:09
-
-
Save vikashg/060b3b1adef84c4808519f3ccf2b85a8 to your computer and use it in GitHub Desktop.
This gist shows an implementation of pythonic downloading and usage of model zoo monai bundles. I propose to use the .json config files along with the ConfigParser. This model_def.py can be made more general and will be middle layer between the model_zoo and the python interface.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from monai.bundle.scripts import download | |
from monai.bundle.config_parser import ConfigParser | |
import os | |
class ModelDef(): | |
def __init__(self, model_name): | |
download(name=model_name) | |
download_dir = "/home/gupta/.cache/torch/hub/bundle/" + model_name | |
self.model_weights_file = os.path.join(download_dir, "models/model.pt") | |
inference_config_fn = os.path.join(download_dir, "configs/inference.json") | |
self.parser = ConfigParser() | |
self.parser.read_config(inference_config_fn) | |
def get_model_architecture(self): | |
""" | |
Returns the model architecture | |
:return: | |
""" | |
network = self.parser.get_parsed_content("network_def") | |
return network | |
def get_preprocessing(self): | |
""" | |
Returns the preprocessing pipeline | |
:return: | |
""" | |
preprocessing = self.parser.get_parsed_content("preprocessing") | |
return preprocessing | |
def use_pretrained_weights(self): | |
""" | |
Returns the pretrained weights | |
:return: | |
""" | |
network = self.get_model_architecture() | |
network_weights = network.load_state_dict(torch.load(self.model_weights_file)) | |
return network_weights |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
from model_zoo.model_def import ModelDef | |
import monai | |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
os.environ["CUDA_VISIBLE_DEVICES"] = "1" | |
bd = ModelDef("breast_density_classification") # use your model name here | |
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") | |
network = bd.get_model_architecture() | |
model_weights = bd.use_pretrained_weights() | |
print(model_weights) # Get the pretrained weights | |
model_weights.to(device) | |
preprocessing = bd.get_preprocessing() # Get the preprocessing pipeline | |
print(network) | |
# do prediction using the model | |
input_fn = "SomeFile.png" # Some sample images are included in the bundles folder | |
image = preprocessing(input_fn) | |
image = image.to(device) | |
output = model_weights(image) | |
print(output) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment