Skip to content

Instantly share code, notes, and snippets.

@BryanCutler
Last active March 14, 2018 05:47
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save BryanCutler/2d2ae04e81fa96ba4b61dc095726419f to your computer and use it in GitHub Desktop.
Save BryanCutler/2d2ae04e81fa96ba4b61dc095726419f to your computer and use it in GitHub Desktop.
Vectorized UDFs in Python SPARK-21190
class DataFrame(object):
...
def asPandas(self):
return ArrowDataFrame(self)
class ArrowDataFrame(object):
"""
Wraps a Python DataFrame to group/winow then apply using``pandas.DataFrame``
"""
def __init__(self, data_frame):
self.df = data_frame
self._lazy_rdd = None
@property
def rdd(self):
if self._lazy_rdd is None:
jrdd = self._jdf.javaToPython()
self._lazy_rdd = ArrowRDD(jrdd, self.df._sc)
return self._lazy_rdd
def groupBy(self, *cols):
jgd = self._jdf.groupBy(self._jcols(*cols))
return ArrowGroupedData(jgd, self.df.sql_ctx)
def windowOver(self, window_spec):
raise NotImplementedError()
class ArrowGroupedData(GroupedData):
"""
Wraps a Python GroupedData object to process groups as ``pandas.DataFrame``
"""
def __init__(self, jgd, sql_ctx):
super(ArrowGroupedData, self).__init__(jgd, sql_ctx)
def agg(self, f):
# Apply function f to each group
return DataFrame(...)
class ArrowRDD(object):
"""
Wraps a Python RDD to deserialize using Arrow into ``pandas.DataFrame`` for processing.
"""
def __init__(self, jrdd, ctx, pipelined_rdd=None):
if pipelined_rdd is None:
self._rdd = RDD(jrdd, ctx, jrdd_deserializer=ArrowPandasSerializer())
else:
self._rdd = pipelined_rdd
def _wrap_rdd(self, rdd):
rdd._jrdd_deserializer = self._rdd._jrdd_deserializer
return ArrowRDD(jrdd=None, ctx=None, pipelined_rdd=rdd)
def map(self, f, preservesPartitioning=False):
rdd = self._rdd.map(f, preservesPartitioning=preservesPartitioning)
return self._wrap_rdd(rdd)
def reduce(self, f):
return self._rdd.reduce(f)
def count(self):
return self._rdd.count()
def collect(self):
return self._rdd.collect()
def toDF(self):
schema = convert_arrow_schema()
return ArrowDataFrame(self.ctx.createDataFrame(self, schema))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment