Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save linar-jether/7dd61ed6fa89098ab9c58a1ab428b2b5 to your computer and use it in GitHub Desktop.
Save linar-jether/7dd61ed6fa89098ab9c58a1ab428b2b5 to your computer and use it in GitHub Desktop.
Convert a RDD of pandas DataFrames to a single Spark DataFrame using Arrow and without collecting all data in the driver.
import pandas as pd
def _dataframe_to_arrow_record_batch(pdf, schema=None, timezone=None, parallelism=1):
"""
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowSerializer, _create_batch
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType, Row, DataType, StringType, StructType
from pyspark.sql.utils import require_minimum_pandas_version, \
require_minimum_pyarrow_version
require_minimum_pandas_version()
require_minimum_pyarrow_version()
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
# Determine arrow types to coerce data when creating batches
if isinstance(schema, StructType):
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
elif isinstance(schema, DataType):
raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
else:
# Any timestamps must be coerced to be compatible with Spark
arrow_types = [to_arrow_type(TimestampType())
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
for t in pdf.dtypes]
# Slice the DataFrame to be batched
step = -(-len(pdf) // parallelism) # round int up
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
# Create Arrow record batches
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
timezone)
for pdf_slice in pdf_slices]
return map(bytearray, map(ArrowSerializer().dumps, batches))
def createFromPandasDataframesRDD(self, prdd, schema=None, timezone=None):
from pyspark.sql.types import from_arrow_schema
from pyspark.sql.dataframe import DataFrame
from pyspark.serializers import ArrowSerializer, PickleSerializer, AutoBatchedSerializer
# Map rdd of pandas dataframes to arrow record batches
prdd = prdd.filter(lambda x: isinstance(x, pd.DataFrame)).cache()
# If schema is not defined, get from the first dataframe
if schema is None:
schema = [str(x) if not isinstance(x, basestring) else
(x.encode('utf-8') if not isinstance(x, str) else x)
for x in prdd.map(lambda x: x.columns).first()]
prdd = prdd.flatMap(lambda x: _dataframe_to_arrow_record_batch(x, schema=schema, timezone=timezone))
# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
struct = from_arrow_schema(ArrowSerializer().loads(prdd.first()).schema)
for i, name in enumerate(schema):
struct.fields[i].name = name
struct.names[i] = name
schema = struct
# Create the Spark DataFrame directly from the Arrow data and schema
jrdd = prdd._to_java_object_rdd()
jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
jrdd, schema.json(), self._wrapped._jsqlContext)
df = DataFrame(jdf, self._wrapped)
df._schema = schema
return df
from pyspark.sql import SparkSession
SparkSession.createFromPandasDataframesRDD = createFromPandasDataframesRDD
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@tahashmi
Copy link

Hi Jether,

I have a small question if you can help me.

In this code snippet, you are converting a prdd (RDD) of pd.Dataframes objects to Arrow RecordBatches (slices) and then to Spark Dataframe finally.

The code in Scala converts JavaRDD to Spark Dataframe directly. (https://github.com/apache/spark/blob/65a189c7a1ddceb8ab482ccc60af5350b8da5ea5/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala#L192-L206)

If I already have an ardd (RDD) of pa.RecordBatch (Arrow RecordBatches) objects, how can I convert it to Spark Dataframe directly without using Pandas in PySpark like in Scala? Thanks.

@linar-jether
Copy link
Author

@tahashmi i believe this is similar to what's done internally in the the createFromPandasDataframesRDD function above.
You can try using this to create the dataframe directly from arrow:

    # Create the Spark DataFrame directly from the Arrow data and schema
    jrdd = prdd._to_java_object_rdd()
    jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
        jrdd, schema.json(), self._wrapped._jsqlContext)
    df = DataFrame(jdf, self._wrapped)
    df._schema = schema

@tahashmi
Copy link

tahashmi commented Jul 1, 2020

Thank you for your reply.

I tried something like this for my Arrow RecordBatches RDD as per your suggestions above.


def _arrow_record_batch_dumps(rb):
    from pyspark.serializers import ArrowSerializer

    return map(bytearray, map(ArrowSerializer().dumps, rb))

def createFromArrowRecordBatchesRDD(self, ardd, schema=None, timezone=None):
    from pyspark.sql.types import from_arrow_schema
    from pyspark.sql.dataframe import DataFrame
    from pyspark.serializers import ArrowSerializer, PickleSerializer, AutoBatchedSerializer

    # Filter out and cache arrow record batches 
    ardd = ardd.filter(lambda x: isinstance(x, pa.RecordBatch)).cache()

    ardd = ardd.flatMap(lambda x: _arrow_record_batch_dumps(x))

    schema = pa.schema([pa.field('c0', pa.int16()),
                        pa.field('c1', pa.int32())],
                       metadata={b'foo': b'bar'})
    schema = from_arrow_schema(schema)

    # Create the Spark DataFrame directly from the Arrow data and schema
    jrdd = ardd._to_java_object_rdd()
    jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
        jrdd, schema.json(), self._wrapped._jsqlContext)
    df = DataFrame(jdf, self._wrapped)
    df._schema = schema

    return df

def rb_return(ardd):
    data = [
        pa.array(range(5), type='int16'),
        pa.array([-10, -5, 0, None, 10], type='int32')
    ]
    schema = pa.schema([pa.field('c0', pa.int16()),
                        pa.field('c1', pa.int32())],
                       metadata={b'foo': b'bar'})
    return pa.RecordBatch.from_arrays(data, schema)

if __name__ == '__main__':
    spark = SparkSession \
        .builder \
        .appName("Python Arrow-in-Spark example") \
        .getOrCreate()

    # Enable Arrow-based columnar data transfers
    spark.conf.set("spark.sql.execution.arrow.enabled", "true")

    ardd = spark.sparkContext.parallelize([0,1,2], 3)
    ardd = ardd.map(rb_return)

    SparkSession.createFromArrowRecordBatchesRDD = createFromArrowRecordBatchesRDD
    df = spark.createFromArrowRecordBatchesRDD(ardd)
    df.show()

In the above scenario with enabling ardd = ardd.flatMap(lambda x: _arrow_record_batch_dumps(x)) , I got the followoing errors:
It seems like the Arrow dumper is just taking the first array in Arrow RecordBatch which is of type pa.int16()) and not able to read the schema of RecordBatch.

2020-07-01 04:13:25 INFO  TaskSchedulerImpl:54 - Adding task set 0.0 with 1 tasks
2020-07-01 04:13:25 INFO  TaskSetManager:54 - Starting task 0.0 in stage 0.0 (TID 0, localhost, executor driver, partition 0, PROCESS_LOCAL, 7858 bytes)
2020-07-01 04:13:25 INFO  Executor:54 - Running task 0.0 in stage 0.0 (TID 0)
2020-07-01 04:13:25 INFO  Executor:54 - Fetching file:/home/tahmad/tahmad/script.py with timestamp 1593576802633
2020-07-01 04:13:25 INFO  Utils:54 - /home/tahmad/tahmad/script.py has been previously copied to /tmp/spark-868994a9-976a-4d8f-884c-d97d37df14ae/userFiles-2ac7c66e-924b-498c-b760-ee25ba956a39/script.py
2020-07-01 04:13:25 INFO  PythonRunner:54 - Times: total = 425, boot = 215, init = 210, finish = 0
2020-07-01 04:13:25 INFO  MemoryStore:54 - Block rdd_1_0 stored as bytes in memory (estimated size 392.0 B, free 366.3 MB)
2020-07-01 04:13:25 INFO  BlockManagerInfo:54 - Added rdd_1_0 in memory on tcn793:45311 (size: 392.0 B, free: 366.3 MB)
2020-07-01 04:13:25 ERROR Executor:91 - Exception in task 0.0 in stage 0.0 (TID 0)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/home/tahmad/tahmad/spark-2.3.4-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/worker.py", line 253, in main
    process()
  File "/home/tahmad/tahmad/spark-2.3.4-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/worker.py", line 248, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/home/tahmad/tahmad/spark-2.3.4-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/serializers.py", line 379, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "/home/tahmad/tahmad/spark-2.3.4-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/util.py", line 55, in wrapper
    return f(*args, **kwargs)
  File "/home/tahmad/tahmad/script.py", line 53, in <lambda>
    ardd = ardd.flatMap(lambda x: _arrow_record_batch_dumps(x))
  File "/home/tahmad/tahmad/script.py", line 42, in _arrow_record_batch_dumps
    return map(bytearray, map(ArrowSerializer().dumps, rb))
  File "/home/tahmad/tahmad/spark-2.3.4-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/serializers.py", line 196, in dumps
    writer = pa.RecordBatchFileWriter(sink, batch.schema)
AttributeError: 'pyarrow.lib.Int16Array' object has no attribute 'schema'

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:336)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:475)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:458)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:290)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:439)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.<init>(ArrowConverters.scala:138)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$.fromPayloadIterator(ArrowConverters.scala:135)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anonfun$3.apply(ArrowConverters.scala:211)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anonfun$3.apply(ArrowConverters.scala:209)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Similarly, by commenting out ardd = ardd.flatMap(lambda x: _arrow_record_batch_dumps(x)) and passing ardd directly to ._to_java_object_rdd() , I got the followoing errors:

net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for pyarrow.lib.type_for_alias)
	at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
	at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:707)
	at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:175)
	at net.razorvine.pickle.Unpickler.load(Unpickler.java:99)
	at net.razorvine.pickle.Unpickler.loads(Unpickler.java:112)
	at org.apache.spark.api.python.SerDeUtil$$anonfun$pythonToJava$1$$anonfun$apply$1.apply(SerDeUtil.scala:188)
	at org.apache.spark.api.python.SerDeUtil$$anonfun$pythonToJava$1$$anonfun$apply$1.apply(SerDeUtil.scala:187)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.<init>(ArrowConverters.scala:138)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$.fromPayloadIterator(ArrowConverters.scala:135)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anonfun$3.apply(ArrowConverters.scala:211)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anonfun$3.apply(ArrowConverters.scala:209)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
2020-07-01 04:22:43 WARN  TaskSetManager:66 - Lost task 0.0 in stage 0.0 (TID 0, localhost, executor driver): net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for pyarrow.lib.type_for_alias)
	at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
	at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:707)
	at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:175)
	at net.razorvine.pickle.Unpickler.load(Unpickler.java:99)
	at net.razorvine.pickle.Unpickler.loads(Unpickler.java:112)
	at org.apache.spark.api.python.SerDeUtil$$anonfun$pythonToJava$1$$anonfun$apply$1.apply(SerDeUtil.scala:188)
	at org.apache.spark.api.python.SerDeUtil$$anonfun$pythonToJava$1$$anonfun$apply$1.apply(SerDeUtil.scala:187)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.<init>(ArrowConverters.scala:138)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$.fromPayloadIterator(ArrowConverters.scala:135)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anonfun$3.apply(ArrowConverters.scala:211)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anonfun$3.apply(ArrowConverters.scala:209)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Any help to make this code running will be highly appreciated! Thanks

@linar-jether
Copy link
Author

Hi @tahashmi, this seems to work for me:
Issue was using flatMap on the record batch, causing it to iterate on arrays

from pyspark.sql import SparkSession
import pyspark
import pyarrow as pa
from pyspark.serializers import ArrowSerializer

def _arrow_record_batch_dumps(rb):
    # Fix for interoperability between pyarrow version >=0.15 and Spark's arrow version
    # Streaming message protocol has changed, remove setting when upgrading spark.
    import os
    os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = '1'
    
    return bytearray(ArrowSerializer().dumps(rb))


def rb_return(ardd):
    data = [
        pa.array(range(5), type='int16'),
        pa.array([-10, -5, 0, None, 10], type='int32')
    ]
    schema = pa.schema([pa.field('c0', pa.int16()),
                        pa.field('c1', pa.int32())],
                       metadata={b'foo': b'bar'})
    return pa.RecordBatch.from_arrays(data, schema=schema)

if __name__ == '__main__':
    spark = SparkSession \
        .builder \
        .appName("Python Arrow-in-Spark example") \
        .getOrCreate()

    # Enable Arrow-based columnar data transfers
    spark.conf.set("spark.sql.execution.arrow.enabled", "true")
    sc = spark.sparkContext

    ardd = spark.sparkContext.parallelize([0,1,2], 3)
    ardd = ardd.map(rb_return)

    from pyspark.sql.types import from_arrow_schema
    from pyspark.sql.dataframe import DataFrame
    from pyspark.serializers import ArrowSerializer, PickleSerializer, AutoBatchedSerializer

    # Filter out and cache arrow record batches 
    ardd = ardd.filter(lambda x: isinstance(x, pa.RecordBatch)).cache()

    ardd = ardd.map(_arrow_record_batch_dumps)

    schema = pa.schema([pa.field('c0', pa.int16()),
                        pa.field('c1', pa.int32())],
                       metadata={b'foo': b'bar'})
    schema = from_arrow_schema(schema)

    jrdd = ardd._to_java_object_rdd()
    jdf = spark._jvm.PythonSQLUtils.arrowPayloadToDataFrame(jrdd, schema.json(), spark._wrapped._jsqlContext)
    df = DataFrame(jdf, spark._wrapped)
    df._schema = schema

    df.show()
    

@tahashmi
Copy link

Hi @linar-jether, Thank you so much for your time. :-)
I have just checked, it works perfectly.

@tahashmi
Copy link

Hi @linar-jether, I got stuck again when migrating the above code to Spark-3.0.0.
Can you please help me? I created a JIRA few days back explaining the whole issue. Thanks.

@linar-jether
Copy link
Author

Hi @tahashmi, Can be done using:

Also this PR (SPARK-32846) might be useful as it uses user-facing APIs (but requires conversion to a pandas rdd)

from pyspark.sql import SparkSession
import pyarrow as pa


def _arrow_record_batch_dumps(rb):
    return bytearray(rb.serialize())


def rb_return(ardd):
    data = [
        pa.array(range(5), type='int16'),
        pa.array([-10, -5, 0, None, 10], type='int32')
    ]
    schema = pa.schema([pa.field('c0', pa.int16()),
                        pa.field('c1', pa.int32())],
                       metadata={b'foo': b'bar'})
    return pa.RecordBatch.from_arrays(data, schema=schema)


if __name__ == '__main__':
    spark = SparkSession \
        .builder \
        .appName("Python Arrow-in-Spark example") \
        .getOrCreate()

    # Enable Arrow-based columnar data transfers
    spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
    sc = spark.sparkContext

    ardd = spark.sparkContext.parallelize([0, 1, 2], 3)
    ardd = ardd.map(rb_return)

    from pyspark.sql.pandas.types import from_arrow_schema
    from pyspark.sql.dataframe import DataFrame

    # Filter out and cache arrow record batches 
    ardd = ardd.filter(lambda x: isinstance(x, pa.RecordBatch)).cache()

    ardd = ardd.map(_arrow_record_batch_dumps)

    schema = pa.schema([pa.field('c0', pa.int16()),
                        pa.field('c1', pa.int32())],
                       metadata={b'foo': b'bar'})
    schema = from_arrow_schema(schema)

    jrdd = ardd._to_java_object_rdd()
    jdf = spark._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(),
                                                spark._wrapped._jsqlContext)
    df = DataFrame(jdf, spark._wrapped)
    df._schema = schema

    df.show()

@tahashmi
Copy link

Thank you so much @linar-jether ! It's very helpful.
I want to avoid Pandas conversion because my data is in Arrow RecordBatches on all worker nodes.

@ashish615
Copy link

@linar-jether @tahashmi while running your code I am facing below error
Traceback (most recent call last):
File "recordbatch.py", line 48, in
spark._wrapped._jsqlContext)
AttributeError: 'SparkSession' object has no attribute '_wrapped'
Am I missing something?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment