Skip to content

Instantly share code, notes, and snippets.

@heiko-hotz
Created March 18, 2023 17:18
Show Gist options
  • Save heiko-hotz/35930f60a8117b19a3a7e472560fd2ae to your computer and use it in GitHub Desktop.
Save heiko-hotz/35930f60a8117b19a3a7e472560fd2ae to your computer and use it in GitHub Desktop.
from transformers import T5ForConditionalGeneration, AutoTokenizer
import torch
import os
def model_fn(model_dir):
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xxl",
load_in_8bit=True, device_map="auto", cache_dir="/tmp/model_cache/")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-xxl")
return model, tokenizer
def predict_fn(data, model_and_tokenizer):
model, tokenizer = model_and_tokenizer
text = data.pop("inputs", data)
inputs = tokenizer(text, return_tensors="pt").input_ids.to("cuda")
outputs = model.generate(inputs, **data)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment