Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
predict_csv endpoint
import pandas as pd
from fastapi import File, UploadFile, HTTPException
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
@app.post("/predict_csv")
def predict_csv(csv_file: UploadFile = File(...), model: Model = Depends(get_model)):
try:
df = pd.read_csv(csv_file.file).astype(float)
except:
raise HTTPException(
status_code=HTTP_422_UNPROCESSABLE_ENTITY, detail="Unable to process file"
)
df_n_instances, df_n_features = df.shape
if df_n_features != n_features:
raise HTTPException(
status_code=HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Each data point must contain {n_features} features",
)
y_pred = model.predict(df.to_numpy().reshape(-1, n_features))
result = PredictResponse(data=y_pred.tolist())
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment