-
-
Save linar-jether/7dd61ed6fa89098ab9c58a1ab428b2b5 to your computer and use it in GitHub Desktop.
import pandas as pd | |
def _dataframe_to_arrow_record_batch(pdf, schema=None, timezone=None, parallelism=1): | |
""" | |
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting | |
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the | |
data types will be used to coerce the data in Pandas to Arrow conversion. | |
""" | |
from pyspark.serializers import ArrowSerializer, _create_batch | |
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType, Row, DataType, StringType, StructType | |
from pyspark.sql.utils import require_minimum_pandas_version, \ | |
require_minimum_pyarrow_version | |
require_minimum_pandas_version() | |
require_minimum_pyarrow_version() | |
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype | |
# Determine arrow types to coerce data when creating batches | |
if isinstance(schema, StructType): | |
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields] | |
elif isinstance(schema, DataType): | |
raise ValueError("Single data type %s is not supported with Arrow" % str(schema)) | |
else: | |
# Any timestamps must be coerced to be compatible with Spark | |
arrow_types = [to_arrow_type(TimestampType()) | |
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None | |
for t in pdf.dtypes] | |
# Slice the DataFrame to be batched | |
step = -(-len(pdf) // parallelism) # round int up | |
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step)) | |
# Create Arrow record batches | |
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)], | |
timezone) | |
for pdf_slice in pdf_slices] | |
return map(bytearray, map(ArrowSerializer().dumps, batches)) | |
def createFromPandasDataframesRDD(self, prdd, schema=None, timezone=None): | |
from pyspark.sql.types import from_arrow_schema | |
from pyspark.sql.dataframe import DataFrame | |
from pyspark.serializers import ArrowSerializer, PickleSerializer, AutoBatchedSerializer | |
# Map rdd of pandas dataframes to arrow record batches | |
prdd = prdd.filter(lambda x: isinstance(x, pd.DataFrame)).cache() | |
# If schema is not defined, get from the first dataframe | |
if schema is None: | |
schema = [str(x) if not isinstance(x, basestring) else | |
(x.encode('utf-8') if not isinstance(x, str) else x) | |
for x in prdd.map(lambda x: x.columns).first()] | |
prdd = prdd.flatMap(lambda x: _dataframe_to_arrow_record_batch(x, schema=schema, timezone=timezone)) | |
# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing) | |
struct = from_arrow_schema(ArrowSerializer().loads(prdd.first()).schema) | |
for i, name in enumerate(schema): | |
struct.fields[i].name = name | |
struct.names[i] = name | |
schema = struct | |
# Create the Spark DataFrame directly from the Arrow data and schema | |
jrdd = prdd._to_java_object_rdd() | |
jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame( | |
jrdd, schema.json(), self._wrapped._jsqlContext) | |
df = DataFrame(jdf, self._wrapped) | |
df._schema = schema | |
return df | |
from pyspark.sql import SparkSession | |
SparkSession.createFromPandasDataframesRDD = createFromPandasDataframesRDD |
Hi @tahashmi, this seems to work for me:
Issue was using flatMap on the record batch, causing it to iterate on arrays
from pyspark.sql import SparkSession
import pyspark
import pyarrow as pa
from pyspark.serializers import ArrowSerializer
def _arrow_record_batch_dumps(rb):
# Fix for interoperability between pyarrow version >=0.15 and Spark's arrow version
# Streaming message protocol has changed, remove setting when upgrading spark.
import os
os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = '1'
return bytearray(ArrowSerializer().dumps(rb))
def rb_return(ardd):
data = [
pa.array(range(5), type='int16'),
pa.array([-10, -5, 0, None, 10], type='int32')
]
schema = pa.schema([pa.field('c0', pa.int16()),
pa.field('c1', pa.int32())],
metadata={b'foo': b'bar'})
return pa.RecordBatch.from_arrays(data, schema=schema)
if __name__ == '__main__':
spark = SparkSession \
.builder \
.appName("Python Arrow-in-Spark example") \
.getOrCreate()
# Enable Arrow-based columnar data transfers
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
sc = spark.sparkContext
ardd = spark.sparkContext.parallelize([0,1,2], 3)
ardd = ardd.map(rb_return)
from pyspark.sql.types import from_arrow_schema
from pyspark.sql.dataframe import DataFrame
from pyspark.serializers import ArrowSerializer, PickleSerializer, AutoBatchedSerializer
# Filter out and cache arrow record batches
ardd = ardd.filter(lambda x: isinstance(x, pa.RecordBatch)).cache()
ardd = ardd.map(_arrow_record_batch_dumps)
schema = pa.schema([pa.field('c0', pa.int16()),
pa.field('c1', pa.int32())],
metadata={b'foo': b'bar'})
schema = from_arrow_schema(schema)
jrdd = ardd._to_java_object_rdd()
jdf = spark._jvm.PythonSQLUtils.arrowPayloadToDataFrame(jrdd, schema.json(), spark._wrapped._jsqlContext)
df = DataFrame(jdf, spark._wrapped)
df._schema = schema
df.show()
Hi @linar-jether, Thank you so much for your time. :-)
I have just checked, it works perfectly.
Hi @linar-jether, I got stuck again when migrating the above code to Spark-3.0.0.
Can you please help me? I created a JIRA few days back explaining the whole issue. Thanks.
Hi @tahashmi, Can be done using:
Also this PR (SPARK-32846) might be useful as it uses user-facing APIs (but requires conversion to a pandas rdd)
from pyspark.sql import SparkSession
import pyarrow as pa
def _arrow_record_batch_dumps(rb):
return bytearray(rb.serialize())
def rb_return(ardd):
data = [
pa.array(range(5), type='int16'),
pa.array([-10, -5, 0, None, 10], type='int32')
]
schema = pa.schema([pa.field('c0', pa.int16()),
pa.field('c1', pa.int32())],
metadata={b'foo': b'bar'})
return pa.RecordBatch.from_arrays(data, schema=schema)
if __name__ == '__main__':
spark = SparkSession \
.builder \
.appName("Python Arrow-in-Spark example") \
.getOrCreate()
# Enable Arrow-based columnar data transfers
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
sc = spark.sparkContext
ardd = spark.sparkContext.parallelize([0, 1, 2], 3)
ardd = ardd.map(rb_return)
from pyspark.sql.pandas.types import from_arrow_schema
from pyspark.sql.dataframe import DataFrame
# Filter out and cache arrow record batches
ardd = ardd.filter(lambda x: isinstance(x, pa.RecordBatch)).cache()
ardd = ardd.map(_arrow_record_batch_dumps)
schema = pa.schema([pa.field('c0', pa.int16()),
pa.field('c1', pa.int32())],
metadata={b'foo': b'bar'})
schema = from_arrow_schema(schema)
jrdd = ardd._to_java_object_rdd()
jdf = spark._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(),
spark._wrapped._jsqlContext)
df = DataFrame(jdf, spark._wrapped)
df._schema = schema
df.show()
Thank you so much @linar-jether ! It's very helpful.
I want to avoid Pandas conversion because my data is in Arrow RecordBatches on all worker nodes.
@linar-jether @tahashmi while running your code I am facing below error
Traceback (most recent call last):
File "recordbatch.py", line 48, in
spark._wrapped._jsqlContext)
AttributeError: 'SparkSession' object has no attribute '_wrapped'
Am I missing something?
Thank you for your reply.
I tried something like this for my Arrow RecordBatches RDD as per your suggestions above.
In the above scenario with enabling
ardd = ardd.flatMap(lambda x: _arrow_record_batch_dumps(x))
, I got the followoing errors:It seems like the Arrow dumper is just taking the first array in Arrow RecordBatch which is of type
pa.int16())
and not able to read the schema of RecordBatch.Similarly, by commenting out
ardd = ardd.flatMap(lambda x: _arrow_record_batch_dumps(x))
and passingardd
directly to._to_java_object_rdd()
, I got the followoing errors:Any help to make this code running will be highly appreciated! Thanks