Skip to content

Instantly share code, notes, and snippets.

@hussainsultan
Created July 5, 2025 14:49
Show Gist options
  • Select an option

  • Save hussainsultan/6146a6c57b573256121e860c2cb1d4bb to your computer and use it in GitHub Desktop.

Select an option

Save hussainsultan/6146a6c57b573256121e860c2cb1d4bb to your computer and use it in GitHub Desktop.
xorq_ipc_bench.py
import argparse
import time
import numpy as np
import pyarrow as pa
import pandas as pd
import xorq as xo
from xorq.flight.client import FlightClient
from attr import frozen, field
from attr.validators import instance_of
import pyarrow.parquet as pq
import io
import gcsfs
@frozen
class FlightStorage:
client: FlightClient = field(validator=instance_of(FlightClient))
_fs: gcsfs.GCSFileSystem = field(init=False, default=None)
@property
def fs(self):
if self._fs is None:
object.__setattr__(self, '_fs', gcsfs.GCSFileSystem())
return self._fs
def _put(self, key, table):
print(f"Uploading {key} to GCS and Flight")
self._put_to_flight(key, table)
def _put_to_flight(self, key, table):
self.client.upload_data(key, table)
def _get(self, key, schema):
print(f"Fetching {key} using serialized XORQ expression")
expr = xo.table(schema=schema, name=key)
return self.client.execute_query(expr)
def create_batches(total_rows, cols, batch_size):
fields = [pa.field(f"int_{i}", pa.int32()) for i in range(cols // 2)] + \
[pa.field(f"float_{i}", pa.float64()) for i in range(cols // 2)]
schema = pa.schema(fields)
batches = []
for start_row in range(0, total_rows, batch_size):
current_batch_rows = min(batch_size, total_rows - start_row)
arrays = [
pa.array(np.random.randint(0, 100000, size=current_batch_rows), pa.int32())
for _ in range(cols // 2)
] + [
pa.array(np.random.random(size=current_batch_rows), pa.float64())
for _ in range(cols // 2)
]
batch = pa.RecordBatch.from_arrays(arrays, schema)
batches.append(batch)
return batches
def benchmark_xorq_flight(storage, batches, key):
table = pa.Table.from_batches(batches)
start_put = time.time()
storage._put(key, table)
put_time = time.time() - start_put
start_get = time.time()
fetched_table = storage._get(key, table.schema)
get_time = time.time() - start_get
ipc_size = sum(len(batch.serialize().to_pybytes()) for batch in batches)
return put_time, get_time, ipc_size
def main():
parser = argparse.ArgumentParser(description="xorq Flight Client IPC benchmark")
parser.add_argument("--host", default="localhost")
parser.add_argument("--port", type=int, default=50051)
parser.add_argument("--batch_size", type=int, default=8192, help="RecordBatch size")
args = parser.parse_args()
flight_client = FlightClient(host=args.host, port=args.port)
storage = FlightStorage(client=flight_client)
results = []
for total_rows in [10_000, 50_000, 100_000, 500_000, 1_000_000]:
for cols in [10, 50, 100]:
batches = create_batches(total_rows, cols, args.batch_size)
key = f"xorq_batch_{total_rows}_{cols}"
put_time, get_time, ipc_size = benchmark_xorq_flight(storage, batches, key)
results.append({
"total_rows": total_rows,
"cols": cols,
"batch_size": args.batch_size,
"num_batches": len(batches),
"ipc_size_mb": ipc_size / (1024 ** 2),
"put_time_sec": put_time,
"get_time_sec": get_time,
"put_throughput_mb_s": ipc_size / (put_time * 1024 ** 2),
"get_throughput_mb_s": ipc_size / (get_time * 1024 ** 2)
})
print(f"[{total_rows} rows x {cols} cols, batch size {args.batch_size}] PUT: {put_time:.2f}s, GET: {get_time:.2f}s, Size: {ipc_size / (1024 ** 2):.2f}MB")
df = pd.DataFrame(results)
print("\nBenchmark Summary:")
print(df.to_string(index=False))
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment