Last active
February 7, 2024 15:45
-
-
Save lupko/8b6f165a6574ef830c531c8056b20957 to your computer and use it in GitHub Desktop.
Flight RPC + ADBC crash
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pyarrow.flight | |
import adbc_driver_postgresql.dbapi | |
_USERNAME = ".." | |
_PASSWORD = ".." | |
_DATABASE = ".." | |
_INIT = [ | |
"DROP TABLE numeric_test;", | |
"CREATE TABLE numeric_test (col numeric(16, 10));", | |
"""INSERT INTO numeric_test VALUES | |
(0.0), | |
(1.0), | |
(1.01), | |
(1.012), | |
(1.0123), | |
(1.01234), | |
(1.012345), | |
(1.0123456), | |
(1.01234567), | |
(1.012345678), | |
(1.0123456789), | |
(1.0123456789);""", | |
] | |
class SampleMiddlewareFactory(pyarrow.flight.ServerMiddlewareFactory): | |
def start_call(self, info, headers): | |
return SampleMiddleware() | |
class SampleMiddleware(pyarrow.flight.ServerMiddleware): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._cursor = None | |
def set_cursor_to_close(self, cursor) -> None: | |
self._cursor = cursor | |
def call_completed(self, exception): | |
if self._cursor is not None: | |
self._cursor.close() | |
class SampleFlightServer(pyarrow.flight.FlightServerBase): | |
def __init__(self): | |
super().__init__( | |
location="grpc://localhost:11666", | |
middleware={"mw": SampleMiddlewareFactory()} | |
) | |
self._c = adbc_driver_postgresql.dbapi.connect( | |
f"postgresql://{_USERNAME}:{_PASSWORD}@localhost:5432/{_DATABASE}" | |
) | |
with self._c.cursor() as c: | |
# PostgreSQL 16.1 on x86_64-redhat-linux-gnu, compiled by gcc (GCC) 13.2.1 20231011 (Red Hat 13.2.1-4), 64-bit | |
c.execute("SELECT version();") | |
print(c.fetchone()[0]) | |
for stmt in _INIT: | |
c.execute(stmt) | |
print("Server is up.") | |
def do_get(self, context, ticket): | |
print("Handling do get") | |
mw = context.get_middleware("mw") | |
cursor = self._c.cursor() | |
cursor.execute("SELECT * FROM numeric_test") | |
reader = cursor.fetch_record_batch() | |
# push cursor to the middleware so that it will close it | |
# on call completion (once all data is sent out) | |
mw.set_cursor_to_close(cursor) | |
print("Returning stream") | |
return pyarrow.flight.RecordBatchStream(reader) | |
SampleFlightServer() | |
client = pyarrow.flight.FlightClient(location="grpc://localhost:11666") | |
client.do_get(pyarrow.flight.Ticket(ticket=b"")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment