Skip to content

Instantly share code, notes, and snippets.

@kasperjunge
Created December 8, 2022 07:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kasperjunge/4ec27724272109b9a413e4d1477ef7cd to your computer and use it in GitHub Desktop.
Save kasperjunge/4ec27724272109b9a413e4d1477ef7cd to your computer and use it in GitHub Desktop.
ChatGPT answer to: Implement me a hyper-optimized service for serving predictions by transformer models
from fastapi import FastAPI
import transformers
import torch
app = FastAPI()
# Load the tokenizer and model
tokenizer = transformers.AutoTokenizer.from_pretrained("model_name")
model = transformers.AutoModelForTokenClassification.from_pretrained("model_name")
# Use torch.jit.trace to create a torch.jit.ScriptModule from the model
traced_model = torch.jit.trace(model, example_inputs=inputs)
# Check if a CUDA-enabled GPU is available
if torch.cuda.is_available():
# If a GPU is available, move the model and input tensors to the GPU
traced_model = traced_model.to("cuda")
inputs = inputs.to("cuda")
@app.post("/predict")
def predict(input: str):
# Tokenize the input
tokens = tokenizer.encode(input)
# Convert the tokens to their numeric IDs
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# Create a tensor with the input IDs
inputs = torch.tensor([input_ids])
# If a GPU is available, move the input tensor to the GPU
if torch.cuda.is_available():
inputs = inputs.to("cuda")
# Prevent PyTorch from tracking the gradients of the model's weights
with torch.no_grad():
# Use the traced model to generate predictions
predictions = traced_model(inputs)
# Get the first prediction
first_prediction = predictions[0]
# If a GPU is available, move the prediction tensor back to the CPU
if torch.cuda.is_available():
first_prediction = first_prediction.to("cpu")
# Convert the prediction tensor to a list of predicted labels
predicted_labels = [tokenizer.decode(g) for g in torch.argmax(first_prediction, axis=1)]
# Return the predicted labels as the response
return {"labels": predicted_labels}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment