Skip to content

Instantly share code, notes, and snippets.

@Winand
Created July 26, 2022 12:04
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save Winand/3060a70c706cf83ac873f35753c91c87 to your computer and use it in GitHub Desktop.
Load parquet dataset and save with new partitioning
"""
https://stackoverflow.com/questions/68708477/repartition-large-parquet-dataset-by-ranges-of-values
"""
import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.compute as pc
import pyarrow.parquet as pq
from pathlib import Path
### SETTINGS ###
src_path = Path(R'C:\Users\Andrey.makarov\Documents\projects_ocrv\_jupyter_remote\data\axis') # data/foo.parquet
dst_path = Path('data/new')
col_key = 'invNum'
items_per_group_exponent = 17
init_data = 0
flavor = None # None: DirectoryPartitioning, hive: HivePartitioning
################
if init_data:
table = pa.Table.from_pydict({col_key: range(20), 'y': [1] * 20})
pq.write_table(table, src_path)
part = pa.dataset.partitioning(pa.schema([("partition_key", pa.int64())]), flavor=flavor)
dataset = pa.dataset.dataset(src_path)
scanner = dataset.scanner()
scanner_iter = scanner.scan_batches()
# Arrow doesn't have modulo / integer division yet but we can
# approximate it with masking (ARROW-12755).
# There will be 2^3 items per group. Adjust items_per_group_exponent
# to your liking for more items per file.
items_per_group_mask = (2 ** items_per_group_exponent) - 1
mask = ((2 ** 63) - 1) ^ items_per_group_mask
def projector():
iterations = 0
while True:
try:
next_batch = next(scanner_iter).record_batch
partition_key_arr = pc.bit_wise_and(next_batch.column(col_key), mask)
all_arrays = [*next_batch.columns, partition_key_arr]
all_names = [*next_batch.schema.names, 'partition_key']
batch_with_part = pa.RecordBatch.from_arrays(all_arrays, names=all_names)
iterations += 1
print(f'Iter {iterations}')
yield batch_with_part
except StopIteration:
print('STOP')
return
full_schema = dataset.schema.append(pa.field('partition_key', pa.int64()))
# for i, batch in enumerate(projector()):
# ...
ds.write_dataset(
projector(), dst_path, schema=full_schema, format='parquet', partitioning=part,
file_options=ds.ParquetFileFormat().make_write_options(compression='snappy')
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment