Skip to content

Instantly share code, notes, and snippets.

@public-daniel
Last active November 11, 2020 16:58
Show Gist options
  • Save public-daniel/ccc12bbe2df16a46f509764af0d773e9 to your computer and use it in GitHub Desktop.
Save public-daniel/ccc12bbe2df16a46f509764af0d773e9 to your computer and use it in GitHub Desktop.
Few ideas for Greg on ML serving with FastAPI
from datetime import datetime
from enum import Enum
import time
from typing import Optional
from uuid import UUID, uuid4
from fastapi import BackgroundTasks, FastAPI, Request
from pydantic import BaseModel, Field
app = FastAPI()
class ModelClasses(str, Enum):
car = "car"
person = "person"
animal = "animal"
class ModelInput(BaseModel):
id: int = Field(..., title="Primary key from source system database")
feature_1: str = Field(..., title="Super important feature")
feature_2: Optional[str] = Field(
None, title="Kinda important feature, won't always have this"
)
class Config:
schema_extra = {
"example": {
"id": 1,
"feature_1": "example input of feature 1",
"feature_2": "example input of feature 2",
}
}
class ModelOutput(BaseModel):
id: int = Field(..., title="Primary key from source system database")
prediction_uuid: UUID = Field(..., title="Unique identifier for prediction")
prediction_createdAt: datetime = Field(
..., title="UTC datetime prediction was generated"
)
prediction_prob: float = Field(..., title="Probability prediction belongs to class")
prediction_class: ModelClasses = Field(..., title="Model Classes")
class Config:
schema_extra = {
"example": {
"id": 1,
"prediction_uuid": "3fa85f64-5717-4562-b3fc-2c963f66afa6",
"prediction_createdAt": "2020-11-11T15:59:02.500Z",
"prediction_prob": 0.71,
"prediction_class": "car"
}
}
@app.on_event("startup")
async def load_model():
print("Loading model")
global model
model = {"name": "Best Model Ever v3", "weights": [0.31, 1.23, 1.31]} # Load the model into memory here, could pull from S3, etc.
@app.middleware("http")
async def add_headers(request: Request, call_next):
start_time = time.time()
request_id = uuid4()
request.state.request_id = request_id
response = await call_next(request)
response.headers["X-Request-ID"] = str(request_id)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
return response
def model_monitoring(request_id, model_name, model_process_time, model_input, model_output):
"""Log inputs and prediction to model monitoring service"""
print(request_id, model_name, process_time, model_input, model_output)
# In practice you'd log this to some other service (e.g. CloudWatch)
@app.post(
"/predict",
response_model=ModelOutput,
tags=["predict"],
summary="Single Prediction, Optimized for Low Latency",
response_description="This is the expected model output",
)
def single_predict(model_input: ModelInput, background_tasks: BackgroundTasks, request: Request):
"""
Make a prediction from input features:
- **id**: This is the source system's primary key
- **feature_1**: We have to have this because blah blah
- **feature_2**: We expect this in the following circumstances...
"""
start_time = time.time()
prediction = {
"id": 1,
"prediction_uuid": uuid4(),
"prediction_createdAt": datetime(2020, 11, 11, 15, 16, 56, 341909),
"prediction_prob": 0.87,
"prediction_class": ModelClasses("person"),
}
model_output = ModelOutput(**prediction)
model_process_time = time.time() - start_time
background_tasks.add_task(model_monitoring, request.state.request_id, model["name"], model_process_time, model_input, model_output)
return model_output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment