Last active
November 11, 2020 16:58
-
-
Save public-daniel/ccc12bbe2df16a46f509764af0d773e9 to your computer and use it in GitHub Desktop.
Few ideas for Greg on ML serving with FastAPI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from datetime import datetime | |
from enum import Enum | |
import time | |
from typing import Optional | |
from uuid import UUID, uuid4 | |
from fastapi import BackgroundTasks, FastAPI, Request | |
from pydantic import BaseModel, Field | |
app = FastAPI() | |
class ModelClasses(str, Enum): | |
car = "car" | |
person = "person" | |
animal = "animal" | |
class ModelInput(BaseModel): | |
id: int = Field(..., title="Primary key from source system database") | |
feature_1: str = Field(..., title="Super important feature") | |
feature_2: Optional[str] = Field( | |
None, title="Kinda important feature, won't always have this" | |
) | |
class Config: | |
schema_extra = { | |
"example": { | |
"id": 1, | |
"feature_1": "example input of feature 1", | |
"feature_2": "example input of feature 2", | |
} | |
} | |
class ModelOutput(BaseModel): | |
id: int = Field(..., title="Primary key from source system database") | |
prediction_uuid: UUID = Field(..., title="Unique identifier for prediction") | |
prediction_createdAt: datetime = Field( | |
..., title="UTC datetime prediction was generated" | |
) | |
prediction_prob: float = Field(..., title="Probability prediction belongs to class") | |
prediction_class: ModelClasses = Field(..., title="Model Classes") | |
class Config: | |
schema_extra = { | |
"example": { | |
"id": 1, | |
"prediction_uuid": "3fa85f64-5717-4562-b3fc-2c963f66afa6", | |
"prediction_createdAt": "2020-11-11T15:59:02.500Z", | |
"prediction_prob": 0.71, | |
"prediction_class": "car" | |
} | |
} | |
@app.on_event("startup") | |
async def load_model(): | |
print("Loading model") | |
global model | |
model = {"name": "Best Model Ever v3", "weights": [0.31, 1.23, 1.31]} # Load the model into memory here, could pull from S3, etc. | |
@app.middleware("http") | |
async def add_headers(request: Request, call_next): | |
start_time = time.time() | |
request_id = uuid4() | |
request.state.request_id = request_id | |
response = await call_next(request) | |
response.headers["X-Request-ID"] = str(request_id) | |
process_time = time.time() - start_time | |
response.headers["X-Process-Time"] = str(process_time) | |
return response | |
def model_monitoring(request_id, model_name, model_process_time, model_input, model_output): | |
"""Log inputs and prediction to model monitoring service""" | |
print(request_id, model_name, process_time, model_input, model_output) | |
# In practice you'd log this to some other service (e.g. CloudWatch) | |
@app.post( | |
"/predict", | |
response_model=ModelOutput, | |
tags=["predict"], | |
summary="Single Prediction, Optimized for Low Latency", | |
response_description="This is the expected model output", | |
) | |
def single_predict(model_input: ModelInput, background_tasks: BackgroundTasks, request: Request): | |
""" | |
Make a prediction from input features: | |
- **id**: This is the source system's primary key | |
- **feature_1**: We have to have this because blah blah | |
- **feature_2**: We expect this in the following circumstances... | |
""" | |
start_time = time.time() | |
prediction = { | |
"id": 1, | |
"prediction_uuid": uuid4(), | |
"prediction_createdAt": datetime(2020, 11, 11, 15, 16, 56, 341909), | |
"prediction_prob": 0.87, | |
"prediction_class": ModelClasses("person"), | |
} | |
model_output = ModelOutput(**prediction) | |
model_process_time = time.time() - start_time | |
background_tasks.add_task(model_monitoring, request.state.request_id, model["name"], model_process_time, model_input, model_output) | |
return model_output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment