Created
January 4, 2024 22:33
-
-
Save mbillingr/d5a89c703680fe94c20775e09c47d0d0 to your computer and use it in GitHub Desktop.
Python ADT Prototype
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Abstract Data Types in Python | |
============================= | |
(Python meta programming at its ugliest) | |
The implementation of `adt` is horrible. It could be improved by a few abstractions | |
for generating the different flavors of variant arguments. | |
I'm no longer sure if the use of exec is as bad as I initially thought. At least it's | |
relatively readable. | |
However, I think the result is actually usable. Declaring an ADT is quite readable, and | |
it even looks good in pattern matching. I might actually start to use this for real :) | |
Known Limitations | |
----------------- | |
- does not work with MyPy (I have not tried other type checkers) | |
""" | |
from __future__ import annotations # this is important | |
from typing import Any, assert_never, TypeVar | |
def adt(name: str, **variants: list[type | str]) -> Any: | |
globals: dict[str, Any] = {} | |
exec(f"class {name}: pass", globals) | |
for v, ats in variants.items(): | |
params = ", ".join( | |
f"a{i}: {t if isinstance(t, str) else t.__name__}" | |
for i, t in enumerate(ats) | |
) | |
param_names = ", ".join(f"'a{i}'" for i in range(len(ats))) | |
fields = ", ".join(f"self.a{i}" for i in range(len(ats))) | |
if ats: | |
self_field_assignment = "; ".join( | |
f"self.a{i} = a{i}" for i in range(len(ats)) | |
) | |
else: | |
self_field_assignment = "pass" | |
exec( | |
f""" | |
class {v}({name}): | |
__match_args__ = ({param_names}) | |
def __init__(self, {params}): | |
{self_field_assignment} | |
def __repr__(self) -> str: | |
args = ', '.join(map(repr, [{fields}])) | |
args = args and '('+args+')' | |
return '{name}.{v}' + args | |
def __str__(self) -> str: | |
args = ', '.join(map(str, [{fields}])) | |
args = args and '('+args+')' | |
return '{v}' + args | |
""", | |
globals, | |
) | |
if not ats: | |
exec(f"{name}.{v} = {v}()", globals) | |
else: | |
exec(f"{name}.{v} = {v}", globals) | |
return globals[name] | |
# Example Usage | |
T = TypeVar("T") | |
LinkList = adt( | |
"LinkList", | |
Nil=[], | |
Cons=[T, "LinkList"] | |
) | |
def length(xs: LinkList) -> int: | |
match xs: | |
case LinkList.Nil: | |
return 0 | |
case LinkList.Cons(_, xs): | |
return 1 + length(xs) | |
case _: | |
assert_never(xs) | |
nil = LinkList.Nil | |
cons = LinkList.Cons | |
abc = cons(1, cons("X", cons(3, nil))) | |
print(abc) | |
print(length(abc)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment