Forked from linar-jether/PySpark DataFrame from many small pandas DataFrames.ipynb
Created
May 17, 2019 07:19
-
-
Save dheerajinampudi/5a0e83ca83125fe6035af27eb561b29c to your computer and use it in GitHub Desktop.
Convert a RDD of pandas DataFrames to a single Spark DataFrame using Arrow and without collecting all data in the driver.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
def _dataframe_to_arrow_record_batch(pdf, schema=None, timezone=None, parallelism=1): | |
""" | |
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting | |
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the | |
data types will be used to coerce the data in Pandas to Arrow conversion. | |
""" | |
from pyspark.serializers import ArrowSerializer, _create_batch | |
from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType, Row, DataType, StringType, StructType | |
from pyspark.sql.utils import require_minimum_pandas_version, \ | |
require_minimum_pyarrow_version | |
require_minimum_pandas_version() | |
require_minimum_pyarrow_version() | |
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype | |
# Determine arrow types to coerce data when creating batches | |
if isinstance(schema, StructType): | |
arrow_types = [to_arrow_type(f.dataType) for f in schema.fields] | |
elif isinstance(schema, DataType): | |
raise ValueError("Single data type %s is not supported with Arrow" % str(schema)) | |
else: | |
# Any timestamps must be coerced to be compatible with Spark | |
arrow_types = [to_arrow_type(TimestampType()) | |
if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None | |
for t in pdf.dtypes] | |
# Slice the DataFrame to be batched | |
step = -(-len(pdf) // parallelism) # round int up | |
pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step)) | |
# Create Arrow record batches | |
batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)], | |
timezone) | |
for pdf_slice in pdf_slices] | |
return map(bytearray, map(ArrowSerializer().dumps, batches)) | |
def createFromPandasDataframesRDD(self, prdd, schema=None, timezone=None): | |
from pyspark.sql.types import from_arrow_schema | |
from pyspark.sql.dataframe import DataFrame | |
from pyspark.serializers import ArrowSerializer, PickleSerializer, AutoBatchedSerializer | |
# Map rdd of pandas dataframes to arrow record batches | |
prdd = prdd.filter(lambda x: isinstance(x, pd.DataFrame)).cache() | |
# If schema is not defined, get from the first dataframe | |
if schema is None: | |
schema = [str(x) if not isinstance(x, basestring) else | |
(x.encode('utf-8') if not isinstance(x, str) else x) | |
for x in prdd.map(lambda x: x.columns).first()] | |
prdd = prdd.flatMap(lambda x: _dataframe_to_arrow_record_batch(x, schema=schema, timezone=timezone)) | |
# Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing) | |
struct = from_arrow_schema(ArrowSerializer().loads(prdd.first()).schema) | |
for i, name in enumerate(schema): | |
struct.fields[i].name = name | |
struct.names[i] = name | |
schema = struct | |
# Create the Spark DataFrame directly from the Arrow data and schema | |
jrdd = prdd._to_java_object_rdd() | |
jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame( | |
jrdd, schema.json(), self._wrapped._jsqlContext) | |
df = DataFrame(jdf, self._wrapped) | |
df._schema = schema | |
return df | |
from pyspark.sql import SparkSession | |
SparkSession.createFromPandasDataframesRDD = createFromPandasDataframesRDD |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"\n", | |
" <div>\n", | |
" <p><b>SparkSession - hive</b></p>\n", | |
" \n", | |
" <div>\n", | |
" <p><b>SparkContext</b></p>\n", | |
"\n", | |
" <p><a href=\"http://cluster-1-m.c.jether-trader.internal:4040\">Spark UI</a></p>\n", | |
"\n", | |
" <dl>\n", | |
" <dt>Version</dt>\n", | |
" <dd><code>v2.3.0</code></dd>\n", | |
" <dt>Master</dt>\n", | |
" <dd><code>yarn</code></dd>\n", | |
" <dt>AppName</dt>\n", | |
" <dd><code>PySparkShell</code></dd>\n", | |
" </dl>\n", | |
" </div>\n", | |
" \n", | |
" </div>\n", | |
" " | |
], | |
"text/plain": [ | |
"<pyspark.sql.session.SparkSession at 0x7fb1ec14ba50>" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"import pandas as pd\n", | |
"import numpy as np\n", | |
"spark" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Cluster of 2 n1-standard-16 VMs - 32 total cores\n", | |
"# Generate a RDD of 64 pd.DataFrame objects of about 20mb - total size 1.25gb\n", | |
"# df = pd.DataFrame(np.random.randint(0,100,size=(10<<15, 4)), columns=list('ABCD'))\n", | |
"prdd = sc.range(0, 64, numSlices=64).\\\n", | |
" map(lambda x: pd.DataFrame(np.random.randint(0,100,size=(20<<15, 4)), columns=list('ABCD')))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Total DFs size: 1.25gb\n", | |
"CPU times: user 16 ms, sys: 4 ms, total: 20 ms\n", | |
"Wall time: 1.89 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"print 'Total DFs size: %.2fgb' % (prdd.map(lambda x: x.memory_usage(deep=True).sum()).sum() / (1<<20) / 1024.0)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Try to calculate the average value for each column (for every partition)\n", | |
"# Three methods are used,\n", | |
"# - using pandas to apply .mean() on each partition\n", | |
"# - using pd.DataFrame -> rdd<Row> -> spark.DataFrame\n", | |
"# - using pd.DataFrame -> ArrowRecordBatches -> spark.DataFrame" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"A 49.497162\n", | |
"B 49.512371\n", | |
"C 49.502610\n", | |
"D 49.511418\n", | |
"dtype: float64\n", | |
"A 49.497162\n", | |
"B 49.512371\n", | |
"C 49.502610\n", | |
"D 49.511418\n", | |
"dtype: float64\n", | |
"A 49.497162\n", | |
"B 49.512371\n", | |
"C 49.502610\n", | |
"D 49.511418\n", | |
"dtype: float64\n", | |
"A 49.497162\n", | |
"B 49.512371\n", | |
"C 49.502610\n", | |
"D 49.511418\n", | |
"dtype: float64\n", | |
"1 loop, best of 3: 1.52 s per loop\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"# Native method\n", | |
"print reduce(lambda x,y: pd.concat([x,y], axis=1), prdd.map(lambda x: x.mean()).collect()).mean(axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Old method (without arrow)\n", | |
"def _get_numpy_record_dtype(rec):\n", | |
" \"\"\"\n", | |
" Used when converting a pandas.DataFrame to Spark using to_records(), this will correct\n", | |
" the dtypes of fields in a record so they can be properly loaded into Spark.\n", | |
" :param rec: a numpy record to check field dtypes\n", | |
" :return corrected dtype for a numpy.record or None if no correction needed\n", | |
" \"\"\"\n", | |
" import numpy as np\n", | |
" cur_dtypes = rec.dtype\n", | |
" col_names = cur_dtypes.names\n", | |
" record_type_list = []\n", | |
" has_rec_fix = False\n", | |
" for i in xrange(len(cur_dtypes)):\n", | |
" curr_type = cur_dtypes[i]\n", | |
" # If type is a datetime64 timestamp, convert to microseconds\n", | |
" # NOTE: if dtype is datetime[ns] then np.record.tolist() will output values as longs,\n", | |
" # conversion from [us] or lower will lead to py datetime objects, see SPARK-22417\n", | |
" if curr_type == np.dtype('datetime64[ns]'):\n", | |
" curr_type = 'datetime64[us]'\n", | |
" has_rec_fix = True\n", | |
" record_type_list.append((str(col_names[i]), curr_type))\n", | |
" return np.dtype(record_type_list) if has_rec_fix else None\n", | |
"\n", | |
"\n", | |
"def pandas_to_rows(pdf):\n", | |
" \"\"\"\n", | |
" Convert a pandas.DataFrame to list of records that can be used to make a DataFrame\n", | |
" :return list of records\n", | |
" \"\"\"\n", | |
" schema = [str(x) if not isinstance(x, basestring) else\n", | |
" (x.encode('utf-8') if not isinstance(x, str) else x)\n", | |
" for x in pdf.columns]\n", | |
"\n", | |
" from pyspark.sql import Row\n", | |
" Record = Row(*schema)\n", | |
"\n", | |
" # Convert pandas.DataFrame to list of numpy records\n", | |
" np_records = pdf.to_records(index=False)\n", | |
"\n", | |
" # Check if any columns need to be fixed for Spark to infer properly\n", | |
" if len(np_records) > 0:\n", | |
" record_dtype = _get_numpy_record_dtype(np_records[0])\n", | |
" if record_dtype is not None:\n", | |
" return [Record(*r.astype(record_dtype).tolist()) for r in np_records]\n", | |
"\n", | |
" # Convert list of numpy records to python lists\n", | |
" return [Record(*r.tolist()) for r in np_records]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"+------------------+------------------+------------------+------------------+\n", | |
"| avg(A)| avg(B)| avg(C)| avg(D)|\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"|49.497162437438966|49.512370872497556|49.502609729766846|49.511417865753174|\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"| avg(A)| avg(B)| avg(C)| avg(D)|\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"|49.497162437438966|49.512370872497556|49.502609729766846|49.511417865753174|\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"| avg(A)| avg(B)| avg(C)| avg(D)|\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"|49.497162437438966|49.512370872497556|49.502609729766846|49.511417865753174|\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"| avg(A)| avg(B)| avg(C)| avg(D)|\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"|49.497162437438966|49.512370872497556|49.502609729766846|49.511417865753174|\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"\n", | |
"1 loop, best of 3: 36.1 s per loop\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"df = prdd.flatMap(pandas_to_rows).toDF()\n", | |
"df.agg({x: \"avg\" for x in df.columns}).show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# New method - Using Arrow for serialization\n", | |
"def _dataframe_to_arrow_record_batch(pdf, schema=None, timezone=None, parallelism=1):\n", | |
" \"\"\"\n", | |
" Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting\n", | |
" to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the\n", | |
" data types will be used to coerce the data in Pandas to Arrow conversion.\n", | |
" \"\"\"\n", | |
"\n", | |
" from pyspark.serializers import ArrowSerializer, _create_batch\n", | |
" from pyspark.sql.types import from_arrow_schema, to_arrow_type, TimestampType, Row, DataType, StringType, StructType\n", | |
" from pyspark.sql.utils import require_minimum_pandas_version, \\\n", | |
" require_minimum_pyarrow_version\n", | |
"\n", | |
" require_minimum_pandas_version()\n", | |
" require_minimum_pyarrow_version()\n", | |
"\n", | |
" from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype\n", | |
"\n", | |
" # Determine arrow types to coerce data when creating batches\n", | |
" if isinstance(schema, StructType):\n", | |
" arrow_types = [to_arrow_type(f.dataType) for f in schema.fields]\n", | |
" elif isinstance(schema, DataType):\n", | |
" raise ValueError(\"Single data type %s is not supported with Arrow\" % str(schema))\n", | |
" else:\n", | |
" # Any timestamps must be coerced to be compatible with Spark\n", | |
" arrow_types = [to_arrow_type(TimestampType())\n", | |
" if is_datetime64_dtype(t) or is_datetime64tz_dtype(t) else None\n", | |
" for t in pdf.dtypes]\n", | |
"\n", | |
" # Slice the DataFrame to be batched\n", | |
" step = -(-len(pdf) // parallelism) # round int up\n", | |
" pdf_slices = (pdf[start:start + step] for start in xrange(0, len(pdf), step))\n", | |
"\n", | |
" # Create Arrow record batches\n", | |
" batches = [_create_batch([(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)],\n", | |
" timezone)\n", | |
" for pdf_slice in pdf_slices]\n", | |
"\n", | |
" return map(bytearray, map(ArrowSerializer().dumps, batches))\n", | |
"\n", | |
"\n", | |
"def createFromPandasDataframesRDD(self, prdd, schema=None, timezone=None):\n", | |
" from pyspark.sql.types import from_arrow_schema\n", | |
" from pyspark.sql.dataframe import DataFrame\n", | |
" from pyspark.serializers import ArrowSerializer, PickleSerializer, AutoBatchedSerializer\n", | |
"\n", | |
" # Map rdd of pandas dataframes to arrow record batches\n", | |
" prdd = prdd.filter(lambda x: isinstance(x, pd.DataFrame)).cache()\n", | |
"\n", | |
" # If schema is not defined, get from the first dataframe\n", | |
" if schema is None:\n", | |
" schema = [str(x) if not isinstance(x, basestring) else\n", | |
" (x.encode('utf-8') if not isinstance(x, str) else x)\n", | |
" for x in prdd.map(lambda x: x.columns).first()]\n", | |
"\n", | |
" prdd = prdd.flatMap(lambda x: _dataframe_to_arrow_record_batch(x, schema=schema, timezone=timezone))\n", | |
"\n", | |
" # Create the Spark schema from the first Arrow batch (always at least 1 batch after slicing)\n", | |
" struct = from_arrow_schema(ArrowSerializer().loads(prdd.first()).schema)\n", | |
" for i, name in enumerate(schema):\n", | |
" struct.fields[i].name = name\n", | |
" struct.names[i] = name\n", | |
" schema = struct\n", | |
"\n", | |
" # Create the Spark DataFrame directly from the Arrow data and schema\n", | |
" jrdd = prdd._to_java_object_rdd()\n", | |
" jdf = self._jvm.PythonSQLUtils.arrowPayloadToDataFrame(\n", | |
" jrdd, schema.json(), self._wrapped._jsqlContext)\n", | |
" df = DataFrame(jdf, self._wrapped)\n", | |
" df._schema = schema\n", | |
"\n", | |
" return df\n", | |
"\n", | |
"from pyspark.sql import SparkSession\n", | |
"SparkSession.createFromPandasDataframesRDD = createFromPandasDataframesRDD" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 27, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"+------------------+------------------+------------------+------------------+\n", | |
"| avg(A)| avg(B)| avg(C)| avg(D)|\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"|49.497162437438966|49.512370872497556|49.502609729766846|49.511417865753174|\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"| avg(A)| avg(B)| avg(C)| avg(D)|\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"|49.497162437438966|49.512370872497556|49.502609729766846|49.511417865753174|\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"| avg(A)| avg(B)| avg(C)| avg(D)|\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"|49.497162437438966|49.512370872497556|49.502609729766846|49.511417865753174|\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"| avg(A)| avg(B)| avg(C)| avg(D)|\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"|49.497162437438966|49.512370872497556|49.502609729766846|49.511417865753174|\n", | |
"+------------------+------------------+------------------+------------------+\n", | |
"\n", | |
"1 loop, best of 3: 6.03 s per loop\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"df = spark.createFromPandasDataframesRDD(prdd)\n", | |
"df.agg({x: \"avg\" for x in df.columns}).show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Native method - 1.52s\n", | |
"# toRow serialization - 36.1s\n", | |
"# Arrow serialization - 6.03s" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "PySpark", | |
"language": "python", | |
"name": "pyspark" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.14" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment