Skip to content

Instantly share code, notes, and snippets.

@thundergolfer
Created May 2, 2023 20:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thundergolfer/3e39977d7bb0596933408d44e9ac2a8d to your computer and use it in GitHub Desktop.
Save thundergolfer/3e39977d7bb0596933408d44e9ac2a8d to your computer and use it in GitHub Desktop.
import modal
def download_model():
from transformers import pipeline
pipeline("fill-mask", model="bert-base-uncased")
CACHE_PATH = "/root/model_cache"
ENV = modal.Secret({"TRANSFORMERS_CACHE": CACHE_PATH})
image = (
modal.Image.debian_slim()
.pip_install("torch", "transformers")
.run_function(download_model, secret=ENV)
)
stub = modal.Stub(name="hn-demo", image=image)
class Model:
def __enter__(self):
import torch
from transformers import pipeline
self.model = pipeline("fill-mask", model="bert-base-uncased", device=0)
@stub.function(
gpu="a10g",
secret=ENV,
)
def handler(self, prompt: str):
return self.model(prompt)
if __name__ == "__main__":
with stub.run():
prompt = "Hello World! I am a [MASK] machine learning model."
print(Model().handler.call(prompt)[0]["sequence"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment