Skip to content

Instantly share code, notes, and snippets.

@AntoineToubhans
Last active August 19, 2021 08:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AntoineToubhans/40c436052e2c5437ae9cb50ad89d0e0a to your computer and use it in GitHub Desktop.
Save AntoineToubhans/40c436052e2c5437ae9cb50ad89d0e0a to your computer and use it in GitHub Desktop.
from pathlib import Path
import tensorflow as tf
# Warning: this is private internal dvc api, it may change with future version
import dvc.repo.get
ROOT_MODEL_CACHE_DIR = Path(".model_cache")
ROOT_MODEL_CACHE_DIR.mkdir(exist_ok=True)
@st.cache
def load_model(rev: str):
print(f"Loading model for revision {rev}")
# 1. Download model to model cache dir using `dvc get`
# See https://dvc.org/doc/command-reference/get
model_cache_dir = str(ROOT_MODEL_CACHE_DIR / rev)
# Try to load the model directly (if it is in cache dir)
try:
return tf.keras.models.load_model(model_cache_dir)
except OSError:
print(f"Could not find model {rev} in cache")
except Exception as e:
print(f"Could not load model {rev} from cache")
dvc.repo.get.get(
url=".",
path="data/train/model",
out=model_cache_dir,
rev=rev
)
print(f"Model downloaded to {model_cache_dir}")
# 2. Load the model with tf.keras.models.load_model
return tf.keras.models.load_model(model_cache_dir)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment