Last active
November 4, 2021 20:42
-
-
Save pabloalcain/de79938507ad2d823a866238b3c8a66e to your computer and use it in GitHub Desktop.
Dynamic DataFrame
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing | |
from pyspark.sql import DataFrame | |
T = typing.TypeVar("T", bound="DynamicDataFrame") | |
class DynamicDataFrame(DataFrame): | |
def __init__(self, df: DataFrame): | |
super().__init__(df._jdf, df.sql_ctx) | |
def alias(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().alias(*args, **kwargs)) | |
def checkpoint(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().checkpoint(*args, **kwargs)) | |
def coalesce(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().coalesce(*args, **kwargs)) | |
def crossJoin(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().crossJoin(*args, **kwargs)) | |
def distinct(self: T) -> T: | |
return self.__class__(super().distinct()) | |
def drop(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().drop(*args, **kwargs)) | |
def dropDuplicates(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().dropDuplicates(*args, **kwargs)) | |
def dropna(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().dropna(*args, **kwargs)) | |
def fillna(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().fillna(*args, **kwargs)) | |
def filter(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().filter(*args, **kwargs)) | |
def exceptAll(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().exceptAll(*args, **kwargs)) | |
def hint(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().hint(*args, **kwargs)) | |
def intersect(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().intersect(*args, **kwargs)) | |
def intersectAll(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().intersectAll(*args, **kwargs)) | |
def join(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().join(*args, **kwargs)) | |
def limit(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().limit(*args, **kwargs)) | |
def localCheckpoint(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().localCheckpoint(*args, **kwargs)) | |
def orderBy(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().orderBy(*args, **kwargs)) | |
def repartition(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().repartition(*args, **kwargs)) | |
def repartitionByRange(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().repartitionByRange(*args, **kwargs)) | |
def replace(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().replace(*args, **kwargs)) | |
def sample(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().sample(*args, **kwargs)) | |
def sampleBy(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().sampleBy(*args, **kwargs)) | |
def select(self: T, *args) -> T: | |
return self.__class__(super().select(*args)) | |
def selectExpr(self: T, *args) -> T: | |
return self.__class__(super().selectExpr(*args)) | |
def sort(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().sort(*args, **kwargs)) | |
def sortWithinPartitions(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().sortWithinPartitions(*args, **kwargs)) | |
def subtract(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().subtract(*args, **kwargs)) | |
def transform(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().transform(*args, **kwargs)) | |
def union(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().union(*args, **kwargs)) | |
def unionByName(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().unionByName(*args, **kwargs)) | |
def withColumn(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().withColumn(*args, **kwargs)) | |
def withColumnRenamed(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().withColumnRenamed(*args, **kwargs)) | |
def withWatermark(self: T, *args, **kwargs) -> T: | |
return self.__class__(super().withWatermark(*args, **kwargs)) | |
def randomSplit(self, *args, **kwargs) -> typing.List: | |
return [self.__class__(df) for df in super().randomSplit(*args, **kwargs)] | |
def toDataFrame(self) -> DataFrame: | |
return DataFrame(self._jdf, self.sql_ctx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing | |
import pyspark | |
from pyspark.sql import DataFrame | |
from pyspark.sql import functions as F | |
T = typing.TypeVar("T", bound="DynamicDataFrame") | |
class MyBusinessDataFrameVanilla(DataFrame): | |
def __init__(self, df: DataFrame): | |
super().__init__(df._jdf, df.sql_ctx) | |
def my_business_query(self, factor: float = 2.0): | |
return self.withColumn("price", F.col("price") * factor) | |
class DynamicDataFrame(DataFrame): | |
def __init__(self, df: DataFrame): | |
super().__init__(df._jdf, df.sql_ctx) | |
def select(self: T, *args) -> T: | |
return self.__class__(super().select(*args)) | |
def withColumn(self: T, *args) -> T: | |
return self.__class__(super().withColumn(*args)) | |
class MyBusinessDataFrame(DynamicDataFrame): | |
def my_business_query(self, factor: float = 2.0): | |
return self.withColumn("price", F.col("price") * factor) | |
spark = pyspark.sql.SparkSession.builder.getOrCreate() | |
base_dataframe = spark.createDataFrame( | |
data=[['product_1', 2], ['product_2', 4]], | |
schema=["name", "price"], | |
) | |
print("Doing a direct inheritance from DataFrame") | |
van_business = MyBusinessDataFrameVanilla(base_dataframe) | |
van_business_updated = van_business.my_business_query(2.0) | |
van_business_updated.show() | |
print("After one use of the query we have a plain dataframe again") | |
print(type(van_business_updated)) | |
# This would raise an AttributeError | |
# van_business_updated.my_business_query(5.0) | |
print("="*80) | |
print("Doing an inheritance mediated by DynamicDataFrame") | |
dyn_business = MyBusinessDataFrame(base_dataframe) | |
dyn_business_updated = dyn_business.my_business_query(2.0).my_business_query(5.0) | |
dyn_business_updated.show() | |
print("After multiple uses of the query we still have the desired type") | |
print(type(dyn_business_updated)) | |
print("And we can still use the usual dataframe methods") | |
dyn_business_updated.filter(F.col('price') > 25).show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Doing a direct inheritance from DataFrame | |
+---------+-----+ | |
| name|price| | |
+---------+-----+ | |
|product_1| 4.0| | |
|product_2| 8.0| | |
+---------+-----+ | |
After one use of the query we have a plain dataframe again | |
<class 'pyspark.sql.dataframe.DataFrame'> | |
================================================================================ | |
Doing an inheritance mediated by DynamicDataFrame | |
+---------+-----+ | |
| name|price| | |
+---------+-----+ | |
|product_1| 20.0| | |
|product_2| 40.0| | |
+---------+-----+ | |
After multiple uses of the query we still have the desired type | |
<class '__main__.MyBusinessDataFrame'> | |
And we can still use the usual dataframe methods | |
+---------+-----+ | |
| name|price| | |
+---------+-----+ | |
|product_2| 40.0| | |
+---------+-----+ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment