Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save linar-jether/7dd61ed6fa89098ab9c58a1ab428b2b5 to your computer and use it in GitHub Desktop.
Save linar-jether/7dd61ed6fa89098ab9c58a1ab428b2b5 to your computer and use it in GitHub Desktop.
Convert a RDD of pandas DataFrames to a single Spark DataFrame using Arrow and without collecting all data in the driver.
import pandas as pd
def _dataframe_to_arrow_record_batch(pdf, schema=None, timezone=None, parallelism=1):
"""
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
from pyspark.serializers import ArrowSerializer, _create_batch
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType, Row, DataType, StringType, StructType
from pyspark.sql.utils import require_minimum_pandas_version, \
require_minimum_pyarrow_version
require_minimum_pandas_version()
require_minimum_pyarrow_version()
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
# Determine arrow types to coerce data when creating batches
if isinstance(schema, StructType):
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]
elif isinstance(schema, DataType):
raise ValueError("Single data type %s is not supported with Arrow" % str(schema))
else:
# Any timestamps must be coerced to be compatible with Spark
arrow_types = [to_arrow_type(TimestampType())
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None
for t in pdf.dtypes]
# Slice the DataFrame to be batched
step = -(-len(pdf) // parallelism) # round int up
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))
# Create Arrow record batches
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],
timezone)
for pdf_slice in pdf_slices]
return map(bytearray, map(ArrowSerializer().dumps, batches))
def createFromPandasDataframesRDD(self, prdd, schema=None, timezone=None):
from pyspark.sql.types import from_arrow_schema
from pyspark.sql.dataframe import DataFrame
from pyspark.serializers import ArrowSerializer, PickleSerializer, AutoBatchedSerializer
# Map rdd of pandas dataframes to arrow record batches
prdd = prdd.filter(lambda x: isinstance(x, pd.DataFrame)).cache()
# If schema is not defined, get from the first dataframe
if schema is None:
schema = [str(x) if not isinstance(x, basestring) else
(x.encode('utf-8') if not isinstance(x, str) else x)
for x in prdd.map(lambda x: x.columns).first()]
prdd = prdd.flatMap(lambda x: _dataframe_to_arrow_record_batch(x, schema=schema, timezone=timezone))
# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)
struct = from_arrow_schema(ArrowSerializer().loads(prdd.first()).schema)
for i, name in enumerate(schema):
struct.fields[i].name = name
struct.names[i] = name
schema = struct
# Create the Spark DataFrame directly from the Arrow data and schema
jrdd = prdd._to_java_object_rdd()
jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
jrdd, schema.json(), self._wrapped._jsqlContext)
df = DataFrame(jdf, self._wrapped)
df._schema = schema
return df
from pyspark.sql import SparkSession
SparkSession.createFromPandasDataframesRDD = createFromPandasDataframesRDD
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <p><b>SparkSession - hive</b></p>\n",
" \n",
" <div>\n",
" <p><b>SparkContext</b></p>\n",
"\n",
" <p><a href=\"http://cluster-1-m.c.jether-trader.internal:4040\">Spark UI</a></p>\n",
"\n",
" <dl>\n",
" <dt>Version</dt>\n",
" <dd><code>v2.3.0</code></dd>\n",
" <dt>Master</dt>\n",
" <dd><code>yarn</code></dd>\n",
" <dt>AppName</dt>\n",
" <dd><code>PySparkShell</code></dd>\n",
" </dl>\n",
" </div>\n",
" \n",
" </div>\n",
" "
],
"text/plain": [
"<pyspark.sql.session.SparkSession at 0x7fb1ec14ba50>"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"spark"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# Cluster of 2 n1-standard-16 VMs - 32 total cores\n",
"# Generate a RDD of 64 pd.DataFrame objects of about 20mb - total size 1.25gb\n",
"# df = pd.DataFrame(np.random.randint(0,100,size=(10<<15, 4)), columns=list('ABCD'))\n",
"prdd = sc.range(0, 64, numSlices=64).\\\n",
" map(lambda x: pd.DataFrame(np.random.randint(0,100,size=(20<<15, 4)), columns=list('ABCD')))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total DFs size: 1.25gb\n",
"CPU times: user 16 ms, sys: 4 ms, total: 20 ms\n",
"Wall time: 1.89 s\n"
]
}
],
"source": [
"%%time\n",
"print 'Total DFs size: %.2fgb' % (prdd.map(lambda x: x.memory_usage(deep=True).sum()).sum() / (1<<20) / 1024.0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Try to calculate the average value for each column (for every partition)\n",
"# Three methods are used,\n",
"# - using pandas to apply .mean() on each partition\n",
"# - using pd.DataFrame -> rdd<Row> -> spark.DataFrame\n",
"# - using pd.DataFrame -> ArrowRecordBatches -> spark.DataFrame"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"A 49.497162\n",
"B 49.512371\n",
"C 49.502610\n",
"D 49.511418\n",
"dtype: float64\n",
"A 49.497162\n",
"B 49.512371\n",
"C 49.502610\n",
"D 49.511418\n",
"dtype: float64\n",
"A 49.497162\n",
"B 49.512371\n",
"C 49.502610\n",
"D 49.511418\n",
"dtype: float64\n",
"A 49.497162\n",
"B 49.512371\n",
"C 49.502610\n",
"D 49.511418\n",
"dtype: float64\n",
"1 loop, best of 3: 1.52 s per loop\n"
]
}
],
"source": [
"%%timeit\n",
"# Native method\n",
"print reduce(lambda x,y: pd.concat([x,y], axis=1), prdd.map(lambda x: x.mean()).collect()).mean(axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Old method (without arrow)\n",
"def _get_numpy_record_dtype(rec):\n",
" \"\"\"\n",
" Used when converting a pandas.DataFrame to Spark using to_records(), this will correct\n",
" the dtypes of fields in a record so they can be properly loaded into Spark.\n",
" :param rec: a numpy record to check field dtypes\n",
" :return corrected dtype for a numpy.record or None if no correction needed\n",
" \"\"\"\n",
" import numpy as np\n",
" cur_dtypes = rec.dtype\n",
" col_names = cur_dtypes.names\n",
" record_type_list = []\n",
" has_rec_fix = False\n",
" for i in xrange(len(cur_dtypes)):\n",
" curr_type = cur_dtypes[i]\n",
" # If type is a datetime64 timestamp, convert to microseconds\n",
" # NOTE: if dtype is datetime[ns] then np.record.tolist() will output values as longs,\n",
" # conversion from [us] or lower will lead to py datetime objects, see SPARK-22417\n",
" if curr_type == np.dtype('datetime64[ns]'):\n",
" curr_type = 'datetime64[us]'\n",
" has_rec_fix = True\n",
" record_type_list.append((str(col_names[i]), curr_type))\n",
" return np.dtype(record_type_list) if has_rec_fix else None\n",
"\n",
"\n",
"def pandas_to_rows(pdf):\n",
" \"\"\"\n",
" Convert a pandas.DataFrame to list of records that can be used to make a DataFrame\n",
" :return list of records\n",
" \"\"\"\n",
" schema = [str(x) if not isinstance(x, basestring) else\n",
" (x.encode('utf-8') if not isinstance(x, str) else x)\n",
" for x in pdf.columns]\n",
"\n",
" from pyspark.sql import Row\n",
" Record = Row(*schema)\n",
"\n",
" # Convert pandas.DataFrame to list of numpy records\n",
" np_records = pdf.to_records(index=False)\n",
"\n",
" # Check if any columns need to be fixed for Spark to infer properly\n",
" if len(np_records) > 0:\n",
" record_dtype = _get_numpy_record_dtype(np_records[0])\n",
" if record_dtype is not None:\n",
" return [Record(*r.astype(record_dtype).tolist()) for r in np_records]\n",
"\n",
" # Convert list of numpy records to python lists\n",
" return [Record(*r.tolist()) for r in np_records]"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------------------+------------------+------------------+------------------+\n",
"| avg(A)| avg(B)| avg(C)| avg(D)|\n",
"+------------------+------------------+------------------+------------------+\n",
"|49.497162437438966|49.512370872497556|49.502609729766846|49.511417865753174|\n",
"+------------------+------------------+------------------+------------------+\n",
"\n",
"+------------------+------------------+------------------+------------------+\n",
"| avg(A)| avg(B)| avg(C)| avg(D)|\n",
"+------------------+------------------+------------------+------------------+\n",
"|49.497162437438966|49.512370872497556|49.502609729766846|49.511417865753174|\n",
"+------------------+------------------+------------------+------------------+\n",
"\n",
"+------------------+------------------+------------------+------------------+\n",
"| avg(A)| avg(B)| avg(C)| avg(D)|\n",
"+------------------+------------------+------------------+------------------+\n",
"|49.497162437438966|49.512370872497556|49.502609729766846|49.511417865753174|\n",
"+------------------+------------------+------------------+------------------+\n",
"\n",
"+------------------+------------------+------------------+------------------+\n",
"| avg(A)| avg(B)| avg(C)| avg(D)|\n",
"+------------------+------------------+------------------+------------------+\n",
"|49.497162437438966|49.512370872497556|49.502609729766846|49.511417865753174|\n",
"+------------------+------------------+------------------+------------------+\n",
"\n",
"1 loop, best of 3: 36.1 s per loop\n"
]
}
],
"source": [
"%%timeit\n",
"df = prdd.flatMap(pandas_to_rows).toDF()\n",
"df.agg({x: \"avg\" for x in df.columns}).show()"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"# New method - Using Arrow for serialization\n",
"def _dataframe_to_arrow_record_batch(pdf, schema=None, timezone=None, parallelism=1):\n",
" \"\"\"\n",
" Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting\n",
" to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the\n",
" data types will be used to coerce the data in Pandas to Arrow conversion.\n",
" \"\"\"\n",
"\n",
" from pyspark.serializers import ArrowSerializer, _create_batch\n",
" from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType, Row, DataType, StringType, StructType\n",
" from pyspark.sql.utils import require_minimum_pandas_version, \\\n",
" require_minimum_pyarrow_version\n",
"\n",
" require_minimum_pandas_version()\n",
" require_minimum_pyarrow_version()\n",
"\n",
" from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype\n",
"\n",
" # Determine arrow types to coerce data when creating batches\n",
" if isinstance(schema, StructType):\n",
" arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]\n",
" elif isinstance(schema, DataType):\n",
" raise ValueError(\"Single data type %s is not supported with Arrow\" % str(schema))\n",
" else:\n",
" # Any timestamps must be coerced to be compatible with Spark\n",
" arrow_types = [to_arrow_type(TimestampType())\n",
" if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None\n",
" for t in pdf.dtypes]\n",
"\n",
" # Slice the DataFrame to be batched\n",
" step = -(-len(pdf) // parallelism) # round int up\n",
" pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))\n",
"\n",
" # Create Arrow record batches\n",
" batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],\n",
" timezone)\n",
" for pdf_slice in pdf_slices]\n",
"\n",
" return map(bytearray, map(ArrowSerializer().dumps, batches))\n",
"\n",
"\n",
"def createFromPandasDataframesRDD(self, prdd, schema=None, timezone=None):\n",
" from pyspark.sql.types import from_arrow_schema\n",
" from pyspark.sql.dataframe import DataFrame\n",
" from pyspark.serializers import ArrowSerializer, PickleSerializer, AutoBatchedSerializer\n",
"\n",
" # Map rdd of pandas dataframes to arrow record batches\n",
" prdd = prdd.filter(lambda x: isinstance(x, pd.DataFrame)).cache()\n",
"\n",
" # If schema is not defined, get from the first dataframe\n",
" if schema is None:\n",
" schema = [str(x) if not isinstance(x, basestring) else\n",
" (x.encode('utf-8') if not isinstance(x, str) else x)\n",
" for x in prdd.map(lambda x: x.columns).first()]\n",
"\n",
" prdd = prdd.flatMap(lambda x: _dataframe_to_arrow_record_batch(x, schema=schema, timezone=timezone))\n",
"\n",
" # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)\n",
" struct = from_arrow_schema(ArrowSerializer().loads(prdd.first()).schema)\n",
" for i, name in enumerate(schema):\n",
" struct.fields[i].name = name\n",
" struct.names[i] = name\n",
" schema = struct\n",
"\n",
" # Create the Spark DataFrame directly from the Arrow data and schema\n",
" jrdd = prdd._to_java_object_rdd()\n",
" jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(\n",
" jrdd, schema.json(), self._wrapped._jsqlContext)\n",
" df = DataFrame(jdf, self._wrapped)\n",
" df._schema = schema\n",
"\n",
" return df\n",
"\n",
"from pyspark.sql import SparkSession\n",
"SparkSession.createFromPandasDataframesRDD = createFromPandasDataframesRDD"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"+------------------+------------------+------------------+------------------+\n",
"| avg(A)| avg(B)| avg(C)| avg(D)|\n",
"+------------------+------------------+------------------+------------------+\n",
"|49.497162437438966|49.512370872497556|49.502609729766846|49.511417865753174|\n",
"+------------------+------------------+------------------+------------------+\n",
"\n",
"+------------------+------------------+------------------+------------------+\n",
"| avg(A)| avg(B)| avg(C)| avg(D)|\n",
"+------------------+------------------+------------------+------------------+\n",
"|49.497162437438966|49.512370872497556|49.502609729766846|49.511417865753174|\n",
"+------------------+------------------+------------------+------------------+\n",
"\n",
"+------------------+------------------+------------------+------------------+\n",
"| avg(A)| avg(B)| avg(C)| avg(D)|\n",
"+------------------+------------------+------------------+------------------+\n",
"|49.497162437438966|49.512370872497556|49.502609729766846|49.511417865753174|\n",
"+------------------+------------------+------------------+------------------+\n",
"\n",
"+------------------+------------------+------------------+------------------+\n",
"| avg(A)| avg(B)| avg(C)| avg(D)|\n",
"+------------------+------------------+------------------+------------------+\n",
"|49.497162437438966|49.512370872497556|49.502609729766846|49.511417865753174|\n",
"+------------------+------------------+------------------+------------------+\n",
"\n",
"1 loop, best of 3: 6.03 s per loop\n"
]
}
],
"source": [
"%%timeit\n",
"df = spark.createFromPandasDataframesRDD(prdd)\n",
"df.agg({x: \"avg\" for x in df.columns}).show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Native method - 1.52s\n",
"# toRow serialization - 36.1s\n",
"# Arrow serialization - 6.03s"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "PySpark",
"language": "python",
"name": "pyspark"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@tahashmi
Copy link

Hi Jether,

I have a small question if you can help me.

In this code snippet, you are converting a prdd (RDD) of pd.Dataframes objects to Arrow RecordBatches (slices) and then to Spark Dataframe finally.

The code in Scala converts JavaRDD to Spark Dataframe directly. (https://github.com/apache/spark/blob/65a189c7a1ddceb8ab482ccc60af5350b8da5ea5/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala#L192-L206)

If I already have an ardd (RDD) of pa.RecordBatch (Arrow RecordBatches) objects, how can I convert it to Spark Dataframe directly without using Pandas in PySpark like in Scala? Thanks.

@linar-jether
Copy link
Author

@tahashmi i believe this is similar to what's done internally in the the createFromPandasDataframesRDD function above.
You can try using this to create the dataframe directly from arrow:

    # Create the Spark DataFrame directly from the Arrow data and schema
    jrdd = prdd._to_java_object_rdd()
    jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
        jrdd, schema.json(), self._wrapped._jsqlContext)
    df = DataFrame(jdf, self._wrapped)
    df._schema = schema

@tahashmi
Copy link

tahashmi commented Jul 1, 2020

Thank you for your reply.

I tried something like this for my Arrow RecordBatches RDD as per your suggestions above.


def _arrow_record_batch_dumps(rb):
    from pyspark.serializers import ArrowSerializer

    return map(bytearray, map(ArrowSerializer().dumps, rb))

def createFromArrowRecordBatchesRDD(self, ardd, schema=None, timezone=None):
    from pyspark.sql.types import from_arrow_schema
    from pyspark.sql.dataframe import DataFrame
    from pyspark.serializers import ArrowSerializer, PickleSerializer, AutoBatchedSerializer

    # Filter out and cache arrow record batches 
    ardd = ardd.filter(lambda x: isinstance(x, pa.RecordBatch)).cache()

    ardd = ardd.flatMap(lambda x: _arrow_record_batch_dumps(x))

    schema = pa.schema([pa.field('c0', pa.int16()),
                        pa.field('c1', pa.int32())],
                       metadata={b'foo': b'bar'})
    schema = from_arrow_schema(schema)

    # Create the Spark DataFrame directly from the Arrow data and schema
    jrdd = ardd._to_java_object_rdd()
    jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(
        jrdd, schema.json(), self._wrapped._jsqlContext)
    df = DataFrame(jdf, self._wrapped)
    df._schema = schema

    return df

def rb_return(ardd):
    data = [
        pa.array(range(5), type='int16'),
        pa.array([-10, -5, 0, None, 10], type='int32')
    ]
    schema = pa.schema([pa.field('c0', pa.int16()),
                        pa.field('c1', pa.int32())],
                       metadata={b'foo': b'bar'})
    return pa.RecordBatch.from_arrays(data, schema)

if __name__ == '__main__':
    spark = SparkSession \
        .builder \
        .appName("Python Arrow-in-Spark example") \
        .getOrCreate()

    # Enable Arrow-based columnar data transfers
    spark.conf.set("spark.sql.execution.arrow.enabled", "true")

    ardd = spark.sparkContext.parallelize([0,1,2], 3)
    ardd = ardd.map(rb_return)

    SparkSession.createFromArrowRecordBatchesRDD = createFromArrowRecordBatchesRDD
    df = spark.createFromArrowRecordBatchesRDD(ardd)
    df.show()

In the above scenario with enabling ardd = ardd.flatMap(lambda x: _arrow_record_batch_dumps(x)) , I got the followoing errors:
It seems like the Arrow dumper is just taking the first array in Arrow RecordBatch which is of type pa.int16()) and not able to read the schema of RecordBatch.

2020-07-01 04:13:25 INFO  TaskSchedulerImpl:54 - Adding task set 0.0 with 1 tasks
2020-07-01 04:13:25 INFO  TaskSetManager:54 - Starting task 0.0 in stage 0.0 (TID 0, localhost, executor driver, partition 0, PROCESS_LOCAL, 7858 bytes)
2020-07-01 04:13:25 INFO  Executor:54 - Running task 0.0 in stage 0.0 (TID 0)
2020-07-01 04:13:25 INFO  Executor:54 - Fetching file:/home/tahmad/tahmad/script.py with timestamp 1593576802633
2020-07-01 04:13:25 INFO  Utils:54 - /home/tahmad/tahmad/script.py has been previously copied to /tmp/spark-868994a9-976a-4d8f-884c-d97d37df14ae/userFiles-2ac7c66e-924b-498c-b760-ee25ba956a39/script.py
2020-07-01 04:13:25 INFO  PythonRunner:54 - Times: total = 425, boot = 215, init = 210, finish = 0
2020-07-01 04:13:25 INFO  MemoryStore:54 - Block rdd_1_0 stored as bytes in memory (estimated size 392.0 B, free 366.3 MB)
2020-07-01 04:13:25 INFO  BlockManagerInfo:54 - Added rdd_1_0 in memory on tcn793:45311 (size: 392.0 B, free: 366.3 MB)
2020-07-01 04:13:25 ERROR Executor:91 - Exception in task 0.0 in stage 0.0 (TID 0)
org.apache.spark.api.python.PythonException: Traceback (most recent call last):
  File "/home/tahmad/tahmad/spark-2.3.4-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/worker.py", line 253, in main
    process()
  File "/home/tahmad/tahmad/spark-2.3.4-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/worker.py", line 248, in process
    serializer.dump_stream(func(split_index, iterator), outfile)
  File "/home/tahmad/tahmad/spark-2.3.4-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/serializers.py", line 379, in dump_stream
    vs = list(itertools.islice(iterator, batch))
  File "/home/tahmad/tahmad/spark-2.3.4-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/util.py", line 55, in wrapper
    return f(*args, **kwargs)
  File "/home/tahmad/tahmad/script.py", line 53, in <lambda>
    ardd = ardd.flatMap(lambda x: _arrow_record_batch_dumps(x))
  File "/home/tahmad/tahmad/script.py", line 42, in _arrow_record_batch_dumps
    return map(bytearray, map(ArrowSerializer().dumps, rb))
  File "/home/tahmad/tahmad/spark-2.3.4-bin-hadoop2.7/python/lib/pyspark.zip/pyspark/serializers.py", line 196, in dumps
    writer = pa.RecordBatchFileWriter(sink, batch.schema)
AttributeError: 'pyarrow.lib.Int16Array' object has no attribute 'schema'

	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:336)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:475)
	at org.apache.spark.api.python.PythonRunner$$anon$1.read(PythonRunner.scala:458)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:290)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:439)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.<init>(ArrowConverters.scala:138)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$.fromPayloadIterator(ArrowConverters.scala:135)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anonfun$3.apply(ArrowConverters.scala:211)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anonfun$3.apply(ArrowConverters.scala:209)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Similarly, by commenting out ardd = ardd.flatMap(lambda x: _arrow_record_batch_dumps(x)) and passing ardd directly to ._to_java_object_rdd() , I got the followoing errors:

net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for pyarrow.lib.type_for_alias)
	at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
	at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:707)
	at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:175)
	at net.razorvine.pickle.Unpickler.load(Unpickler.java:99)
	at net.razorvine.pickle.Unpickler.loads(Unpickler.java:112)
	at org.apache.spark.api.python.SerDeUtil$$anonfun$pythonToJava$1$$anonfun$apply$1.apply(SerDeUtil.scala:188)
	at org.apache.spark.api.python.SerDeUtil$$anonfun$pythonToJava$1$$anonfun$apply$1.apply(SerDeUtil.scala:187)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.<init>(ArrowConverters.scala:138)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$.fromPayloadIterator(ArrowConverters.scala:135)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anonfun$3.apply(ArrowConverters.scala:211)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anonfun$3.apply(ArrowConverters.scala:209)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
2020-07-01 04:22:43 WARN  TaskSetManager:66 - Lost task 0.0 in stage 0.0 (TID 0, localhost, executor driver): net.razorvine.pickle.PickleException: expected zero arguments for construction of ClassDict (for pyarrow.lib.type_for_alias)
	at net.razorvine.pickle.objects.ClassDictConstructor.construct(ClassDictConstructor.java:23)
	at net.razorvine.pickle.Unpickler.load_reduce(Unpickler.java:707)
	at net.razorvine.pickle.Unpickler.dispatch(Unpickler.java:175)
	at net.razorvine.pickle.Unpickler.load(Unpickler.java:99)
	at net.razorvine.pickle.Unpickler.loads(Unpickler.java:112)
	at org.apache.spark.api.python.SerDeUtil$$anonfun$pythonToJava$1$$anonfun$apply$1.apply(SerDeUtil.scala:188)
	at org.apache.spark.api.python.SerDeUtil$$anonfun$pythonToJava$1$$anonfun$apply$1.apply(SerDeUtil.scala:187)
	at scala.collection.Iterator$$anon$12.nextCur(Iterator.scala:434)
	at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:408)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anon$1.<init>(ArrowConverters.scala:138)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$.fromPayloadIterator(ArrowConverters.scala:135)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anonfun$3.apply(ArrowConverters.scala:211)
	at org.apache.spark.sql.execution.arrow.ArrowConverters$$anonfun$3.apply(ArrowConverters.scala:209)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
	at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:801)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:49)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
	at org.apache.spark.scheduler.Task.run(Task.scala:109)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)

Any help to make this code running will be highly appreciated! Thanks

@linar-jether
Copy link
Author

Hi @tahashmi, this seems to work for me:
Issue was using flatMap on the record batch, causing it to iterate on arrays

from pyspark.sql import SparkSession
import pyspark
import pyarrow as pa
from pyspark.serializers import ArrowSerializer

def _arrow_record_batch_dumps(rb):
    # Fix for interoperability between pyarrow version >=0.15 and Spark's arrow version
    # Streaming message protocol has changed, remove setting when upgrading spark.
    import os
    os.environ['ARROW_PRE_0_15_IPC_FORMAT'] = '1'
    
    return bytearray(ArrowSerializer().dumps(rb))


def rb_return(ardd):
    data = [
        pa.array(range(5), type='int16'),
        pa.array([-10, -5, 0, None, 10], type='int32')
    ]
    schema = pa.schema([pa.field('c0', pa.int16()),
                        pa.field('c1', pa.int32())],
                       metadata={b'foo': b'bar'})
    return pa.RecordBatch.from_arrays(data, schema=schema)

if __name__ == '__main__':
    spark = SparkSession \
        .builder \
        .appName("Python Arrow-in-Spark example") \
        .getOrCreate()

    # Enable Arrow-based columnar data transfers
    spark.conf.set("spark.sql.execution.arrow.enabled", "true")
    sc = spark.sparkContext

    ardd = spark.sparkContext.parallelize([0,1,2], 3)
    ardd = ardd.map(rb_return)

    from pyspark.sql.types import from_arrow_schema
    from pyspark.sql.dataframe import DataFrame
    from pyspark.serializers import ArrowSerializer, PickleSerializer, AutoBatchedSerializer

    # Filter out and cache arrow record batches 
    ardd = ardd.filter(lambda x: isinstance(x, pa.RecordBatch)).cache()

    ardd = ardd.map(_arrow_record_batch_dumps)

    schema = pa.schema([pa.field('c0', pa.int16()),
                        pa.field('c1', pa.int32())],
                       metadata={b'foo': b'bar'})
    schema = from_arrow_schema(schema)

    jrdd = ardd._to_java_object_rdd()
    jdf = spark._jvm.PythonSQLUtils.arrowPayloadToDataFrame(jrdd, schema.json(), spark._wrapped._jsqlContext)
    df = DataFrame(jdf, spark._wrapped)
    df._schema = schema

    df.show()
    

@tahashmi
Copy link

Hi @linar-jether, Thank you so much for your time. :-)
I have just checked, it works perfectly.

@tahashmi
Copy link

Hi @linar-jether, I got stuck again when migrating the above code to Spark-3.0.0.
Can you please help me? I created a JIRA few days back explaining the whole issue. Thanks.

@linar-jether
Copy link
Author

Hi @tahashmi, Can be done using:

Also this PR (SPARK-32846) might be useful as it uses user-facing APIs (but requires conversion to a pandas rdd)

from pyspark.sql import SparkSession
import pyarrow as pa


def _arrow_record_batch_dumps(rb):
    return bytearray(rb.serialize())


def rb_return(ardd):
    data = [
        pa.array(range(5), type='int16'),
        pa.array([-10, -5, 0, None, 10], type='int32')
    ]
    schema = pa.schema([pa.field('c0', pa.int16()),
                        pa.field('c1', pa.int32())],
                       metadata={b'foo': b'bar'})
    return pa.RecordBatch.from_arrays(data, schema=schema)


if __name__ == '__main__':
    spark = SparkSession \
        .builder \
        .appName("Python Arrow-in-Spark example") \
        .getOrCreate()

    # Enable Arrow-based columnar data transfers
    spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
    sc = spark.sparkContext

    ardd = spark.sparkContext.parallelize([0, 1, 2], 3)
    ardd = ardd.map(rb_return)

    from pyspark.sql.pandas.types import from_arrow_schema
    from pyspark.sql.dataframe import DataFrame

    # Filter out and cache arrow record batches 
    ardd = ardd.filter(lambda x: isinstance(x, pa.RecordBatch)).cache()

    ardd = ardd.map(_arrow_record_batch_dumps)

    schema = pa.schema([pa.field('c0', pa.int16()),
                        pa.field('c1', pa.int32())],
                       metadata={b'foo': b'bar'})
    schema = from_arrow_schema(schema)

    jrdd = ardd._to_java_object_rdd()
    jdf = spark._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(),
                                                spark._wrapped._jsqlContext)
    df = DataFrame(jdf, spark._wrapped)
    df._schema = schema

    df.show()

@tahashmi
Copy link

Thank you so much @linar-jether ! It's very helpful.
I want to avoid Pandas conversion because my data is in Arrow RecordBatches on all worker nodes.

@ashish615
Copy link

@linar-jether @tahashmi while running your code I am facing below error
Traceback (most recent call last):
File "recordbatch.py", line 48, in
spark._wrapped._jsqlContext)
AttributeError: 'SparkSession' object has no attribute '_wrapped'
Am I missing something?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment