Created
April 12, 2022 12:21
-
-
Save ertaquo/91619a9c021fad7d7b0b6219ebce9f93 to your computer and use it in GitHub Desktop.
Enum and enum arrays support for psycopg3.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
from enum import Enum | |
from typing import TypeVar, Optional, Generic, List, Union | |
import psycopg.adapt | |
from psycopg.adapt import Buffer | |
from psycopg.types import TypeInfo | |
E = TypeVar("E", bound=Enum) | |
class TypeNotFound(RuntimeError): | |
def __init__(self, pg_type_name: str, pg_namespaces: List[str] = None): | |
self.pg_namespaces = pg_namespaces | |
self.pg_type_name = pg_type_name | |
super().__init__(f'Postgres type {pg_type_name} not found within namespaces {",".join(pg_namespaces)}') | |
def psycopg_register_enum_type( | |
connection: psycopg.Connection, | |
enum_type: Generic[E], | |
pg_type_name: str, | |
pg_namespaces: Union[str, List[str], None] = None, | |
) -> None: | |
""" | |
Register enum type for psycopg connection. | |
:param connection: psycopg connection (tuple_row, namedtuple_row and dict_row row factories are supported). | |
:param enum_type: enum type. | |
:param pg_type_name: enum type name in Postgres database. | |
:param pg_namespaces: namespace(s) in Postgres database. If set to None, will be used current search path. | |
""" | |
if pg_namespaces is None: | |
search_path = connection.execute("SHOW search_path").fetchone() | |
if type(search_path) == dict: | |
search_path = tuple(search_path.values()) | |
pg_namespaces = [namespace.strip() for namespace in search_path[0].split(',')] | |
elif type(pg_namespaces) == str: | |
pg_namespaces = [pg_namespaces] | |
type_data = connection.execute( | |
"UNION ALL".join([ | |
""" | |
SELECT pt.oid, pt.typarray FROM pg_catalog.pg_type pt | |
WHERE pt.typnamespace = ( | |
SELECT oid FROM pg_catalog.pg_namespace pn WHERE pn.nspname = %s | |
) | |
AND pt.typname = %s | |
AND pt.typtype = 'e' | |
""" | |
for _ in pg_namespaces | |
]), | |
tuple( | |
itertools.chain( | |
*[(pg_namespace, pg_type_name) for pg_namespace in pg_namespaces] | |
) | |
), | |
).fetchone() | |
if type_data is None: | |
raise TypeNotFound(pg_type_name, pg_namespaces) | |
if type(type_data) == dict: | |
type_data = tuple(type_data.values()) | |
oid = type_data[0] | |
array_oid = type_data[1] | |
connection.adapters.types.add(TypeInfo( | |
pg_type_name, | |
oid, | |
array_oid, | |
)) | |
class _PsycopgEnumLoader(psycopg.adapt.Loader, Generic[E]): | |
def load(self, data: Buffer) -> E: | |
return enum_type(data.tobytes().decode()) | |
class _PsycopgEnumArrayLoader(psycopg.adapt.Loader, Generic[E]): | |
def load(self, data: Buffer) -> List[E]: | |
return [enum_type(value) for value in data.tobytes().decode().lstrip('{').rstrip('}').split(',')] | |
connection.adapters.register_loader(oid, _PsycopgEnumLoader[enum_type]) | |
connection.adapters.register_loader(array_oid, _PsycopgEnumArrayLoader[enum_type]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CREATE TYPE some_namespace.some_enum AS ENUM('ONE', 'TWO'); | |
# CREATE TYPE public.some_enum AS ENUM('OTHER_ONE', 'OTHER_TWO'); | |
# CREATE TABLE some_namespace.some_table( id UUID NOT NULL PRIMARY KEY, value some_namespace.some_enum ); | |
# CREATE TABLE public.some_table_arr( id UUID NOT NULL PRIMARY KEY, values public.some_enum[] ); | |
class SomeEnum(str, Enum): | |
ONE = "ONE" | |
TWO = "TWO" | |
class SomeOtherEnum(str, Enum): | |
OTHER_ONE = "OTHER_ONE" | |
OTHER_TWO = "OTHER_TWO" | |
class SomeModel(BaseModel): | |
id: UUID | |
value: SomeEnum | |
class SomeModelArr(BaseModel): | |
id: UUID | |
values: List[SomeOtherEnum] | |
connection = psycopg.connect( | |
conninfo="postgresql://username:password@localhost/database", | |
options="-c search_path=some_namespace,public", | |
) | |
psycopg_register_enum_type(connection, SomeEnum, pg_type_name="some_enum") | |
psycopg_register_enum_type(connection, SomeEnum, pg_type_name="some_enum", pg_namespaces="public") | |
connection.execute("SELECT * FROM some_namespace.some_table").fetchall() | |
connection.execute("SELECT * FROM some_namespace.some_table_arr").fetchall() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I can send you some comments, but if you create a MR for psycopg 3 it would be better. Eyeballing:
CompositeTypeInfo
(knowing the attributes of a composite type)More info in psycopg/psycopg#273