Skip to content

Instantly share code, notes, and snippets.

@justinttl
Last active February 8, 2024 01:18
Show Gist options
  • Save justinttl/37acb80fdc1b978e8fc6f68d5b1c3aff to your computer and use it in GitHub Desktop.
Save justinttl/37acb80fdc1b978e8fc6f68d5b1c3aff to your computer and use it in GitHub Desktop.
FastAPI marshaling performance analysis
import json
from time import time
from typing import List, Type
import matplotlib.pyplot as plt
import orjson
import pandas as pd
from fastapi.encoders import jsonable_encoder
from fastapi.utils import create_response_field
from pydantic import BaseModel
class ThingSchema(BaseModel):
name: str
class ManyObjectsSchema(BaseModel):
things: List[ThingSchema]
class ListSchema(BaseModel):
things: List[str]
def build_many_objects_schema(length: int) -> ManyObjectsSchema:
return ManyObjectsSchema(things=[ThingSchema(name=str(i)) for i in range(length)])
def build_list_schema(length: int) -> ListSchema:
return ListSchema(things=[str(i) for i in range(length)])
def benchmark():
max_length = 50000
length_step = 1000
lengths = range(0, max_length, length_step)
# List Schema
print("=== Benchmark for single object schemas ===")
list_schema_stats = []
for length in lengths:
print(f"Benchmark for length {length}")
schema = build_list_schema(length)
list_schema_stats.append(time_ops(schema, model=ListSchema))
list_times = pd.DataFrame.from_records(list_schema_stats, index=lengths)
# Many Objects Schema
print("=== Benchmark for many objects schemas ===")
many_objects_schema_stats = []
for length in lengths:
print(f"Benchmark for length {length}")
schema = build_many_objects_schema(length)
many_objects_schema_stats.append(time_ops(schema, model=ManyObjectsSchema))
many_objects_times = pd.DataFrame.from_records(
many_objects_schema_stats, index=lengths
)
# Plotting
list_times[["jsonable_encoder", "dict"]].plot(
title="Encoding Times (List of str)",
xlabel="Number of things",
ylabel="Seconds",
)
plt.savefig("list_encoding.png")
plt.clf()
many_objects_times[["jsonable_encoder", "dict"]].plot(
title="Encoding Times (List of objects)",
xlabel="Number of things",
ylabel="Seconds",
)
plt.savefig("many_objects_encoding.png")
plt.clf()
validate_times = pd.concat(
[list_times["validate"], many_objects_times["validate"]],
axis=1,
keys=["List of str", "List of objects"],
)
ax = validate_times.plot(
title="Validate Times", xlabel="Number of things", ylabel="Seconds"
)
ax.ticklabel_format(useOffset=False, style='plain')
plt.savefig("validate.png")
plt.clf()
list_times[["json", "orjson"]].plot(
title="Serialization Times (List of str)",
xlabel="Number of things",
ylabel="Seconds",
)
plt.savefig("list_serialization.png")
plt.clf()
many_objects_times[["json", "orjson"]].plot(
title="Serialization Times (List of objects)",
xlabel="Number of things",
ylabel="Seconds",
)
plt.savefig("many_objects_serialization.png")
plt.clf()
# Time spent percentages (Before vs After)
before_list_times = list_times[['jsonable_encoder', 'validate', 'json']]
after_list_times = list_times[['dict', 'orjson']]
before_list_times_normalized = before_list_times.divide(
before_list_times.sum(axis=1), axis=0
)
before_list_times_normalized.plot.area(
title="Time spent in marshaling phases (Before, List of str)",
xlabel="Number of things",
ylabel="Percentage",
)
plt.savefig("list_before_ratio.png")
plt.clf()
after_list_times_normalized = after_list_times.divide(
after_list_times.sum(axis=1), axis=0
)
after_list_times_normalized.plot.area(
title="Time spent in marshaling phases (After, List of str)",
xlabel="Number of things",
ylabel="Percentage",
)
plt.savefig("list_after_ratio.png")
plt.clf()
before_many_objects_times = many_objects_times[
['jsonable_encoder', 'validate', 'json']
]
after_many_objects_times = many_objects_times[['dict', 'orjson']]
before_many_objects_times_normalized = before_many_objects_times.divide(
before_many_objects_times.sum(axis=1), axis=0
)
before_many_objects_times_normalized.plot.area(
title="Time spent in marshaling phases (Before, List of objects)",
xlabel="Number of things",
ylabel="Percentage",
)
plt.savefig("many_objects_before_ratio.png")
plt.clf()
after_many_objects_times_normalized = after_many_objects_times.divide(
after_many_objects_times.sum(axis=1), axis=0
)
after_many_objects_times_normalized.plot.area(
title="Time spent in marshaling phases (After, List of objects)",
xlabel="Number of things",
ylabel="Percentage",
)
plt.savefig("many_objects_after_ratio.png")
plt.clf()
# Total time taken (Before vs After)
list_before_vs_after = pd.concat(
[before_list_times.sum(axis=1), after_list_times.sum(axis=1)],
axis=1,
keys=["before", "after"],
)
list_before_vs_after.plot(
title="Time spent in marshaling phases (List of str)",
xlabel="Number of things",
ylabel="Seconds",
)
plt.savefig("list_before_vs_after.png")
plt.clf()
many_objects_before_vs_after = pd.concat(
[before_many_objects_times.sum(axis=1), after_many_objects_times.sum(axis=1)],
axis=1,
keys=["before", "after"],
)
many_objects_before_vs_after.plot(
title="Time spent in marshaling phases (List of objects)",
xlabel="Number of things",
ylabel="Seconds",
)
plt.savefig("many_objects_before_vs_after.png")
plt.clf()
def time_ops(schema: BaseModel, model: Type[BaseModel]):
# Jsonable Encoder
start = time()
jsonable_encoded = jsonable_encoder(schema)
jsonable_encoder_time = time() - start
# .dict()
start = time()
dict_encoded = schema.dict()
dict_time = time() - start
assert jsonable_encoded == dict_encoded
response_field = create_response_field(name="Testing", type_=model)
# Validation
start = time()
response_field.validate(schema, {}, loc=("response"))
validate_time = time() - start
# OOTB json
start = time()
json.dumps(dict_encoded)
json_time = time() - start
# orjson
start = time()
orjson.dumps(dict_encoded)
orjson_time = time() - start
return {
"jsonable_encoder": jsonable_encoder_time,
"dict": dict_time,
"validate": validate_time,
"json": json_time,
"orjson": orjson_time,
}
if __name__ == "__main__":
benchmark()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment