Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save dheerajinampudi/5a0e83ca83125fe6035af27eb561b29c to your computer and use it in GitHub Desktop.
Save dheerajinampudi/5a0e83ca83125fe6035af27eb561b29c to your computer and use it in GitHub Desktop.
Convert a RDD of pandas DataFrames to a single Spark DataFrame using Arrow and without collecting all data in the driver.
import pandas as pd
def _dataframe_to_arrow_record_batch(pdf, schema=None, timezone=None, parallelism=1):
"""
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowSerializer, _create_batch
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType, Row, DataType, StringType, StructType
from pyspark.sql.utils import require_minimum_pandas_version, \
require_minimum_pyarrow_version
require_minimum_pandas_version()
require_minimum_pyarrow_version()
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
# Determine arrow types to coerce data when creating batches
if isinstance(schema, StructType):
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
elif isinstance(schema, DataType):
raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
else:
# Any timestamps must be coerced to be compatible with Spark
arrow_types = [to_arrow_type(TimestampType())
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
for t in pdf.dtypes]
# Slice the DataFrame to be batched
step = -(-len(pdf) // parallelism) # round int up
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
# Create Arrow record batches
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
timezone)
for pdf_slice in pdf_slices]
return map(bytearray, map(ArrowSerializer().dumps, batches))
def createFromPandasDataframesRDD(self, prdd, schema=None, timezone=None):
from pyspark.sql.types import from_arrow_schema
from pyspark.sql.dataframe import DataFrame
from pyspark.serializers import ArrowSerializer, PickleSerializer, AutoBatchedSerializer
# Map rdd of pandas dataframes to arrow record batches
prdd = prdd.filter(lambda x: isinstance(x, pd.DataFrame)).cache()
# If schema is not defined, get from the first dataframe
if schema is None:
schema = [str(x) if not isinstance(x, basestring) else
(x.encode('utf-8') if not isinstance(x, str) else x)
for x in prdd.map(lambda x: x.columns).first()]
prdd = prdd.flatMap(lambda x: _dataframe_to_arrow_record_batch(x, schema=schema, timezone=timezone))
# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
struct = from_arrow_schema(ArrowSerializer().loads(prdd.first()).schema)
for i, name in enumerate(schema):
struct.fields[i].name = name
struct.names[i] = name
schema = struct
# Create the Spark DataFrame directly from the Arrow data and schema
jrdd = prdd._to_java_object_rdd()
jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
jrdd, schema.json(), self._wrapped._jsqlContext)
df = DataFrame(jdf, self._wrapped)
df._schema = schema
return df
from pyspark.sql import SparkSession
SparkSession.createFromPandasDataframesRDD = createFromPandasDataframesRDD
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment