Skip to content

Instantly share code, notes, and snippets.

@iamkhush
Last active March 27, 2021 19:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save iamkhush/8612c62be915e554d9430b65fcd7d2d9 to your computer and use it in GitHub Desktop.
Save iamkhush/8612c62be915e554d9430b65fcd7d2d9 to your computer and use it in GitHub Desktop.
Python implementation for a PG Client
#! python3
from io import BytesIO
import struct
import socket
from enum import auto, Enum
from contextlib import closing
class PGEntities(Enum):
AuthenticationOk = auto()
ParameterStatus = auto()
ErrorResponse = auto()
ReadyForQuery = auto()
BackendKeyData = auto()
BackendIdle = auto()
CommandComplete = auto()
ParseComplete = auto()
BindComplete = auto()
DataRow = auto()
RowDescription = auto()
PortalSuspended = auto()
MessageFormatMapping = {
b'S': PGEntities.ParameterStatus,
b'R': PGEntities.AuthenticationOk,
b'E': PGEntities.ErrorResponse,
b'Z': PGEntities.ReadyForQuery,
b'K': PGEntities.BackendKeyData,
b'C': PGEntities.CommandComplete,
b'D': PGEntities.DataRow,
b'T': PGEntities.RowDescription,
b'1': PGEntities.ParseComplete,
b'2': PGEntities.BindComplete,
b's': PGEntities.PortalSuspended
}
PGConstants = {
b'I': PGEntities.BackendIdle
}
socket_file = "/run/postgresql/.s.PGSQL.5432"
def read_int32(sock):
return struct.unpack("!i", sock.recv(4))[0]
def read_int16(sock):
return struct.unpack("!h", sock.recv(2))[0]
def read_string(sock):
total_data = []
while True:
data = sock.recv(1)
if data == b'\0':
break
total_data.append(data)
return struct.unpack("!p", b''.join(total_data))[0]
def write_int32(buffer, data):
return buffer.write(struct.pack("!i", data))
def write_int16(buffer, data):
return buffer.write(struct.pack("!h", data))
def write_string(buffer, data):
buffer.write(data)
buffer.write(b'\0')
def write_message_type(buffer, mesage_type):
buffer.write(mesage_type)
def send_startup_packet(sock):
connection_options = {
b'user': b'ankush',
b'database': b'postgres',
b'application_name': b'this is my own client',
}
data = b""
version_message = 196608
with closing(BytesIO()) as startup_message:
for k, v in connection_options.items():
data += k + b"\0" + v + b"\0"
write_int32(startup_message, 4 + 4 + len(data) + 1)
write_int32(startup_message, version_message)
write_string(startup_message, data)
sock.send(startup_message.getvalue())
def send_simple_query(sock, query):
with closing(BytesIO()) as query_buffer:
write_message_type(query_buffer, b"Q")
write_int32(query_buffer, 4 + len(query) + 1)
write_string(query_buffer, query)
sock.send(query_buffer.getvalue())
print("Sending query", query)
def send_parse_query(sock, query, prepared_statement_name=b"", count_params=0, param_oids=None):
with closing(BytesIO()) as query_buffer:
write_message_type(query_buffer, b"P")
length = 4 + len(prepared_statement_name) + 1 + len(query) + 1 + 2 + (4 * count_params)
write_int32(query_buffer, length)
write_string(query_buffer, prepared_statement_name)
write_string(query_buffer, query)
write_int16(query_buffer, count_params) # TEXT field
for oid in param_oids:
write_int32(query_buffer, oid) # oid for TEXT
# print(sock.send(query_buffer.getvalue()), length)
print("Sending parse query", query)
sock.send(query_buffer.getvalue())
def send_bind_message(sock, portal_name, prepared_statement_name, parameter_values):
with closing(BytesIO()) as query_buffer:
write_message_type(query_buffer, b"B")
length = 4 + \
len(portal_name) + 1 + \
len(prepared_statement_name) + 1 + \
2 + \
2 + \
( 4 * len(parameter_values)) + \
sum([len(x) for x in parameter_values]) + \
2
write_int32(query_buffer, length)
write_string(query_buffer, portal_name)
write_string(query_buffer, prepared_statement_name)
write_int16(query_buffer, 0) # all text parameter format i.e not binary
write_int16(query_buffer, len(parameter_values)) # count of all parameter values
for parameter_value in parameter_values:
write_int32(query_buffer, len(parameter_value))
query_buffer.write(parameter_value)
write_int16(query_buffer, 0)
print('Sending Bind message')
sock.send(query_buffer.getvalue())
# print(sock.send(query_buffer.getvalue()), length)
def send_execute_message(sock, portal_name, limit_of_rows):
with closing(BytesIO()) as query_buffer:
write_message_type(query_buffer, b"E")
length = 4 + len(portal_name) + 1 + 4
write_int32(query_buffer, length)
write_string(query_buffer, portal_name)
write_int32(query_buffer, limit_of_rows)
print('Sending Execute message')
sock.send(query_buffer.getvalue())
# print(sock.send(query_buffer.getvalue()), length)
def send_flush(sock):
with closing(BytesIO()) as query_buffer:
write_message_type(query_buffer, b"H")
write_int32(query_buffer, 4)
sock.send(query_buffer.getvalue())
print('Sending Flush message')
def send_sync(sock):
with closing(BytesIO()) as query_buffer:
write_message_type(query_buffer, b"S")
write_int32(query_buffer, 4)
sock.send(query_buffer.getvalue())
print('Sending Sync message')
def read_first_byte(func):
def wrapper(sock, *args):
if len(args) == 1:
data = args[0]
else:
data = sock.recv(1)
return func(sock, data)
return wrapper
@read_first_byte
def read_ready_for_query(sock, data):
if MessageFormatMapping[data] == PGEntities.ReadyForQuery:
message_length = read_int32(sock)
backend_status = sock.recv(message_length - 1)
# print('Backend Status', backend_status)
print('-------------------------------------------------------')
return backend_status
@read_first_byte
def read_parameter_status(sock, data=None):
while MessageFormatMapping[data] == PGEntities.ParameterStatus:
message_length = read_int32(sock)
parameter_and_value = sock.recv(message_length - 4)
print('Param and Value', parameter_and_value.split(b'\0'))
data = sock.recv(1)
if MessageFormatMapping[data] == PGEntities.BackendKeyData:
message_length = read_int32(sock)
process_id = read_int32(sock)
secret = read_int32(sock)
print('Cancellation Key Data , process and secret are', process_id, secret)
data = sock.recv(1)
return data
@read_first_byte
def read_authentication_message(sock, data=None):
if MessageFormatMapping[data] == PGEntities.AuthenticationOk:
message_length = read_int32(sock)
is_auth_successful = read_int32(sock)
print('result for auth', is_auth_successful)
elif MessageFormatMapping[data] == PGEntities.ErrorResponse:
message_length = read_int32(sock)
error_string = sock.recv(message_length - 4)
raise Exception('Auth Failed', error_string.split(b'\0'))
else:
raise Exception("Expected auth OK or error but received ", PGEntities[data])
@read_first_byte
def read_data_row(sock, data=None):
print(data, PGEntities.DataRow)
read_int32(sock)
column_length = read_int16(sock)
for _ in range(column_length):
print('Col Value', read_string(sock))
@read_first_byte
def receive_response(sock, data=None):
while True:
print('Message Type Received ->', data, MessageFormatMapping.get(data))
message_length = read_int32(sock)
if MessageFormatMapping.get(data) in [PGEntities.ReadyForQuery]:
data = sock.recv(1)
# print("Backend Status ", data.decode('ascii'))
# dont do anything, server is ready for some new queries
print('-------------------------------------------------------')
break
payload = sock.recv(message_length - 4)
if MessageFormatMapping.get(data) in [PGEntities.RowDescription]:
# parsing has some decoding issues, so bypassing
pass
else:
print("Received Data ", payload.decode('ascii'))
data = sock.recv(1)
if __name__ == '__main__':
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.connect(socket_file)
send_startup_packet(sock)
read_authentication_message(sock)
last_byte_read = read_parameter_status(sock)
backend_status = read_ready_for_query(sock, last_byte_read)
if PGConstants[backend_status] == PGEntities.BackendIdle:
send_simple_query(sock, b"CREATE TEMPORARY TABLE my_table (id int Primary key, data TEXT);")
receive_response(sock)
send_simple_query(sock, b"""INSERT INTO my_table (id, data) VALUES (1, 'ABCD'), (2, 'EFGH');""")
receive_response(sock)
## Simple Query
### Sending one query
send_simple_query(sock, b"""SELECT id from my_table where data = 'ABCD';""")
receive_response(sock)
### Sending multiple simple queries
send_simple_query(sock, b"""SELECT id from my_table where data = 'ABCD';""")
send_simple_query(sock, b"""SELECT * from my_table;""")
receive_response(sock)
receive_response(sock)
## End Simple Query
## The response was sequential
## Start Parameterized Query
### Send parse , bind , execute and sync sequentially
send_parse_query(sock, b"SELECT id from my_table where data = $1;", b"Ankush", 1, [25])
send_bind_message(sock, b"testPortal", b"Ankush", [b'EFGH'])
send_execute_message(sock, b"testPortal", 1)
send_sync(sock)
receive_response(sock)
### Send same parse but different bind and execute
### todo : where is flush used ?? Sync generates ReadyForQuery, Flush generates nothing at all but only flushes the server-side buffer.
send_bind_message(sock, b"testPortal2", b"Ankush", [b'EFGH'])
send_bind_message(sock, b"testPortal3", b"Ankush", [b'ABCD'])
send_execute_message(sock, b"testPortal2", 1)
send_execute_message(sock, b"testPortal3", 1)
send_sync(sock)
send_sync(sock)
receive_response(sock)
receive_response(sock)
# read_data_row(sock)
# End Parameterized Query
sock.close()
del sock
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment