Skip to content

Instantly share code, notes, and snippets.

@okhaliavka
Created October 16, 2023 21:21
Show Gist options
  • Save okhaliavka/c9f06bcebfdd247c845e2135a3962be1 to your computer and use it in GitHub Desktop.
Save okhaliavka/c9f06bcebfdd247c845e2135a3962be1 to your computer and use it in GitHub Desktop.
Pydantic performance drop
from __future__ import annotations
import timeit
from statistics import mean, stdev
from typing import Any, Literal
from pydantic import BaseModel
from typing_extensions import TypedDict
class SimpleModel(BaseModel):
fruit: Literal['apple', 'banana']
name: str
surname: str
age: int
class RichTextTypeText(TypedDict):
nodes: list[RichTextType]
class RichTextType(TypedDict):
text: RichTextTypeText | str
class TableRowModel(BaseModel):
cells: list[RichTextType]
class TableBlockModel(BaseModel):
rows: list[TableRowModel]
def generate_cell(length: int, nodes_count: int) -> RichTextType:
return {
'text': {
'nodes': [{'text': 'a' * length} for _ in range(nodes_count)]
}
}
def generate_row(cols_count: int) -> dict[str, Any]:
return {'cells': [generate_cell(length=100, nodes_count=10) for _ in range(cols_count)]}
def generate_data(rows_count, cols_count) -> dict[str, Any]:
return {'rows': [generate_row(cols_count) for _ in range(rows_count)]}
complex_data = generate_data(rows_count=200, cols_count=100)
simple_data = {
'fruit': 'banana',
'name': 'Oleksii',
'surname': 'Khaliavka',
'age': 26,
}
timer_simple = timeit.Timer('SimpleModel.model_validate(simple_data)', globals=globals())
timer_complex = timeit.Timer('TableBlockModel.model_validate(complex_data)', globals=globals())
before = timer_simple.repeat(repeat=10, number=100000)
complex = timer_complex.timeit(number=10) # even a single invocation (number=1) here is enough to bork pydantic
after = timer_simple.repeat(repeat=10, number=100000)
print('min (before): {:.3f}s'.format(min(before)))
print('max (before): {:.3f}s'.format(max(before)))
print('mean (before): {:.3f}s'.format(mean(before)))
print('stdev (before): {:.3f}s'.format(stdev(before)))
print('complex: {:.3f}s'.format(complex))
print('min (after): {:.3f}s ({:.1f} times slower)'.format(min(after), min(after) / min(before)))
print('max (after): {:.3f}s ({:.1f} times slower)'.format(max(after), max(after) / max(before)))
print('mean (after): {:.3f}s ({:.1f} times slower)'.format(mean(after), mean(after) / mean(before)))
print('stdev (after): {:.3f}s'.format(stdev(after)))
"""
min (before): 0.086s
max (before): 0.086s
mean (before): 0.086s
stdev (before): 0.001s
complex: 0.631s
min (after): 0.705s (8.2 times slower)
max (after): 1.425s (16.5 times slower)
mean (after): 0.940s (10.9 times slower)
stdev (after): 0.265s
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment