Skip to content

Instantly share code, notes, and snippets.

@lupko
Last active February 7, 2024 15:45
Show Gist options
  • Save lupko/8b6f165a6574ef830c531c8056b20957 to your computer and use it in GitHub Desktop.
Save lupko/8b6f165a6574ef830c531c8056b20957 to your computer and use it in GitHub Desktop.
Flight RPC + ADBC crash
import pyarrow.flight
import adbc_driver_postgresql.dbapi
_USERNAME = ".."
_PASSWORD = ".."
_DATABASE = ".."
_INIT = [
"DROP TABLE numeric_test;",
"CREATE TABLE numeric_test (col numeric(16, 10));",
"""INSERT INTO numeric_test VALUES
(0.0),
(1.0),
(1.01),
(1.012),
(1.0123),
(1.01234),
(1.012345),
(1.0123456),
(1.01234567),
(1.012345678),
(1.0123456789),
(1.0123456789);""",
]
class SampleMiddlewareFactory(pyarrow.flight.ServerMiddlewareFactory):
def start_call(self, info, headers):
return SampleMiddleware()
class SampleMiddleware(pyarrow.flight.ServerMiddleware):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._cursor = None
def set_cursor_to_close(self, cursor) -> None:
self._cursor = cursor
def call_completed(self, exception):
if self._cursor is not None:
self._cursor.close()
class SampleFlightServer(pyarrow.flight.FlightServerBase):
def __init__(self):
super().__init__(
location="grpc://localhost:11666",
middleware={"mw": SampleMiddlewareFactory()}
)
self._c = adbc_driver_postgresql.dbapi.connect(
f"postgresql://{_USERNAME}:{_PASSWORD}@localhost:5432/{_DATABASE}"
)
with self._c.cursor() as c:
# PostgreSQL 16.1 on x86_64-redhat-linux-gnu, compiled by gcc (GCC) 13.2.1 20231011 (Red Hat 13.2.1-4), 64-bit
c.execute("SELECT version();")
print(c.fetchone()[0])
for stmt in _INIT:
c.execute(stmt)
print("Server is up.")
def do_get(self, context, ticket):
print("Handling do get")
mw = context.get_middleware("mw")
cursor = self._c.cursor()
cursor.execute("SELECT * FROM numeric_test")
reader = cursor.fetch_record_batch()
# push cursor to the middleware so that it will close it
# on call completion (once all data is sent out)
mw.set_cursor_to_close(cursor)
print("Returning stream")
return pyarrow.flight.RecordBatchStream(reader)
SampleFlightServer()
client = pyarrow.flight.FlightClient(location="grpc://localhost:11666")
client.do_get(pyarrow.flight.Ticket(ticket=b""))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment