Skip to content

Instantly share code, notes, and snippets.

@williamcaicedo
Created March 11, 2023 21:19
Show Gist options
  • Save williamcaicedo/87d9b53625f6326ff9dee1c53c6e86d1 to your computer and use it in GitHub Desktop.
Save williamcaicedo/87d9b53625f6326ff9dee1c53c6e86d1 to your computer and use it in GitHub Desktop.
Custom Kedro dataset for PySparkML objects
from kedro.extras.datasets.spark import SparkDataSet
from pyspark.ml import PipelineModel
class PySparkMLPipelineDataSet(SparkDataSet):
def _load(self) -> PipelineModel:
load_path = self._fs_prefix + str(self._get_load_path())
self._get_spark()
return PipelineModel.load(load_path)
def _save(self, pipeline: PipelineModel) -> None:
save_path = self._fs_prefix + str(self._get_save_path())
pipeline.write().overwrite().save(save_path)
def _exists(self) -> bool:
load_path = self._fs_prefix + str(self._get_load_path())
try:
PipelineModel.load(load_path)
except AnalysisException as exception:
if (
exception.desc.startswith("Path does not exist:")
):
return False
raise
return True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment