Skip to content

Instantly share code, notes, and snippets.

@betafcc
Last active October 19, 2022 12:05
Show Gist options
  • Save betafcc/a5d97a89a9f50a1efb4000481d6b9729 to your computer and use it in GitHub Desktop.
Save betafcc/a5d97a89a9f50a1efb4000481d6b9729 to your computer and use it in GitHub Desktop.
Extensible typed records in python (pyright)
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, ClassVar, Generic, Literal, NoReturn, Type, TypeAlias, TypeVar
from typing_extensions import LiteralString
K = TypeVar("K", bound=LiteralString)
V = TypeVar("V", covariant=True)
AnyPair: TypeAlias = tuple[LiteralString, Any]
M = TypeVar("M", bound=AnyPair, covariant=True)
N = TypeVar("N", bound=AnyPair, covariant=True)
@dataclass(frozen=True)
class Record(Generic[M]):
empty: ClassVar[Record[NoReturn]]
dict: dict[Any, Any]
@staticmethod
def get(key: K) -> Get[K]:
return Get(key)
def __and__(self, other: Record[N]) -> Record[M | N]:
return Record(self.dict | other.dict)
def __add__(self, other: tuple[K, V]) -> Record[M | tuple[K, V]]:
return Record(self.dict | {other[0]: other[1]})
Record.empty = Record[NoReturn]({})
@dataclass(frozen=True)
class Get(Generic[K]):
key: K
def __call__(self, record: Record[tuple[K, V] | AnyPair]) -> V:
return record.dict[self.key]
# Extra type-level helpers below
# fmt: off
class MergeMeta(type):
def __getitem__(cls, args: tuple[Type[Record[M]], Type[Record[N]]]) -> Type[Record[M | N]]: ...
class KeyOfMeta(type):
def __getitem__(cls, args: Type[Record[tuple[K, Any]]]) -> Type[K]: ...
class Merge(metaclass=MergeMeta): ...
class KeyOf(metaclass=KeyOfMeta): ...
# fmt: on
result = Record.empty + ("a", 1) + ("b", "hi") + ("c", 3.14)
# (variable) result: Record[tuple[Literal['a'], Literal[1]] | tuple[Literal['b'], Literal['hi']] | tuple[Literal['c'], float]]
Record.get("a")(result) # Revealed type is 'Literal[1]'
Record.get("b")(result) # Revealed type is 'Literal['hi']'
Record.get("c")(result) # Revealed type is 'float'
merged = result & (Record.empty + ("d", True) + ("e", 2))
# (variable) merged: Record[tuple[Literal['a'], Literal[1]] | tuple[Literal['b'], Literal['hi']] | tuple[Literal['c'], float] | tuple[Literal['d'], Literal[True]] | tuple[Literal['e'], Literal[2]]]
Record.get("a")(merged) # Revealed type is 'Literal[1]'
Record.get("b")(merged) # Revealed type is 'Literal['hi']'
Record.get("c")(merged) # Revealed type is 'float'
Record.get("d")(merged) # Revealed type is 'Literal[True]'
Record.get("e")(merged) # Revealed type is 'Literal[2]'
Merged: TypeAlias = Merge[
Record[tuple[Literal["a"], int] | tuple[Literal["b"], str]],
Record[tuple[Literal["c"], float] | tuple[Literal["d"], bool]],
]
# (type alias) Merged: Type[Record[tuple[Literal['a'], int] | tuple[Literal['b'], str] | tuple[Literal['c'], float] | tuple[Literal['d'], bool]]]
Keys: TypeAlias = KeyOf[Merged]
# (type alias) Keys: Type[Literal['a', 'b', 'c', 'd']]
@betafcc
Copy link
Author

betafcc commented Oct 15, 2022

The idea is to use a covariant Generic Var to keep track of the type of 'item' in the record, that will be a union of all the added tuple pairs.

After that, since the generic is covariant, we can pattern match on a union with 'AnyPair' to retrieve the value from key.

We do have to manipulate the TypeVar binding order from pyright for this to work tho, that's why the get function is curried by a generic class, and I couldn't find a way to make it into a method on Record class

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment