Skip to content

Instantly share code, notes, and snippets.

@imiraoui
Created April 12, 2024 19:29
Show Gist options
  • Save imiraoui/8fa88654de7ed7e9ef6805a7cf814b73 to your computer and use it in GitHub Desktop.
Save imiraoui/8fa88654de7ed7e9ef6805a7cf814b73 to your computer and use it in GitHub Desktop.
from modal import Image, Mount, Stub, asgi_app, gpu, method
from PIL import Image as Image2
from typing import List
from io import BytesIO
from fastapi import FastAPI, Request
from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor
from surya.model.recognition.model import load_model as load_rec_model
from surya.model.recognition.processor import load_processor as load_rec_processor
from surya.ocr import run_ocr
import base64
def download_models():
det_processor, det_model = load_det_processor(), load_det_model()
rec_model, rec_processor = load_rec_model(), load_rec_processor()
ocr_image = (
Image.debian_slim()
.apt_install(
"libglib2.0-0", "libsm6", "libxrender1", "libxext6", "ffmpeg", "libgl1","git"
)
.pip_install("fastapi==0.110.1")
.pip_install("transformers~=4.36.2",
"accelerate~=0.23",
"safetensors~=0.3")
.pip_install(
"surya-ocr"
)
.run_commands("git clone https://github.com/VikParuchuri/surya.git")
.run_commands("cd surya")
.run_function(download_models)
)
stub = Stub("run-ocr")
web_app = FastAPI()
@stub.cls(gpu=gpu.T4(), container_idle_timeout=180, image=ocr_image)
class Model:
def __enter__(self):
self.det_processor, self.det_model = load_det_processor(), load_det_model()
self.rec_model, self.rec_processor = load_rec_model(), load_rec_processor()
@method()
def inference(self,langs, images):
predictions = run_ocr(
images, [langs], self.det_model, self.det_processor, self.rec_model, self.rec_processor
)
return predictions
@stub.local_entrypoint()
def main(image_base64: str,langs: str):
images = []
image_data = base64.b64decode(image_base64)
image = Image2.open(BytesIO(image_data))
predictions = Model().inference.remote(langs, [image])
@stub.function(image=ocr_image, container_idle_timeout=45)
@asgi_app()
def fastapi_app():
return web_app
@web_app.post("/runOcr")
async def runOcr(request: Request):
body = await request.json()
langs = body["langs"]
image_base64 = body["image_base64"]
image_data = base64.b64decode(image_base64)
image = Image2.open(BytesIO(image_data))
predictions = Model().inference.remote(langs, [image])
return predictions
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment