Skip to content

Instantly share code, notes, and snippets.

@Ankur-singh
Created August 24, 2023 21:20
Show Gist options
  • Save Ankur-singh/4e1b723f997f94c9150e376140428c33 to your computer and use it in GitHub Desktop.
Save Ankur-singh/4e1b723f997f94c9150e376140428c33 to your computer and use it in GitHub Desktop.
A basic working example demonstrating how to add a Gradio gateway to Jina `Flow`.
import gradio as gr
from pydantic import BaseModel
from jina import Flow, Executor, requests
from docarray import DocList, BaseDoc
from fastapi import FastAPI
from jina.serve.runtimes.gateway.http.fastapi import FastAPIBaseGateway
## helper function to convert pydantic model -> gradio interface
def generate_gradio_interface(model: BaseModel):
inputs = []
# Process each attribute in the model
for attr, field in model.__annotations__.items():
input_type = field.__name__
input_label = attr.replace("_", " ").capitalize()
# Generate appropriate input component based on the field type
if input_type == "str":
# Additional options for string type
field_info = model.__annotations__[attr]
default = field_info.default if hasattr(field_info, "default") else None
choices = field_info.choices if hasattr(field_info, "choices") else None
input_component = gr.Textbox(
label=input_label,
)
elif input_type == "int":
# Additional options for integer type
field_info = model.__annotations__[attr]
ge = field_info.ge if hasattr(field_info, "ge") else None
le = field_info.le if hasattr(field_info, "le") else None
input_component = gr.Number(
label=input_label,
minimum=ge,
maximum=le,
step=1,
)
elif input_type == "float":
# Additional options for float type
field_info = model.__annotations__[attr]
ge = field_info.ge if hasattr(field_info, "ge") else None
le = field_info.le if hasattr(field_info, "le") else None
input_component = gr.Number(
label=input_label,
minimum=ge,
maximum=le,
step=0.01,
)
elif input_type == "bool":
# Additional options for boolean type
field_info = model.__annotations__[attr]
input_component = gr.Checkbox(label=input_label)
elif input_type == "File":
input_component = gr.File(label=input_label)
elif input_type == "Path":
input_component = gr.Textbox(label=input_label)
else:
# For unsupported types, skip the attribute
continue
# Add the input component to the inputs list
inputs.append(input_component)
return inputs
## Basic executor
class Input(BaseDoc):
text: str
class Output(BaseDoc):
text: str
class Capitalize(Executor):
@requests
def func(self, docs: DocList[Input], **kwargs) -> DocList[Output]:
for doc in docs:
doc.text = doc.text.title()
## Gradio Gateway
class GradioGateway(FastAPIBaseGateway):
async def post(self, text: str):
docs = await self.executor["capitalize_executor"].post(
on="/",
inputs=DocList[Input]([Input(text=text)]),
parameters={"k": "v"},
return_type=DocList[Output],
)
return docs
@property
def app(self):
app = FastAPI()
async def process(text: str):
docs = await self.post(text)
return docs.text[0]
inputs = generate_gradio_interface(Input)
outputs = generate_gradio_interface(Output)
interface = gr.Interface(process, inputs=inputs, outputs=outputs)
app = gr.mount_gradio_app(app, interface, path="/demo")
return app
with Flow().config_gateway(uses=GradioGateway, protocol="http", port=59771).add(
uses=Capitalize, name="capitalize_executor"
) as flow:
flow.block()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment