Skip to content

Instantly share code, notes, and snippets.

@pabloalcain
Last active November 4, 2021 20:42
Show Gist options
  • Save pabloalcain/de79938507ad2d823a866238b3c8a66e to your computer and use it in GitHub Desktop.
Save pabloalcain/de79938507ad2d823a866238b3c8a66e to your computer and use it in GitHub Desktop.
Dynamic DataFrame
import typing
from pyspark.sql import DataFrame
T = typing.TypeVar("T", bound="DynamicDataFrame")
class DynamicDataFrame(DataFrame):
def __init__(self, df: DataFrame):
super().__init__(df._jdf, df.sql_ctx)
def alias(self: T, *args, **kwargs) -> T:
return self.__class__(super().alias(*args, **kwargs))
def checkpoint(self: T, *args, **kwargs) -> T:
return self.__class__(super().checkpoint(*args, **kwargs))
def coalesce(self: T, *args, **kwargs) -> T:
return self.__class__(super().coalesce(*args, **kwargs))
def crossJoin(self: T, *args, **kwargs) -> T:
return self.__class__(super().crossJoin(*args, **kwargs))
def distinct(self: T) -> T:
return self.__class__(super().distinct())
def drop(self: T, *args, **kwargs) -> T:
return self.__class__(super().drop(*args, **kwargs))
def dropDuplicates(self: T, *args, **kwargs) -> T:
return self.__class__(super().dropDuplicates(*args, **kwargs))
def dropna(self: T, *args, **kwargs) -> T:
return self.__class__(super().dropna(*args, **kwargs))
def fillna(self: T, *args, **kwargs) -> T:
return self.__class__(super().fillna(*args, **kwargs))
def filter(self: T, *args, **kwargs) -> T:
return self.__class__(super().filter(*args, **kwargs))
def exceptAll(self: T, *args, **kwargs) -> T:
return self.__class__(super().exceptAll(*args, **kwargs))
def hint(self: T, *args, **kwargs) -> T:
return self.__class__(super().hint(*args, **kwargs))
def intersect(self: T, *args, **kwargs) -> T:
return self.__class__(super().intersect(*args, **kwargs))
def intersectAll(self: T, *args, **kwargs) -> T:
return self.__class__(super().intersectAll(*args, **kwargs))
def join(self: T, *args, **kwargs) -> T:
return self.__class__(super().join(*args, **kwargs))
def limit(self: T, *args, **kwargs) -> T:
return self.__class__(super().limit(*args, **kwargs))
def localCheckpoint(self: T, *args, **kwargs) -> T:
return self.__class__(super().localCheckpoint(*args, **kwargs))
def orderBy(self: T, *args, **kwargs) -> T:
return self.__class__(super().orderBy(*args, **kwargs))
def repartition(self: T, *args, **kwargs) -> T:
return self.__class__(super().repartition(*args, **kwargs))
def repartitionByRange(self: T, *args, **kwargs) -> T:
return self.__class__(super().repartitionByRange(*args, **kwargs))
def replace(self: T, *args, **kwargs) -> T:
return self.__class__(super().replace(*args, **kwargs))
def sample(self: T, *args, **kwargs) -> T:
return self.__class__(super().sample(*args, **kwargs))
def sampleBy(self: T, *args, **kwargs) -> T:
return self.__class__(super().sampleBy(*args, **kwargs))
def select(self: T, *args) -> T:
return self.__class__(super().select(*args))
def selectExpr(self: T, *args) -> T:
return self.__class__(super().selectExpr(*args))
def sort(self: T, *args, **kwargs) -> T:
return self.__class__(super().sort(*args, **kwargs))
def sortWithinPartitions(self: T, *args, **kwargs) -> T:
return self.__class__(super().sortWithinPartitions(*args, **kwargs))
def subtract(self: T, *args, **kwargs) -> T:
return self.__class__(super().subtract(*args, **kwargs))
def transform(self: T, *args, **kwargs) -> T:
return self.__class__(super().transform(*args, **kwargs))
def union(self: T, *args, **kwargs) -> T:
return self.__class__(super().union(*args, **kwargs))
def unionByName(self: T, *args, **kwargs) -> T:
return self.__class__(super().unionByName(*args, **kwargs))
def withColumn(self: T, *args, **kwargs) -> T:
return self.__class__(super().withColumn(*args, **kwargs))
def withColumnRenamed(self: T, *args, **kwargs) -> T:
return self.__class__(super().withColumnRenamed(*args, **kwargs))
def withWatermark(self: T, *args, **kwargs) -> T:
return self.__class__(super().withWatermark(*args, **kwargs))
def randomSplit(self, *args, **kwargs) -> typing.List:
return [self.__class__(df) for df in super().randomSplit(*args, **kwargs)]
def toDataFrame(self) -> DataFrame:
return DataFrame(self._jdf, self.sql_ctx)
import typing
import pyspark
from pyspark.sql import DataFrame
from pyspark.sql import functions as F
T = typing.TypeVar("T", bound="DynamicDataFrame")
class MyBusinessDataFrameVanilla(DataFrame):
def __init__(self, df: DataFrame):
super().__init__(df._jdf, df.sql_ctx)
def my_business_query(self, factor: float = 2.0):
return self.withColumn("price", F.col("price") * factor)
class DynamicDataFrame(DataFrame):
def __init__(self, df: DataFrame):
super().__init__(df._jdf, df.sql_ctx)
def select(self: T, *args) -> T:
return self.__class__(super().select(*args))
def withColumn(self: T, *args) -> T:
return self.__class__(super().withColumn(*args))
class MyBusinessDataFrame(DynamicDataFrame):
def my_business_query(self, factor: float = 2.0):
return self.withColumn("price", F.col("price") * factor)
spark = pyspark.sql.SparkSession.builder.getOrCreate()
base_dataframe = spark.createDataFrame(
data=[['product_1', 2], ['product_2', 4]],
schema=["name", "price"],
)
print("Doing a direct inheritance from DataFrame")
van_business = MyBusinessDataFrameVanilla(base_dataframe)
van_business_updated = van_business.my_business_query(2.0)
van_business_updated.show()
print("After one use of the query we have a plain dataframe again")
print(type(van_business_updated))
# This would raise an AttributeError
# van_business_updated.my_business_query(5.0)
print("="*80)
print("Doing an inheritance mediated by DynamicDataFrame")
dyn_business = MyBusinessDataFrame(base_dataframe)
dyn_business_updated = dyn_business.my_business_query(2.0).my_business_query(5.0)
dyn_business_updated.show()
print("After multiple uses of the query we still have the desired type")
print(type(dyn_business_updated))
print("And we can still use the usual dataframe methods")
dyn_business_updated.filter(F.col('price') > 25).show()
Doing a direct inheritance from DataFrame
+---------+-----+
| name|price|
+---------+-----+
|product_1| 4.0|
|product_2| 8.0|
+---------+-----+
After one use of the query we have a plain dataframe again
<class 'pyspark.sql.dataframe.DataFrame'>
================================================================================
Doing an inheritance mediated by DynamicDataFrame
+---------+-----+
| name|price|
+---------+-----+
|product_1| 20.0|
|product_2| 40.0|
+---------+-----+
After multiple uses of the query we still have the desired type
<class '__main__.MyBusinessDataFrame'>
And we can still use the usual dataframe methods
+---------+-----+
| name|price|
+---------+-----+
|product_2| 40.0|
+---------+-----+
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment