Skip to content

Instantly share code, notes, and snippets.

@jaklinger
Last active June 9, 2024 14:50
Show Gist options
  • Save jaklinger/2b72173c8644f27e9c2b0eabe854ea00 to your computer and use it in GitHub Desktop.
Save jaklinger/2b72173c8644f27e9c2b0eabe854ea00 to your computer and use it in GitHub Desktop.
Read a torch model from S3
import torch
from contextlib import contextmanager
import boto3
from io import BytesIO
from transformers import PretrainedConfig, PreTrainedModel
import json
from tempfile import NamedTemporaryFile
BUCKET_NAME = "open-jobs-lake"
@contextmanager
def s3_fileobj(filename):
"""Yields a file object from the filename at {BUCKET_NAME}/{S3_PATH}"""
s3 = boto3.client("s3")
obj = s3.get_object(Bucket=BUCKET_NAME, Key=filename)
yield BytesIO(obj["Body"].read())
def load_model(s3_path):
tempfile = NamedTemporaryFile()
with s3_fileobj(f'{s3_path}/pytorch_model.bin') as f:
tempfile.write(f.read())
with s3_fileobj(f'{s3_path}/config.json') as f:
dict_data = json.load(f)
config = PretrainedConfig.from_dict(dict_data)
model = PreTrainedModel.from_pretrained(tempfile.name, config=config)
return model
model = load_model('labs/sic/sic4_classifier')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment