Skip to content

Instantly share code, notes, and snippets.

@thistleknot
Created August 13, 2023 18:47
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thistleknot/d7f1675441874cad3e932d1bef138b49 to your computer and use it in GitHub Desktop.
Save thistleknot/d7f1675441874cad3e932d1bef138b49 to your computer and use it in GitHub Desktop.
GPT2 Batching API
from fastapi import FastAPI, Depends
from pydantic import BaseModel
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from typing import List
import time
from threading import Thread, Lock
import torch
app = FastAPI()
MODEL_NAME = "gpt2-medium"
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME)
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
# For request
class Prompt(BaseModel):
text: str
# For storing queue
prompt_queue = []
queue_lock = Lock()
def process_batch():
while True:
time.sleep(1)
with queue_lock:
if prompt_queue:
# Convert texts to model inputs
inputs = tokenizer(prompt_queue, return_tensors="pt", truncation=True, padding="longest", max_length=1024)
# Generate outputs
with torch.no_grad():
outputs = model.generate(**inputs)
# Convert model outputs to texts
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(generated_texts)
# Clear the processed queue
prompt_queue.clear()
# Starting batch processing in a background thread
thread = Thread(target=process_batch)
thread.start()
@app.post("/add_prompt/")
async def add_prompt(prompt: Prompt):
with queue_lock:
prompt_queue.append(prompt.text)
return {"status": "added"}
from fastapi.testclient import TestClient
client = TestClient(app)
def test_prompt_generation():
response1 = client.post("/add_prompt/", json={"text": "Once upon a time,"})
response2 = client.post("/add_prompt/", json={"text": "In a galaxy far, far away,"})
response3 = client.post("/add_prompt/", json={"text": "Deep in the heart of the jungle,"})
assert response1.status_code == 200
assert response1.json() == {"status": "added"}
assert response2.status_code == 200
assert response2.json() == {"status": "added"}
assert response3.status_code == 200
assert response3.json() == {"status": "added"}
# Wait for the batch to be processed in the background
time.sleep(2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment