Skip to content

Instantly share code, notes, and snippets.

@htahir1
Last active November 3, 2022 10:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save htahir1/8e484f7dc55c9bcc5ee0e0fd6addbb49 to your computer and use it in GitHub Desktop.
Save htahir1/8e484f7dc55c9bcc5ee0e0fd6addbb49 to your computer and use it in GitHub Desktop.
How to create a custom materializer that reads and writes pandas dataframes in CSV format with ZenML
"""Materializer for Pandas CSV."""
import os
import tempfile
from typing import Any, Type, Union
import pandas as pd
import numpy as np
from zenml.artifacts import DataArtifact, SchemaArtifact, StatisticsArtifact
from zenml.io import fileio
from zenml.materializers.base_materializer import BaseMaterializer
DEFAULT_FILENAME = "df.csv"
class MyPandasMaterializer(BaseMaterializer):
"""Materializer to read data to and from pandas."""
ASSOCIATED_TYPES = (pd.DataFrame, pd.Series)
ASSOCIATED_ARTIFACT_TYPES = (
DataArtifact,
StatisticsArtifact,
SchemaArtifact,
)
def handle_input(
self, data_type: Type[Any]
) -> Union[pd.DataFrame, pd.Series]:
"""Reads pd.DataFrame or pd.Series from a parquet file.
Args:
data_type: The type of the data to read.
Returns:
The pandas dataframe or series.
"""
super().handle_input(data_type)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
# Create a temporary folder
temp_dir = tempfile.mkdtemp(prefix="zenml-temp-")
temp_file = os.path.join(str(temp_dir), DEFAULT_FILENAME)
# Copy from artifact store to temporary file
fileio.copy(filepath, temp_file)
# Load the model from the temporary file
df = pd.read_csv(temp_file)
# Cleanup and return
fileio.rmtree(temp_dir)
if issubclass(data_type, pd.Series):
# Taking the first column if its a series as the assumption
# is that there will only be one
assert len(df.columns) == 1
df = df[df.columns[0]]
return df
def handle_return(self, df: Union[pd.DataFrame, pd.Series]) -> None:
"""Writes a pandas dataframe or series to the specified filename.
Args:
df: The pandas dataframe or series to write.
"""
super().handle_return(df)
filepath = os.path.join(self.artifact.uri, DEFAULT_FILENAME)
if isinstance(df, pd.Series):
df = df.to_frame(name="series")
# Create a temporary file to store the model
with tempfile.NamedTemporaryFile(
mode="w", suffix=".gzip", delete=False
) as f:
df.to_csv(f.name)
fileio.copy(f.name, filepath)
# Close and remove the temporary file
f.close()
fileio.remove(f.name)
# How to use in a zenml pipeline: https://docs.zenml.io/advanced-guide/pipelines/materializers
import logging
from zenml.steps import step
from zenml.pipelines import pipeline
@step
def my_first_step() -> pd.DataFrame:
"""Step that returns an object of type pd.DataFrame"""
df = pd.DataFrame(np.random.randint(0,100,size=(100, 4)), columns=list('ABCD'))
return df
@step
def my_second_step(my_obj: pd.DataFrame) -> None:
"""Step that logs the input object and returns nothing."""
logging.info(
f"The following dataframe was passed to this step: `{my_obj}`"
)
@pipeline
def first_pipeline(step_1, step_2):
output_1 = step_1()
step_2(output_1)
first_pipeline(
step_1=my_first_step().configure(output_materializers=MyPandasMaterializer),
step_2=my_second_step()
).run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment