-
-
Save hussainsultan/6146a6c57b573256121e860c2cb1d4bb to your computer and use it in GitHub Desktop.
xorq_ipc_bench.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import time | |
| import numpy as np | |
| import pyarrow as pa | |
| import pandas as pd | |
| import xorq as xo | |
| from xorq.flight.client import FlightClient | |
| from attr import frozen, field | |
| from attr.validators import instance_of | |
| import pyarrow.parquet as pq | |
| import io | |
| import gcsfs | |
| @frozen | |
| class FlightStorage: | |
| client: FlightClient = field(validator=instance_of(FlightClient)) | |
| _fs: gcsfs.GCSFileSystem = field(init=False, default=None) | |
| @property | |
| def fs(self): | |
| if self._fs is None: | |
| object.__setattr__(self, '_fs', gcsfs.GCSFileSystem()) | |
| return self._fs | |
| def _put(self, key, table): | |
| print(f"Uploading {key} to GCS and Flight") | |
| self._put_to_flight(key, table) | |
| def _put_to_flight(self, key, table): | |
| self.client.upload_data(key, table) | |
| def _get(self, key, schema): | |
| print(f"Fetching {key} using serialized XORQ expression") | |
| expr = xo.table(schema=schema, name=key) | |
| return self.client.execute_query(expr) | |
| def create_batches(total_rows, cols, batch_size): | |
| fields = [pa.field(f"int_{i}", pa.int32()) for i in range(cols // 2)] + \ | |
| [pa.field(f"float_{i}", pa.float64()) for i in range(cols // 2)] | |
| schema = pa.schema(fields) | |
| batches = [] | |
| for start_row in range(0, total_rows, batch_size): | |
| current_batch_rows = min(batch_size, total_rows - start_row) | |
| arrays = [ | |
| pa.array(np.random.randint(0, 100000, size=current_batch_rows), pa.int32()) | |
| for _ in range(cols // 2) | |
| ] + [ | |
| pa.array(np.random.random(size=current_batch_rows), pa.float64()) | |
| for _ in range(cols // 2) | |
| ] | |
| batch = pa.RecordBatch.from_arrays(arrays, schema) | |
| batches.append(batch) | |
| return batches | |
| def benchmark_xorq_flight(storage, batches, key): | |
| table = pa.Table.from_batches(batches) | |
| start_put = time.time() | |
| storage._put(key, table) | |
| put_time = time.time() - start_put | |
| start_get = time.time() | |
| fetched_table = storage._get(key, table.schema) | |
| get_time = time.time() - start_get | |
| ipc_size = sum(len(batch.serialize().to_pybytes()) for batch in batches) | |
| return put_time, get_time, ipc_size | |
| def main(): | |
| parser = argparse.ArgumentParser(description="xorq Flight Client IPC benchmark") | |
| parser.add_argument("--host", default="localhost") | |
| parser.add_argument("--port", type=int, default=50051) | |
| parser.add_argument("--batch_size", type=int, default=8192, help="RecordBatch size") | |
| args = parser.parse_args() | |
| flight_client = FlightClient(host=args.host, port=args.port) | |
| storage = FlightStorage(client=flight_client) | |
| results = [] | |
| for total_rows in [10_000, 50_000, 100_000, 500_000, 1_000_000]: | |
| for cols in [10, 50, 100]: | |
| batches = create_batches(total_rows, cols, args.batch_size) | |
| key = f"xorq_batch_{total_rows}_{cols}" | |
| put_time, get_time, ipc_size = benchmark_xorq_flight(storage, batches, key) | |
| results.append({ | |
| "total_rows": total_rows, | |
| "cols": cols, | |
| "batch_size": args.batch_size, | |
| "num_batches": len(batches), | |
| "ipc_size_mb": ipc_size / (1024 ** 2), | |
| "put_time_sec": put_time, | |
| "get_time_sec": get_time, | |
| "put_throughput_mb_s": ipc_size / (put_time * 1024 ** 2), | |
| "get_throughput_mb_s": ipc_size / (get_time * 1024 ** 2) | |
| }) | |
| print(f"[{total_rows} rows x {cols} cols, batch size {args.batch_size}] PUT: {put_time:.2f}s, GET: {get_time:.2f}s, Size: {ipc_size / (1024 ** 2):.2f}MB") | |
| df = pd.DataFrame(results) | |
| print("\nBenchmark Summary:") | |
| print(df.to_string(index=False)) | |
| if __name__ == "__main__": | |
| main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment