Skip to content

Instantly share code, notes, and snippets.

@sam-goodwin
Last active April 2, 2024 01:06
Show Gist options
  • Save sam-goodwin/85c44d0241f6848e4a183a39c1abfb58 to your computer and use it in GitHub Desktop.
Save sam-goodwin/85c44d0241f6848e4a183a39c1abfb58 to your computer and use it in GitHub Desktop.
from typing import Any, Optional, get_args, get_origin
import numpy as np
import pandas as pd
import pandera.typing as pdt
import pyarrow as pa
from pandera import DataFrameModel, Index, MultiIndex, dtypes
from pandera.engines import numpy_engine, pandas_engine
from pandera.engines.engine import _is_namedtuple, _is_typeddict
from pandera.typing.common import SeriesBase
# see: https://github.com/unionai-oss/pandera/issues/689 - "generate pyarrow schema from pandera schema"
# forked from (un-merged): https://github.com/unionai-oss/pandera/pull/1047
def to_pyarrow_schema(
model: type[DataFrameModel],
preserve_index: Optional[bool] = None,
) -> pa.Schema:
"""Convert a :class:`~pandera.schemas.DataFrameSchema` to `pa.Schema`.
:param dataframe_schema: schema to convert to `pa.Schema`
:param preserve_index: whether to store the index as an additional column
(or columns, for MultiIndex) in the resulting Table. The default of
None will store the index as a column, except for RangeIndex which is
stored as metadata only. Use `preserve_index=True` to force it to be
stored as a column.
:returns: `pa.Schema` representation of DataFrameSchema
"""
dataframe_schema = model.to_schema()
# List of columns that will be present in the pyarrow schema
columns: dict[str, SeriesBase] = dataframe_schema.columns # type: ignore[assignment]
# pyarrow schema metadata
metadata: dict[str, bytes] = {}
index = dataframe_schema.index
if index is None:
if preserve_index:
# Create column for RangeIndex
name = _get_index_name(0)
columns[name] = Index(dtypes.Int64, nullable=False, name=name)
else:
# Only preserve metadata of index
meta_val = b'[{"kind": "range", "name": pa.null, "step": 1}]'
metadata["index_columns"] = meta_val
elif preserve_index is not False:
# Add column(s) for index(es)
if isinstance(index, Index):
name = index.name or _get_index_name(0)
# Ensure index is added at dictionary beginning
columns = {name: index, **columns}
elif isinstance(index, MultiIndex):
for i, value in enumerate(reversed(index.indexes)):
name = value.name or _get_index_name(i)
columns = {name: value, **columns}
return pa.schema(
[to_pyarrow_field(k, v) for k, v in columns.items()],
metadata=metadata,
)
pandas_types = {
pd.BooleanDtype(): pa.bool_(),
pd.Int8Dtype(): pa.int8(),
pd.Int16Dtype(): pa.int16(),
pd.Int32Dtype(): pa.int32(),
pd.Int64Dtype(): pa.int64(),
pd.UInt8Dtype(): pa.uint8(),
pd.UInt16Dtype(): pa.uint16(),
pd.UInt32Dtype(): pa.uint32(),
pd.UInt64Dtype(): pa.uint64(),
pd.Float32Dtype(): pa.float32(), # type: ignore[attr-defined]
pd.Float64Dtype(): pa.float64(), # type: ignore[attr-defined]
pd.StringDtype(): pa.string(),
}
def to_pyarrow_field(
name: str,
pandera_field: SeriesBase,
) -> pa.Field:
"""Convert a :class:`~pandera.schema_components.SeriesBase` to a `pa.Field`
:param pandera_field: pandera Index or Column
:returns: `pa.Field` representation of `pandera_field`
"""
pandera_dtype = pandera_field.dtype
pyarrow_type = to_pyarrow_type(pandera_dtype)
return pa.field(name, pyarrow_type, pandera_field.nullable)
def to_pyarrow_type(pandera_dtype: Any) -> pa.DataType:
"""Convert a :class:`~pandera.schema_components.DataType` to a `pa.DataType`
:param pandera_dtype: pandera DataType
:returns: `pa.DataType` representation of `pandera_dtype`
"""
pandas_dtype = pandas_engine.Engine.dtype(pandera_dtype)
pandas_dtype_type = pandera_dtype.type
# if issubclass(pandas_dtype, pd.Int16Dtype()):
# pass
# if pandas_dtype in pandas_types:
# return pandas_types[pandera_dtype.type]
if isinstance(pandas_dtype_type, pandas_engine.Date | numpy_engine.DateTime64):
return pa.date64()
if isinstance(pandas_dtype_type, dtypes.Category):
# Categorical data types
return pa.dictionary(
pa.int8(),
pandera_dtype.type.categories.inferred_type,
ordered=pandera_dtype.ordered, # type: ignore[attr-defined]
)
# if (pandera_field)
if isinstance(pandas_dtype_type, type):
return type_to_arrow(pandera_dtype.generic_type)
if isinstance(hasattr(pandera_dtype, "special_type") and pandera_dtype.special_type, type):
return type_to_arrow(pandera_dtype.special_type)
if pandas_dtype.type == np.object_:
return pa.string()
return pa.from_numpy_dtype(pandas_dtype_type)
def type_to_arrow(python_type: type) -> pa.DataType:
if python_type is str:
return pa.string()
elif python_type is int:
return pa.int64()
elif python_type is float:
return pa.float64()
elif python_type is bool:
return pa.bool_()
# pandera types
elif python_type is pdt.UInt8:
return pa.uint8()
elif python_type is pdt.UInt16:
return pa.uint16()
elif python_type is pdt.UInt32:
return pa.uint32()
elif python_type is pdt.UInt64:
return pa.uint64()
elif python_type is pdt.Int8:
return pa.int8()
elif python_type is pdt.Int16:
return pa.int16()
elif python_type is pdt.Int32:
return pa.int32()
elif python_type is pdt.Int64:
return pa.int64()
elif python_type is pdt.Float32:
return pa.float32()
elif python_type is pdt.Float64:
return pa.float64()
elif python_type is pdt.String:
return pa.string()
elif python_type is pdt.Bool:
return pa.bool_()
# TODO: don't know what date32 is
# elif python_type is pd.DateOffset:
# return pa.date32()
elif python_type is pd.Timestamp:
return pa.timestamp("ns")
elif python_type is pd.Timedelta:
return pa.duration("ns")
elif python_type is pd.Categorical:
return pa.dictionary(pa.int8(), pa.string())
elif python_type is pd.Interval or python_type is pd.Period or python_type is pd.Interval:
return pa.duration("ns")
# numpy types
elif python_type is np.datetime64:
return pa.timestamp("ns")
elif python_type is np.timedelta64:
return pa.duration("ns")
elif python_type is np.int8:
return pa.int8()
elif python_type is np.int16:
return pa.int16()
elif python_type is np.int32:
return pa.int32()
elif python_type is np.int64:
return pa.int64()
elif python_type is np.uint8:
return pa.uint8()
elif python_type is np.uint16:
return pa.uint16()
elif python_type is np.uint32:
return pa.uint32()
elif python_type is np.uint64:
return pa.uint64()
elif python_type is np.float32:
return pa.float32()
elif python_type is np.float64:
return pa.float64()
elif python_type is np.bool_:
return pa.bool_()
elif python_type is np.object_:
# TODO: is this right?
return pa.string()
elif get_origin(python_type) is list:
return pa.list_(type_to_arrow(get_args(python_type)[0]))
elif _is_namedtuple(python_type):
annotations = python_type.__annotations__.items()
fields = []
fields = [
pa.field(
key,
type_to_arrow(value),
# TODO(sgoodwin): determine this based on the type of Optional[T] or T | None
nullable=False,
)
for key, value in annotations
]
return pa.struct(fields)
elif _is_typeddict(python_type):
annotations = python_type.__annotations__.items()
error = f"Unsupported type: {python_type}"
raise TypeError(error)
def _get_index_name(level: int) -> str:
"""Generate an index name for pyarrow if none is specified"""
return f"__index_level_{level}__"
import os
from typing import List, NamedTuple
import pandera.typing as pdt
import pyarrow
from pandera.typing import Series
from noetik_pipeline_methods.io.pandera import to_pyarrow_schema
os.environ["PYARROW_IGNORE_TIMEZONE"] = "1"
import pandera as pa
class TodoItem(NamedTuple):
name: str
priority: int
# these aren't supported by pandera's to_schema()
# np_uint8: np.uint8
pd_uint8: pdt.UInt8
class TodoList(pa.DataFrameModel):
# TODO(sgoodwin): remove the redundant Series wrapper when fixed:
# https://github.com/unionai-oss/pandera/issues/1546
bool_: Series[pdt.Bool] = pa.Field()
bool_list: Series[list[pdt.Bool]] = pa.Field()
float32_list: Series[list[pdt.Float32]] = pa.Field()
float32: Series[pdt.Float32] = pa.Field()
float64_list: Series[list[pdt.Float64]] = pa.Field()
float64: Series[pdt.Float64] = pa.Field()
int_list: Series[list[int]] = pa.Field()
int16_List: Series[List[pdt.Int16]] = pa.Field()
int16: Series[pdt.Int16] = pa.Field()
int32_list: Series[list[pdt.Int32]] = pa.Field()
int32: Series[pdt.Int32] = pa.Field()
int64_list: Series[list[pdt.Int64]] = pa.Field()
int64: Series[pdt.Int64] = pa.Field()
int8_list: Series[list[pdt.Int8]] = pa.Field()
int8: Series[pdt.Int8] = pa.Field()
str_list: Series[list[str]] = pa.Field()
string_list: Series[list[pdt.String]] = pa.Field()
string: Series[pdt.String] = pa.Field()
uint16_list: Series[list[pdt.UInt16]] = pa.Field()
uint16: Series[pdt.UInt16] = pa.Field()
uint32_list: Series[list[pdt.UInt32]] = pa.Field()
uint32: Series[pdt.UInt32] = pa.Field()
uint64_list: Series[list[pdt.UInt64]] = pa.Field()
uint64: Series[pdt.UInt64] = pa.Field()
uint8_list: Series[list[pdt.UInt8]] = pa.Field()
uint8: Series[pdt.UInt8] = pa.Field()
named_tuple: Series[TodoItem] = pa.Field()
def test_to_arrow():
schema = to_pyarrow_schema(TodoList)
expected_schema = pyarrow.schema(
[
pyarrow.field("bool_", pyarrow.bool_(), nullable=False),
pyarrow.field("bool_list", pyarrow.list_(pyarrow.bool_()), nullable=False),
pyarrow.field("float32_list", pyarrow.list_(pyarrow.float32()), nullable=False),
pyarrow.field("float32", pyarrow.float32(), nullable=False),
pyarrow.field("float64_list", pyarrow.list_(pyarrow.float64()), nullable=False),
pyarrow.field("float64", pyarrow.float64(), nullable=False),
pyarrow.field("int_list", pyarrow.list_(pyarrow.int64()), nullable=False),
pyarrow.field("int16_List", pyarrow.list_(pyarrow.int16()), nullable=False),
pyarrow.field("int16", pyarrow.int16(), nullable=False),
pyarrow.field("int32_list", pyarrow.list_(pyarrow.int32()), nullable=False),
pyarrow.field("int32", pyarrow.int32(), nullable=False),
pyarrow.field("int64_list", pyarrow.list_(pyarrow.int64()), nullable=False),
pyarrow.field("int64", pyarrow.int64(), nullable=False),
pyarrow.field("int8_list", pyarrow.list_(pyarrow.int8()), nullable=False),
pyarrow.field("int8", pyarrow.int8(), nullable=False),
pyarrow.field("str_list", pyarrow.list_(pyarrow.string()), nullable=False),
pyarrow.field("string_list", pyarrow.list_(pyarrow.string()), nullable=False),
pyarrow.field("string", pyarrow.string(), nullable=False),
pyarrow.field("uint16_list", pyarrow.list_(pyarrow.uint16()), nullable=False),
pyarrow.field("uint16", pyarrow.uint16(), nullable=False),
pyarrow.field("uint32_list", pyarrow.list_(pyarrow.uint32()), nullable=False),
pyarrow.field("uint32", pyarrow.uint32(), nullable=False),
pyarrow.field("uint64_list", pyarrow.list_(pyarrow.uint64()), nullable=False),
pyarrow.field("uint64", pyarrow.uint64(), nullable=False),
pyarrow.field("uint8_list", pyarrow.list_(pyarrow.uint8()), nullable=False),
pyarrow.field("uint8", pyarrow.uint8(), nullable=False),
pyarrow.field(
"named_tuple",
pyarrow.struct(
[
pyarrow.field("name", pyarrow.string(), nullable=False),
pyarrow.field("priority", pyarrow.int64(), nullable=False),
pyarrow.field("pd_uint8", pyarrow.uint8(), nullable=False),
],
),
nullable=False,
),
],
)
assert schema == expected_schema, "Generated schema does not match expected schema"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment