Last active
November 3, 2022 10:36
-
-
Save htahir1/8e484f7dc55c9bcc5ee0e0fd6addbb49 to your computer and use it in GitHub Desktop.
How to create a custom materializer that reads and writes pandas dataframes in CSV format with ZenML
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Materializer for Pandas CSV.""" | |
import os | |
import tempfile | |
from typing import Any, Type, Union | |
import pandas as pd | |
import numpy as np | |
from zenml.artifacts import DataArtifact, SchemaArtifact, StatisticsArtifact | |
from zenml.io import fileio | |
from zenml.materializers.base_materializer import BaseMaterializer | |
DEFAULT_FILENAME = "df.csv" | |
class MyPandasMaterializer(BaseMaterializer): | |
"""Materializer to read data to and from pandas.""" | |
ASSOCIATED_TYPES = (pd.DataFrame, pd.Series) | |
ASSOCIATED_ARTIFACT_TYPES = ( | |
DataArtifact, | |
StatisticsArtifact, | |
SchemaArtifact, | |
) | |
def handle_input( | |
self, data_type: Type[Any] | |
) -> Union[pd.DataFrame, pd.Series]: | |
"""Reads pd.DataFrame or pd.Series from a parquet file. | |
Args: | |
data_type: The type of the data to read. | |
Returns: | |
The pandas dataframe or series. | |
""" | |
super().handle_input(data_type) | |
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME) | |
# Create a temporary folder | |
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-") | |
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME) | |
# Copy from artifact store to temporary file | |
fileio.copy(filepath, temp_file) | |
# Load the model from the temporary file | |
df = pd.read_csv(temp_file) | |
# Cleanup and return | |
fileio.rmtree(temp_dir) | |
if issubclass(data_type, pd.Series): | |
# Taking the first column if its a series as the assumption | |
# is that there will only be one | |
assert len(df.columns) == 1 | |
df = df[df.columns[0]] | |
return df | |
def handle_return(self, df: Union[pd.DataFrame, pd.Series]) -> None: | |
"""Writes a pandas dataframe or series to the specified filename. | |
Args: | |
df: The pandas dataframe or series to write. | |
""" | |
super().handle_return(df) | |
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME) | |
if isinstance(df, pd.Series): | |
df = df.to_frame(name="series") | |
# Create a temporary file to store the model | |
with tempfile.NamedTemporaryFile( | |
mode="w", suffix=".gzip", delete=False | |
) as f: | |
df.to_csv(f.name) | |
fileio.copy(f.name, filepath) | |
# Close and remove the temporary file | |
f.close() | |
fileio.remove(f.name) | |
# How to use in a zenml pipeline: https://docs.zenml.io/advanced-guide/pipelines/materializers | |
import logging | |
from zenml.steps import step | |
from zenml.pipelines import pipeline | |
@step | |
def my_first_step() -> pd.DataFrame: | |
"""Step that returns an object of type pd.DataFrame""" | |
df = pd.DataFrame(np.random.randint(0,100,size=(100, 4)), columns=list('ABCD')) | |
return df | |
@step | |
def my_second_step(my_obj: pd.DataFrame) -> None: | |
"""Step that logs the input object and returns nothing.""" | |
logging.info( | |
f"The following dataframe was passed to this step: `{my_obj}`" | |
) | |
@pipeline | |
def first_pipeline(step_1, step_2): | |
output_1 = step_1() | |
step_2(output_1) | |
first_pipeline( | |
step_1=my_first_step().configure(output_materializers=MyPandasMaterializer), | |
step_2=my_second_step() | |
).run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment