Skip to content

Instantly share code, notes, and snippets.

@jcrubino
Last active June 16, 2023 06:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jcrubino/307e57f1d4a1a539120c51e069643f68 to your computer and use it in GitHub Desktop.
Save jcrubino/307e57f1d4a1a539120c51e069643f68 to your computer and use it in GitHub Desktop.
Simple Flask Web Server for Image Captioning and Document Embedding Generation
import base64
import logging
from io import BytesIO
import torch
import torch.nn.functional as F
from flask import Flask, request
from PIL import Image
from torch import Tensor
from transformers import AutoModel, AutoTokenizer
app = Flask(__name__)
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
TEXT_EMBED_MODEL = "intfloat/e5-large-v2"
logging.info(f"Text Embedding Model: {TEXT_EMBED_MODEL}")
embedding_tokenizer = AutoTokenizer.from_pretrained(TEXT_EMBED_MODEL)
embedding_model = AutoModel.from_pretrained(TEXT_EMBED_MODEL)
PROD_MODEL = "microsoft/git-base-coco"
logging.info(f"Image 2 Text model: {PROD_MODEL}")
def process_image(pil_image, prompt=""):
raise NotImplementedError()
def load_model(PROD_MODEL):
if "Salesforce" in PROD_MODEL:
from transformers import BlipForConditionalGeneration, BlipProcessor
processor = BlipProcessor.from_pretrained(PROD_MODEL)
model = BlipForConditionalGeneration.from_pretrained(PROD_MODEL, load_in_8bit=True)
elif "microsoft/" in PROD_MODEL:
from transformers import AutoModelForCausalLM, AutoProcessor
processor = AutoProcessor.from_pretrained(PROD_MODEL)
model = AutoModelForCausalLM.from_pretrained(PROD_MODEL, load_in_8bit=True)
else:
raise Exception(f"Unknown Image 2 Text Model: {PROD_MODEL}")
return model, processor
model, processor = load_model(PROD_MODEL)
def process_image(pil_image, prompt=""):
device = "cuda"
caption = "Nothing Processed"
if isinstance(processor, BlipProcessor):
if prompt != "":
inputs = processor(pil_image, prompt, return_tensors="pt").to(
device, torch.float16
)
else:
inputs = processor(pil_image, return_tensors="pt").to(device, torch.float16)
elif isinstance(processor, AutoProcessor):
pixel_values = processor(
images=pil_image, return_tensors="pt"
).pixel_values.half()
if prompt != "":
input_ids = processor(text=question, add_special_tokens=False).input_ids
input_ids = [processor.tokenizer.cls_token_id] + input_ids
input_ids = torch.tensor(input_ids).unsqueeze(0)
inputs = {"pixel_values": pixel_values, "input_ids": input_ids, "max_length": 50}
else:
inputs = {"pixel_values": pixel_values, "max_length": 50}
else:
raise Exception("Unknown processor type")
generated_ids = model.generate(**inputs)
caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return caption.strip()
@app.route("/")
def hello_world():
return """
<h1>Base64 Caption Server</h1>
<p>Post an HTML base64 encoded image to /base64caption and receive a description.</p>
"""
@app.route("/base64caption", methods=["POST"])
def base64caption():
request_data = request.get_json()
base64_data = request_data["base64"]
base64_data = base64_data.split(",")[-1]
image_data = base64.b64decode(base64_data)
pil_image = Image.open(BytesIO(image_data)).convert("RGB")
caption = process_image(pil_image)
return {"caption": caption}
@app.route("/embedding", methods=["GET"])
def embedding_base():
return """\
\r<h1>Embeddings</h1>
\r<p></p>
\r<p>Available Endpoints: /embedding/text</p>
\r<p>post data: {"batch":[list, of, strings], "prefix":(query|passage), "normalize":(True|False)}</p>
"""
@app.route("/embedding/text", methods=["POST"])
def embedding_text():
request_data = request.get_json()
logging.info(request_data)
batch_data = request_data["batch"]
prefix = request_data["prefix"]
normalize = request_data["normalize"]
assert normalize in [True, False]
input_texts = [f"{prefix}: {item}" for item in batch_data]
batch_dict = embedding_tokenizer(
input_texts, max_length=512, padding=True, truncation=True, return_tensors="pt"
)
outputs = embedding_model(**batch_dict)
embeddings = average_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
if normalize:
embeddings = F.normalize(embeddings, p=2, dim=1)
return {
"embeddings": embeddings.tolist(),
"batch": batch_data,
"prefix": prefix,
"normalize": normalize,
}
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000, debug=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment