Skip to content

Instantly share code, notes, and snippets.

@kishida
Created February 6, 2024 05:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kishida/655cf703687de1f4979b982c6d6f0e52 to your computer and use it in GitHub Desktop.
Save kishida/655cf703687de1f4979b982c6d6f0e52 to your computer and use it in GitHub Desktop.
Code inference server for HuggingFace VSCode
import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
import json
import logging
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "stabilityai/stable-code-3b"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, torch_dtype="auto")
model.cuda()
app = FastAPI()
logging.basicConfig(
format="%(asctime)s %(levelname)s %(message)s",
level=logging.INFO,
handlers=[
logging.FileHandler('app.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger('app')
def generate(prompt: str, param: dict):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
tokens = model.generate(
**inputs,
max_new_tokens=param['max_new_tokens'],
temperature=param['temperature'],
do_sample=True,
)
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
return output[len(prompt):]
@app.post("/api/generate/")
async def api(request: Request):
json_request: dict = await request.json()
inputs: str = json_request['inputs']
param: dict = json_request['parameters']
logger.info(f'{request.client.host}:{request.client.port} inputs = {json.dumps(inputs)} param={json.dumps(param)}')
return {
"generated_text": generate(inputs, param),
"status": 200
}
uvicorn.run(app, host="localhost", port=8000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment