Skip to content

Instantly share code, notes, and snippets.

@lissahyacinth
Created October 28, 2021 08:55
Show Gist options
  • Save lissahyacinth/ee676de671c18c36515e7682e74e3aa0 to your computer and use it in GitHub Desktop.
Save lissahyacinth/ee676de671c18c36515e7682e74e3aa0 to your computer and use it in GitHub Desktop.
Stripped down PyArrow Metatest
import ast
import base64
import itertools
import os
import signal
import struct
import tempfile
import threading
import time
import traceback
import json
import numpy as np
import pytest
import pyarrow as pa
from pyarrow.lib import tobytes
from pyarrow.util import pathlib, find_free_port
from pyarrow.tests import util
from pyarrow import flight
from pyarrow.flight import (
FlightClient, FlightServerBase,
ServerAuthHandler, ClientAuthHandler,
ServerMiddleware, ServerMiddlewareFactory,
ClientMiddleware, ClientMiddlewareFactory,
)
class MetadataFlightServer(FlightServerBase):
"""A Flight server that numbers incoming/outgoing data."""
def __init__(self, options=None, **kwargs):
super().__init__(**kwargs)
self.options = options
def do_get(self, context, ticket):
data = [
pa.array([-10, -5, 0, 5, 10])
]
table = pa.Table.from_arrays(data, names=['a'])
return flight.GeneratorStream(
table.schema,
self.number_batches(table),
options=self.options)
def do_put(self, context, descriptor, reader, writer):
counter = 0
expected_data = [-10, -5, 0, 5, 10]
while True:
try:
batch, buf = reader.read_chunk()
assert batch.equals(pa.RecordBatch.from_arrays(
[pa.array([expected_data[counter]])],
['a']
))
assert buf is not None
client_counter, = struct.unpack('<i', buf.to_pybytes())
assert counter == client_counter
writer.write(struct.pack('<i', counter))
counter += 1
except StopIteration:
return
@staticmethod
def number_batches(table):
for idx, batch in enumerate(table.to_batches()):
buf = struct.pack('<i', idx)
yield batch, buf
class ExchangeFlightServer(FlightServerBase):
"""A server for testing DoExchange."""
def __init__(self, options=None, **kwargs):
super().__init__(**kwargs)
self.options = options
def do_exchange(self, context, descriptor, reader, writer):
if descriptor.descriptor_type != flight.DescriptorType.CMD:
raise pa.ArrowInvalid("Must provide a command descriptor")
elif descriptor.command == b"echo":
return self.exchange_echo(context, reader, writer)
elif descriptor.command == b"get":
return self.exchange_do_get(context, reader, writer)
elif descriptor.command == b"put":
return self.exchange_do_put(context, reader, writer)
elif descriptor.command == b"transform":
return self.exchange_transform(context, reader, writer)
else:
raise pa.ArrowInvalid(
"Unknown command: {}".format(descriptor.command))
def exchange_do_get(self, context, reader, writer):
"""Emulate DoGet with DoExchange."""
data = pa.Table.from_arrays([
pa.array(range(0, 10 * 1024))
], names=["a"])
writer.begin(data.schema)
writer.write_table(data)
def exchange_do_put(self, context, reader, writer):
"""Emulate DoPut with DoExchange."""
num_batches = 0
for chunk in reader:
if not chunk.data:
raise pa.ArrowInvalid("All chunks must have data.")
num_batches += 1
writer.write_metadata(str(num_batches).encode("utf-8"))
def exchange_echo(self, context, reader, writer):
"""Run a simple echo server."""
started = False
for chunk in reader:
if not started and chunk.data:
writer.begin(chunk.data.schema, options=self.options)
started = True
if chunk.app_metadata and chunk.data:
writer.write_with_metadata(chunk.data, chunk.app_metadata)
elif chunk.app_metadata:
writer.write_metadata(chunk.app_metadata)
elif chunk.data:
writer.write_batch(chunk.data)
else:
assert False, "Should not happen"
def exchange_transform(self, context, reader, writer):
"""Sum rows in an uploaded table."""
for field in reader.schema:
if not pa.types.is_integer(field.type):
raise pa.ArrowInvalid("Invalid field: " + repr(field))
table = reader.read_all()
sums = [0] * table.num_rows
for column in table:
for row, value in enumerate(column):
sums[row] += value.as_py()
result = pa.Table.from_arrays([pa.array(sums)], names=["sum"])
writer.begin(result.schema)
writer.write_table(result)
def test_flight_list_flights():
"""Try a simple list_flights call."""
with ConstantFlightServer() as server:
client = flight.connect(('localhost', server.port))
assert list(client.list_flights()) == []
flights = client.list_flights(ConstantFlightServer.CRITERIA)
assert len(list(flights)) == 1
def test_flight_do_get_metadata():
"""Try a simple do_get call with metadata."""
data = [
pa.array([-10, -5, 0, 5, 10])
]
table = pa.Table.from_arrays(data, names=['a'])
batches = []
with MetadataFlightServer() as server:
client = FlightClient(('localhost', server.port))
reader = client.do_get(flight.Ticket(b''))
idx = 0
while True:
try:
batch, metadata = reader.read_chunk()
batches.append(batch)
server_idx, = struct.unpack('<i', metadata.to_pybytes())
assert idx == server_idx
idx += 1
except StopIteration:
break
data = pa.Table.from_batches(batches)
assert data.equals(table)
def test_flight_do_get_metadata_v4():
"""Try a simple do_get call with V4 metadata version."""
table = pa.Table.from_arrays(
[pa.array([-10, -5, 0, 5, 10])], names=['a'])
options = pa.ipc.IpcWriteOptions(
metadata_version=pa.ipc.MetadataVersion.V4)
with MetadataFlightServer(options=options) as server:
client = FlightClient(('localhost', server.port))
reader = client.do_get(flight.Ticket(b''))
data = reader.read_all()
assert data.equals(table)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment