Skip to content

Instantly share code, notes, and snippets.

@RaczeQ
Last active May 19, 2024 19:41
Show Gist options
  • Save RaczeQ/cb4c65c3626ae410b63a5e6caa71b6dd to your computer and use it in GitHub Desktop.
Save RaczeQ/cb4c65c3626ae410b63a5e6caa71b6dd to your computer and use it in GitHub Desktop.
Pyarrow Multiprocessing with streaming the result
import multiprocessing
from pathlib import Path
from queue import Queue
from time import sleep
from typing import Callable
import pyarrow as pa
import pyarrow.parquet as pq
from tqdm import tqdm
def _intersection_worker(
queue: Queue[tuple[str, int]],
save_path: Path,
function: Callable[[pa.Table], pa.Table],
columns: Optional[list[str]] = None
) -> None: # pragma: no cover
current_pid = multiprocessing.current_process().pid
filepath = save_path / f"{current_pid}.parquet"
writer = None
while not queue.empty():
try:
file_name = None
file_name, row_group_index = queue.get(block=True, timeout=1)
pq_file = pq.ParquetFile(file_name)
row_group_table = pq_file.read_row_group(row_group_index, columns=columns)
if len(row_group_table) == 0:
continue
result_table = function(row_group_table)
if not writer:
writer = pq.ParquetWriter(filepath, result_table.schema)
writer.write_table(result_table)
except Exception as ex:
log_message(ex)
if file_name is not None:
queue.put((file_name, row_group_index))
if writer:
writer.close()
def map_parquet_dataset(
dataset_path: Path,
destination_path: Path,
function: Callable[[pa.Table], pa.Table],
columns: Optional[list[str]] = None,
) -> None:
"""
Apply a function over parquet dataset in a multiprocessing environment.
Will save results in multiple files in a destination path.
Args:
dataset_path (Path): Path of the parquet dataset.
destination_path (Path): Path of the destination.
function (Callable[[pa.Table], pa.Table]): Function to apply over a row group table.
Will save resulting table in a new parquet file.
columns (Optional[list[str]]): List of columns to read. Defaults to `None`.
"""
queue: Queue[tuple[str, int]] = multiprocessing.Manager().Queue()
dataset = pq.ParquetDataset(dataset_path)
for pq_file in dataset.files:
for row_group in range(pq.ParquetFile(pq_file).num_row_groups):
queue.put((pq_file, row_group))
total = queue.qsize()
destination_path.mkdir(parents=True, exist_ok=True)
try:
processes = [
multiprocessing.Process(
target=_intersection_worker,
args=(queue, destination_path, function, columns),
)
for _ in range(multiprocessing.cpu_count())
]
# Run processes
for p in processes:
p.start()
# Report progress with TQDM
with tqdm(total=total) as bar:
while any(process.is_alive() for process in processes):
bar.n = total - queue.qsize()
bar.refresh()
sleep(1)
bar.n = total
bar.refresh()
finally:
# In case of exception - stop all processes
for p in processes:
if p.is_alive():
p.terminate()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment