Created
October 28, 2021 08:55
-
-
Save lissahyacinth/ee676de671c18c36515e7682e74e3aa0 to your computer and use it in GitHub Desktop.
Stripped down PyArrow Metatest
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
import base64 | |
import itertools | |
import os | |
import signal | |
import struct | |
import tempfile | |
import threading | |
import time | |
import traceback | |
import json | |
import numpy as np | |
import pytest | |
import pyarrow as pa | |
from pyarrow.lib import tobytes | |
from pyarrow.util import pathlib, find_free_port | |
from pyarrow.tests import util | |
from pyarrow import flight | |
from pyarrow.flight import ( | |
FlightClient, FlightServerBase, | |
ServerAuthHandler, ClientAuthHandler, | |
ServerMiddleware, ServerMiddlewareFactory, | |
ClientMiddleware, ClientMiddlewareFactory, | |
) | |
class MetadataFlightServer(FlightServerBase): | |
"""A Flight server that numbers incoming/outgoing data.""" | |
def __init__(self, options=None, **kwargs): | |
super().__init__(**kwargs) | |
self.options = options | |
def do_get(self, context, ticket): | |
data = [ | |
pa.array([-10, -5, 0, 5, 10]) | |
] | |
table = pa.Table.from_arrays(data, names=['a']) | |
return flight.GeneratorStream( | |
table.schema, | |
self.number_batches(table), | |
options=self.options) | |
def do_put(self, context, descriptor, reader, writer): | |
counter = 0 | |
expected_data = [-10, -5, 0, 5, 10] | |
while True: | |
try: | |
batch, buf = reader.read_chunk() | |
assert batch.equals(pa.RecordBatch.from_arrays( | |
[pa.array([expected_data[counter]])], | |
['a'] | |
)) | |
assert buf is not None | |
client_counter, = struct.unpack('<i', buf.to_pybytes()) | |
assert counter == client_counter | |
writer.write(struct.pack('<i', counter)) | |
counter += 1 | |
except StopIteration: | |
return | |
@staticmethod | |
def number_batches(table): | |
for idx, batch in enumerate(table.to_batches()): | |
buf = struct.pack('<i', idx) | |
yield batch, buf | |
class ExchangeFlightServer(FlightServerBase): | |
"""A server for testing DoExchange.""" | |
def __init__(self, options=None, **kwargs): | |
super().__init__(**kwargs) | |
self.options = options | |
def do_exchange(self, context, descriptor, reader, writer): | |
if descriptor.descriptor_type != flight.DescriptorType.CMD: | |
raise pa.ArrowInvalid("Must provide a command descriptor") | |
elif descriptor.command == b"echo": | |
return self.exchange_echo(context, reader, writer) | |
elif descriptor.command == b"get": | |
return self.exchange_do_get(context, reader, writer) | |
elif descriptor.command == b"put": | |
return self.exchange_do_put(context, reader, writer) | |
elif descriptor.command == b"transform": | |
return self.exchange_transform(context, reader, writer) | |
else: | |
raise pa.ArrowInvalid( | |
"Unknown command: {}".format(descriptor.command)) | |
def exchange_do_get(self, context, reader, writer): | |
"""Emulate DoGet with DoExchange.""" | |
data = pa.Table.from_arrays([ | |
pa.array(range(0, 10 * 1024)) | |
], names=["a"]) | |
writer.begin(data.schema) | |
writer.write_table(data) | |
def exchange_do_put(self, context, reader, writer): | |
"""Emulate DoPut with DoExchange.""" | |
num_batches = 0 | |
for chunk in reader: | |
if not chunk.data: | |
raise pa.ArrowInvalid("All chunks must have data.") | |
num_batches += 1 | |
writer.write_metadata(str(num_batches).encode("utf-8")) | |
def exchange_echo(self, context, reader, writer): | |
"""Run a simple echo server.""" | |
started = False | |
for chunk in reader: | |
if not started and chunk.data: | |
writer.begin(chunk.data.schema, options=self.options) | |
started = True | |
if chunk.app_metadata and chunk.data: | |
writer.write_with_metadata(chunk.data, chunk.app_metadata) | |
elif chunk.app_metadata: | |
writer.write_metadata(chunk.app_metadata) | |
elif chunk.data: | |
writer.write_batch(chunk.data) | |
else: | |
assert False, "Should not happen" | |
def exchange_transform(self, context, reader, writer): | |
"""Sum rows in an uploaded table.""" | |
for field in reader.schema: | |
if not pa.types.is_integer(field.type): | |
raise pa.ArrowInvalid("Invalid field: " + repr(field)) | |
table = reader.read_all() | |
sums = [0] * table.num_rows | |
for column in table: | |
for row, value in enumerate(column): | |
sums[row] += value.as_py() | |
result = pa.Table.from_arrays([pa.array(sums)], names=["sum"]) | |
writer.begin(result.schema) | |
writer.write_table(result) | |
def test_flight_list_flights(): | |
"""Try a simple list_flights call.""" | |
with ConstantFlightServer() as server: | |
client = flight.connect(('localhost', server.port)) | |
assert list(client.list_flights()) == [] | |
flights = client.list_flights(ConstantFlightServer.CRITERIA) | |
assert len(list(flights)) == 1 | |
def test_flight_do_get_metadata(): | |
"""Try a simple do_get call with metadata.""" | |
data = [ | |
pa.array([-10, -5, 0, 5, 10]) | |
] | |
table = pa.Table.from_arrays(data, names=['a']) | |
batches = [] | |
with MetadataFlightServer() as server: | |
client = FlightClient(('localhost', server.port)) | |
reader = client.do_get(flight.Ticket(b'')) | |
idx = 0 | |
while True: | |
try: | |
batch, metadata = reader.read_chunk() | |
batches.append(batch) | |
server_idx, = struct.unpack('<i', metadata.to_pybytes()) | |
assert idx == server_idx | |
idx += 1 | |
except StopIteration: | |
break | |
data = pa.Table.from_batches(batches) | |
assert data.equals(table) | |
def test_flight_do_get_metadata_v4(): | |
"""Try a simple do_get call with V4 metadata version.""" | |
table = pa.Table.from_arrays( | |
[pa.array([-10, -5, 0, 5, 10])], names=['a']) | |
options = pa.ipc.IpcWriteOptions( | |
metadata_version=pa.ipc.MetadataVersion.V4) | |
with MetadataFlightServer(options=options) as server: | |
client = FlightClient(('localhost', server.port)) | |
reader = client.do_get(flight.Ticket(b'')) | |
data = reader.read_all() | |
assert data.equals(table) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment