Skip to content

Instantly share code, notes, and snippets.

@eddyxu
Created January 27, 2022 03:41
Show Gist options
  • Save eddyxu/f407bdc46e6ee3e604f3a4fdb66c5adf to your computer and use it in GitHub Desktop.
Save eddyxu/f407bdc46e6ee3e604f3a4fdb66c5adf to your computer and use it in GitHub Desktop.
register ssd and feature extraction models
import rikai
from torchvision.models.detection.ssd import ssd300_vgg16
from rikai.contrib.torch.inspect.ssd import SSDClassScoresExtractor
from rikai.contrib.torch.detections import OUTPUT_SCHEMA
# SSD Model
ssd = ssd300_vgg16(pretrained=True)
# Class Score Extraction model
class_scores_extractor = SSDClassScoresExtractor(
ssd,
topk_candidates=90
)
with mlflow.start_run():
rikai.mlflow.pytorch.log_model(
ssd,
"model",
OUTPUT_SCHEMA,
pre_processing="rikai.contrib.torch.transforms.ssd.pre_processing",
post_processing="rikai.contrib.torch.transforms.ssd.post_processing",
registered_model_name="ssd"
)
with mlflow.start_run():
rikai.mlflow.pytorch.log_model(
class_scores_extractor,
"model_scores",
SSDClassScoresExtractor.SCHEMA,
pre_processing="rikai.contrib.torch.inspect.ssd.class_scores_extractor_pre_processing",
post_processing="rikai.contrib.torch.inspect.ssd.class_scores_extractor_post_processing",
registered_model_name="class_scores"
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment