Skip to content

Instantly share code, notes, and snippets.

@afiodorov
Last active October 28, 2023 11:51
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save afiodorov/f0214e317bd82fa610d6172d190896f6 to your computer and use it in GitHub Desktop.
Save afiodorov/f0214e317bd82fa610d6172d190896f6 to your computer and use it in GitHub Desktop.
Run your own LLM & create an api endpoint for predictions
Docker Image : pytorch/pytorch
Image Runtype : jupyter_direc ssh_direc ssh_proxy
Environment : [["JUPYTER_DIR", "/"], ["-p 41654:41654", "1"]]
pip install torch bitsandbytes sentencepiece "protobuf<=3.20.2" git+https://github.com/huggingface/transformers flask python-dotenv Flask-HTTPAuth accelerate
!mv /opt/conda/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda116.so /opt/conda/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "ehartford/WizardLM-13B-Uncensored"
model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16, device_map="auto", load_in_8bit=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
@torch.inference_mode()
def generate_text(input_text):
input_ids = tokenizer([input_text]).input_ids
output_ids = model.generate(
torch.as_tensor(input_ids).cuda(),
do_sample=True,
temperature=0.8,
max_new_tokens=2048,
)
outputs = tokenizer.decode(
output_ids[0][len(input_ids[0]) :],
skip_special_tokens=True,
spaces_between_special_tokens=False,
)
return outputs
servers = []
threads = []
import os
import threading
from dotenv import load_dotenv
from flask import Flask, jsonify, make_response, request
from flask_httpauth import HTTPBasicAuth
from werkzeug.serving import make_server
load_dotenv(override=True)
user = os.environ["USER"]
pass_ = os.environ["PASS"]
app = Flask(__name__)
auth = HTTPBasicAuth()
@auth.verify_password
def verify_password(username, password):
if username == user and password == pass_:
return username
return None
@auth.error_handler
def unauthorized():
return make_response(jsonify({"error": "Unauthorized access"}), 401)
@app.route("/", methods=["POST"])
@auth.login_required
def completion():
data = request.get_json()
if not data or not data["prompt"]:
return jsonify({"error": "No JSON data received"}), 400
try:
completion = generate_text(data["prompt"])
except Exception as e:
app.logger.exception("couldn't query t he model")
return jsonify({"error": "Couldn't run the inference"}), 500
return jsonify({"completion": completion})
def run_flask_app():
server = make_server("0.0.0.0", 41654, app)
servers.append(server)
server.serve_forever()
thread = threading.Thread(target=run_flask_app)
threads.append(thread)
thread.start()
def shutdown():
for s in servers:
s.shutdown()
for t in threads:
t.join()
servers.clear()
threads.clear()
shutdown()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment