This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Model: | |
def train(self, X, y): | |
pass | |
def predict(self, X): | |
pass | |
def save(self): | |
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
class PredictRequest(BaseModel): | |
data: List[List[float]] | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import joblib | |
import numpy as np | |
from pathlib import Path | |
from sklearn.ensemble import RandomForestRegressor | |
from sklearn.datasets import load_boston | |
class Model: | |
def __init__(self, model_path: str = None): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from fastapi import Depends | |
from .ml.model import get_model | |
@app.post("/predict", response_model=PredictResponse) | |
def predict(input: PredictRequest, model: Model = Depends(get_model)): | |
X = np.array(input.data) | |
y_pred = model.predict(X) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pydantic import BaseModel, ValidationError, validator | |
from .ml.model import n_features | |
class PredictRequest(BaseModel): | |
data: List[List[float]] | |
@validator("data") | |
def check_dimensionality(cls, v): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
class MockModel: | |
def __init__(self, model_path: str = None): | |
self._model_path = None | |
self._model = None | |
def predict(self, X: np.ndarray) -> np.ndarray: | |
n_instances = len(X) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
from starlette.testclient import TestClient | |
from ..main import app | |
from ..ml.model import get_model | |
from .mocks import MockModel | |
def get_model_override(): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytest | |
import random | |
from starlette.testclient import TestClient | |
from starlette.status import HTTP_200_OK | |
from api.ml.model import n_features | |
@pytest.mark.parametrize("n_instances", range(1, 10)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from itertools import product | |
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY | |
@pytest.mark.parametrize( | |
"n_instances, test_data_n_features", | |
product(range(1, 10), [n for n in range(1, 20) if n != n_features]), | |
) | |
def test_predict_with_wrong_input( |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from fastapi import File, UploadFile, HTTPException | |
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY | |
@app.post("/predict_csv") | |
def predict_csv(csv_file: UploadFile = File(...), model: Model = Depends(get_model)): | |
try: | |
df = pd.read_csv(csv_file.file).astype(float) |
OlderNewer