Skip to content

Instantly share code, notes, and snippets.

@ymtricks
Created August 15, 2023 08:20
Show Gist options
  • Save ymtricks/8eda231ec7630634087ff5dee6138eea to your computer and use it in GitHub Desktop.
Save ymtricks/8eda231ec7630634087ff5dee6138eea to your computer and use it in GitHub Desktop.
Prefect serializer with async dumps/loads
from io import BytesIO
from typing import Literal, Any
import cloudpickle
import pandas as pd
import prefect
from prefect import flow, task
from prefect.context import TaskRunContext, FlowRunContext
from prefect.filesystems import LocalFileSystem, WritableFileSystem
from prefect.results import get_default_result_storage
from prefect.serializers import Serializer
from prefect.utilities.asyncutils import sync
def find_result_storage() -> WritableFileSystem:
task_run_ctx = TaskRunContext.get()
flow_run_ctx = FlowRunContext.get()
storage = None
if task_run_ctx is not None and task_run_ctx.task is not None:
storage = task_run_ctx.task.result_storage
if storage is None:
if flow_run_ctx is not None:
if flow_run_ctx.flow.result_storage is not None:
storage = flow_run_ctx.flow.result_storage
elif flow_run_ctx.result_factory is not None:
storage = flow_run_ctx.result_factory.storage_block
else:
storage = get_default_result_storage()
else:
storage = get_default_result_storage()
return storage
def replace_runtime_vars(template: str) -> str:
runtime_vars = {key: getattr(prefect.runtime, key) for key in dir(prefect.runtime)}
return template.format(**runtime_vars, parameters=prefect.runtime.task_run.parameters)
def find_result_location() -> str:
task_run_ctx = TaskRunContext.get()
if task_run_ctx is not None and task_run_ctx.task is not None:
location = task_run_ctx.task.result_storage_key
else:
location = "{flow_run.id}.json"
return replace_runtime_vars(location)
class FileReferenceSerializer(Serializer):
type: Literal["file_ref"] = "file_ref"
def dumps(self, obj) -> bytes:
return sync(self._async_dumps, obj)
@classmethod
async def _async_dumps(cls, obj) -> bytes:
storage = find_result_storage()
location = find_result_location()
buffer = BytesIO()
try:
if isinstance(obj, pd.DataFrame):
obj.to_parquet(buffer)
file_path = location.replace(".json", ".parquet")
else:
cloudpickle.dump(obj, buffer)
file_path = location.replace(".json", ".pickle")
await storage.write_path(path=file_path, content=buffer.getvalue())
finally:
buffer.close()
return file_path.encode()
def loads(self, blob: bytes) -> Any:
return sync(self._async_loads, blob)
@classmethod
async def _async_loads(cls, blob: bytes) -> pd.DataFrame:
storage = find_result_storage()
file_path = blob.decode()
content = await storage.read_path(file_path)
if file_path.endswith(".parquet"):
df = pd.read_parquet(BytesIO(content))
return df
elif file_path.endswith(".pickle"):
obj = cloudpickle.loads(content)
return obj
raise ValueError(f"Unsupported file type {file_path}")
@flow(log_prints=True, result_storage=LocalFileSystem(basepath="./results"))
def sample_flow():
df = sample_task()
print(len(df))
@task(
persist_result=True,
result_serializer=FileReferenceSerializer(),
result_storage_key=(
"{flow_run.flow_name}/"
"{flow_run.scheduled_start_time:%Y_%m_%d_%H_%M_%S}/"
"{task_run.task_name}/"
"{task_run.id}.json"
)
)
def sample_task() -> pd.DataFrame:
data = {
'Name': ['Alice', 'Bob', 'Charlie', 'David'],
'Age': [25, 30, 35, 40],
'City': ['New York', 'London', 'Paris', 'Berlin']
}
return pd.DataFrame(data)
if __name__ == "__main__":
sample_flow()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment