Created
August 13, 2023 18:47
-
-
Save thistleknot/d7f1675441874cad3e932d1bef138b49 to your computer and use it in GitHub Desktop.
GPT2 Batching API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI, Depends | |
from pydantic import BaseModel | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
from typing import List | |
import time | |
from threading import Thread, Lock | |
import torch | |
app = FastAPI() | |
MODEL_NAME = "gpt2-medium" | |
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME) | |
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = 'left' | |
# For request | |
class Prompt(BaseModel): | |
text: str | |
# For storing queue | |
prompt_queue = [] | |
queue_lock = Lock() | |
def process_batch(): | |
while True: | |
time.sleep(1) | |
with queue_lock: | |
if prompt_queue: | |
# Convert texts to model inputs | |
inputs = tokenizer(prompt_queue, return_tensors="pt", truncation=True, padding="longest", max_length=1024) | |
# Generate outputs | |
with torch.no_grad(): | |
outputs = model.generate(**inputs) | |
# Convert model outputs to texts | |
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
print(generated_texts) | |
# Clear the processed queue | |
prompt_queue.clear() | |
# Starting batch processing in a background thread | |
thread = Thread(target=process_batch) | |
thread.start() | |
@app.post("/add_prompt/") | |
async def add_prompt(prompt: Prompt): | |
with queue_lock: | |
prompt_queue.append(prompt.text) | |
return {"status": "added"} | |
from fastapi.testclient import TestClient | |
client = TestClient(app) | |
def test_prompt_generation(): | |
response1 = client.post("/add_prompt/", json={"text": "Once upon a time,"}) | |
response2 = client.post("/add_prompt/", json={"text": "In a galaxy far, far away,"}) | |
response3 = client.post("/add_prompt/", json={"text": "Deep in the heart of the jungle,"}) | |
assert response1.status_code == 200 | |
assert response1.json() == {"status": "added"} | |
assert response2.status_code == 200 | |
assert response2.json() == {"status": "added"} | |
assert response3.status_code == 200 | |
assert response3.json() == {"status": "added"} | |
# Wait for the batch to be processed in the background | |
time.sleep(2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment