Skip to content

Instantly share code, notes, and snippets.

@andrewgross
Created February 4, 2019 20:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save andrewgross/d82ae0bdff86b12541591866414071cb to your computer and use it in GitHub Desktop.
Save andrewgross/d82ae0bdff86b12541591866414071cb to your computer and use it in GitHub Desktop.
PySpark code to take a dataframe and repartition it in to an optimal number of partitions for generating 300Mb-1GB parquet files.
import re
import pyspark.sql.types as T
from math import ceil
def repartition_for_writing(df):
count = df.count()
sampled_df = get_sampled_df(df, count=count)
string_column_sizes = get_string_column_sizes(sampled_df)
num_files = get_num_files(count, df.schema, string_column_sizes)
print(num_files)
return df.repartition(num_files)
def get_sampled_df(df, count=None):
if not count:
count = df.count()
sample_size = 100000.0
raw_fraction = sample_size / count
clamped_fraction = min(raw_fraction, 1.0)
return df.sample(withReplacement=False, fraction=clamped_fraction)
def get_string_column_sizes(df):
ddf = df
string_cols = []
for column in ddf.schema:
if isinstance(column.dataType, T.StringType):
ddf = ddf.withColumn('{}__length'.format(column.name), F.length(column.name))
string_cols.append('{}__length'.format(column.name))
sizes = ddf.groupBy().avg(*string_cols).first().asDict()
cleaned_sizes = {}
for k, v in sizes.items():
col_name = re.search(r'avg\((.+)__length\)', k).groups()[0]
cleaned_sizes[col_name] = v
return cleaned_sizes
def get_num_files(rows, schema, string_column_sizes):
record_size = _get_record_size(schema, string_column_sizes)
return _get_files_based_on_file_size(rows, record_size)
def _get_record_size(schema, string_field_sizes):
size_mapping = get_size_mapping()
record_size = 0
for field in schema:
_type = field.dataType.typeName()
if _type == "string":
# Fetch our avg size for a string field, convert from bytes to bits
field_size = string_field_sizes[field.name] * 8
else:
field_size = size_mapping[_type]
record_size = record_size + field_size
return record_size
def _get_files_based_on_file_size(rows, record_size):
compression_ratio = get_compression_ratio("snappy")
data_size_in_bits = record_size * rows
data_size_in_bytes = data_size_in_bits / 8
data_size_mb = data_size_in_bytes / (1024 * 1024)
data_size_compressed = data_size_mb * compression_ratio
# Aim for 1 GB files
num_files = data_size_compressed / 1024
return int(ceil(num_files))
def get_size_mapping():
"""
Size mapping for non-string fields in Parquet files
"""
PARQUET_SIZE_MAPPING = {
"short": 32,
"integer": 32,
"long": 64,
"boolean": 1,
"float": 32,
"double": 64,
"decimal": 64,
"date": 32, # Assume no date64
"timestamp": 96, # Assume legacy timestamp
}
return PARQUET_SIZE_MAPPING
def get_compression_ratio(compression):
"""
Return a floating point scalar for the size after compression.
"""
if compression == "snappy":
return 0.6
return 1.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment