Skip to content

Instantly share code, notes, and snippets.

@vikashg
Last active August 4, 2023 16:09
Show Gist options
  • Save vikashg/060b3b1adef84c4808519f3ccf2b85a8 to your computer and use it in GitHub Desktop.
Save vikashg/060b3b1adef84c4808519f3ccf2b85a8 to your computer and use it in GitHub Desktop.
This gist shows an implementation of pythonic downloading and usage of model zoo monai bundles. I propose to use the .json config files along with the ConfigParser. This model_def.py can be made more general and will be middle layer between the model_zoo and the python interface.
import torch
from monai.bundle.scripts import download
from monai.bundle.config_parser import ConfigParser
import os
class ModelDef():
def __init__(self, model_name):
download(name=model_name)
download_dir = "/home/gupta/.cache/torch/hub/bundle/" + model_name
self.model_weights_file = os.path.join(download_dir, "models/model.pt")
inference_config_fn = os.path.join(download_dir, "configs/inference.json")
self.parser = ConfigParser()
self.parser.read_config(inference_config_fn)
def get_model_architecture(self):
"""
Returns the model architecture
:return:
"""
network = self.parser.get_parsed_content("network_def")
return network
def get_preprocessing(self):
"""
Returns the preprocessing pipeline
:return:
"""
preprocessing = self.parser.get_parsed_content("preprocessing")
return preprocessing
def use_pretrained_weights(self):
"""
Returns the pretrained weights
:return:
"""
network = self.get_model_architecture()
network_weights = network.load_state_dict(torch.load(self.model_weights_file))
return network_weights
import os
import torch
from model_zoo.model_def import ModelDef
import monai
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
bd = ModelDef("breast_density_classification") # use your model name here
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
network = bd.get_model_architecture()
model_weights = bd.use_pretrained_weights()
print(model_weights) # Get the pretrained weights
model_weights.to(device)
preprocessing = bd.get_preprocessing() # Get the preprocessing pipeline
print(network)
# do prediction using the model
input_fn = "SomeFile.png" # Some sample images are included in the bundles folder
image = preprocessing(input_fn)
image = image.to(device)
output = model_weights(image)
print(output)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment