Skip to content

Instantly share code, notes, and snippets.

@msullivan
Last active December 9, 2023 12:47
Show Gist options
  • Save msullivan/7f533f927a4ba3fffd856cb0c9527106 to your computer and use it in GitHub Desktop.
Save msullivan/7f533f927a4ba3fffd856cb0c9527106 to your computer and use it in GitHub Desktop.
A few related schemes for implementing view patterns in Python. view_patterns_explicit.py is probably the least dodgy of them
"""Hacky implementation of "view patterns" with Python match
A "view pattern" is one that does some transformation on the data
being matched before attempting to match it. This can be super useful,
as it allows writing "helper functions" for pattern matching.
We provide a class, ViewPattern, that can be subclassed with custom
`match` methods that performs a transformation on the scructinee,
returning a transformed value or raising NoMatch if a match is not
possible.
For example, you could write:
@dataclasses.dataclass
class IntPair:
lhs: int
rhs: int
class sum_view(ViewPattern):
_view_result: int
def match(obj: object) -> int:
match obj:
case IntPair(lhs, rhs):
return lhs + rhs
raise view_patterns.NoMatch
and then write code like:
match IntPair(lhs=10, rhs=15):
case sum_view(10):
print("NO!")
case sum_view(25):
print("YES!")
IMPORTANT NOTE: Do not match against a view pattern from within a
__del__ method or a signal handler.
----
To understand how this is implemented, we first discuss how pattern
matching a value `v` against a pattern like `C(<expr>)` is performed:
1. isinstance(v, C) is called. If it is False, the match fails
2. C.__match_args__ is fetched; it should contain a tuple of
attribute names to be used for positional matching.
3. In our case, there should be only one attribute in it, `attr`,
and v.attr is fetched
Our implementation strategy, then, is:
a. Overload C's isinstance check by implementing `__instancecheck__`
in a metaclass. It will call the `match` method, returning
False if `NoMatch` is raised. If `match` succeeds, we stash
the value and return True.
b. Make C's __match_args__ `('_view_result')`
c. Arrange for `_view_result` on the matched object to to return
the value stashed away by `__instancecheck__`.
The stashing and restoring of the value is the dodgiest part of this
whole endeavor. To do it, we rely on the fact that in non-pathological
cases, no other match can be performed between steps 1 and 3 above, so
we simply store the value in a thread-local variable.
We use a @property getter on the target value's class to do this; it
simply grabs the value out of the thread-local variable.
The pathological cases are signal handlers and __del__ methods, and
so we warn against those in the documentation. This is obviously
pretty dubious but probably harmless in practice.
"""
import threading
class NoMatch(Exception):
pass
@property # type: ignore
def _view_result_getter(self):
x = ViewPatternMeta._vals.val
ViewPatternMeta._vals.val = None
return x
class ViewPatternMeta(type):
_vals = threading.local()
def __instancecheck__(self, instance):
try:
val = self.match(instance)
except NoMatch:
return False
# XXX: This is very dodgy... Just install the property getter
# on the target class.
# This is pretty funny but we should probably have an opt-in
# base class instead.
try:
type(instance)._view_result
except AttributeError:
type(instance)._view_result = _view_result_getter
ViewPatternMeta._vals.val = val
return True
class ViewPattern(metaclass=ViewPatternMeta):
__match_args__ = ('_view_result',)
###### Testing
import dataclasses
@dataclasses.dataclass
class IntPair:
lhs: int
rhs: int
class sum_view(ViewPattern):
_view_result: int
@staticmethod
def match(obj: object) -> int:
match obj:
case IntPair(lhs, rhs):
return lhs + rhs
raise view_patterns.NoMatch
match IntPair(lhs=10, rhs=15):
case sum_view(10):
print("NO!")
case sum_view(25):
print("YES!")
"""Hacky implementation of "view patterns" with Python match
A "view pattern" is one that does some transformation on the data
being matched before attempting to match it. This can be super useful,
as it allows writing "helper functions" for pattern matching.
We provide a class, ViewPattern, that can be subclassed with custom
`match` methods that performs a transformation on the scrutinee,
returning a transformed value or raising NoMatch if a match is not
possible.
For example, you could write:
@dataclasses.dataclass
class IntPair:
lhs: int
rhs: int
class sum_view(ViewPattern, targets=(IntPair,)):
_view_result: int
def match(obj: object) -> int:
match obj:
case IntPair(lhs, rhs):
return lhs + rhs
raise view_patterns.NoMatch
and then write code like:
match IntPair(lhs=10, rhs=15):
case sum_view(10):
print("NO!")
case sum_view(25):
print("YES!")
----
To understand how this is implemented, we first discuss how pattern
matching a value `v` against a pattern like `C(<expr>)` is performed:
1. isinstance(v, C) is called. If it is False, the match fails
2. C.__match_args__ is fetched; it should contain a tuple of
attribute names to be used for positional matching.
3. In our case, there should be only one attribute in it, `attr`,
and v.attr is fetched. If fetching v.attr raises AttributeError,
the match fails.
Our implementation strategy, then, is:
a. Overload C's isinstance check by implementing `__instancecheck__`
in a metaclass. Return True if the instance is an instance of
one of the target classes.
b. Make C's __match_args__ `('_view_result_<unique_name>',)`
c. Arrange for `_view_result_<unique_name>` on the matched object to
call match and return that value. If match raises NoMatch, transform
it into AttributeError, so that the match fails.
Calling match from the *getter* lets us avoid the need to save the
value somewhere between steps a and c, but requires us to install one
method per view in the scrutinee's class.
"""
class NoMatch(Exception):
pass
class ViewPatternMeta(type):
def __new__(mcls, name, bases, clsdict, *, targets=(), **kwargs):
cls = super().__new__(mcls, name, bases, clsdict, **kwargs)
@property # type: ignore
def _view_result_getter(self):
try:
return cls.match(self)
except NoMatch:
raise AttributeError
fname = f'_view_result_{cls.__module__}.{cls.__qualname__}'
mangled = fname.replace("___", "___3_").replace(".", "___")
cls.__match_args__ = (mangled,) # type: ignore
cls._view_result_getter = _view_result_getter
cls._targets = targets
# Install the getter onto all target classes
for target in targets:
setattr(target, mangled, _view_result_getter)
return cls
def __instancecheck__(self, instance):
return isinstance(instance, self._targets)
class ViewPattern(metaclass=ViewPatternMeta):
__match_args__ = ('_view_result',)
@classmethod
def match(cls, obj: object):
raise NoMatch
###### Testing
import dataclasses
@dataclasses.dataclass
class Base:
pass
@dataclasses.dataclass
class IntPair(Base):
lhs: int
rhs: int
class sum_view(ViewPattern, targets=(Base,)):
_view_result: int
@staticmethod
def match(obj: object) -> int:
match obj:
case IntPair(lhs, rhs):
return lhs + rhs
raise NoMatch
class product_view(ViewPattern, targets=(Base,)):
_view_result: int
@staticmethod
def match(obj: object) -> int:
match obj:
case IntPair(lhs, rhs):
return lhs * rhs
raise NoMatch
match IntPair(lhs=10, rhs=15):
case sum_view(10):
print("NO!")
case product_view(10):
print("NO!")
case sum_view(25):
print("YES!")
case product_view(160):
print("???")
match 20:
case sum_view(10):
print("NO!")
"""Hacky implementation of "view patterns" with Python match (alternate)
A "view pattern" is one that does some transformation on the data
being matched before attempting to match it. This can be super useful,
as it allows writing "helper functions" for pattern matching.
We provide a class, ViewPattern, that can be subclassed with custom
`match` methods that performs a transformation on the scrutinee,
returning a transformed value or raising NoMatch if a match is not
possible.
For example, you could write:
@dataclasses.dataclass
class IntPair:
lhs: int
rhs: int
class sum_view(ViewPattern):
_view_result: int
def match(obj: object) -> int:
match obj:
case IntPair(lhs, rhs):
return lhs + rhs
raise view_patterns.NoMatch
and then write code like:
match IntPair(lhs=10, rhs=15):
case sum_view(10):
print("NO!")
case sum_view(25):
print("YES!")
----
To understand how this is implemented, we first discuss how pattern
matching a value `v` against a pattern like `C(<expr>)` is performed:
1. isinstance(v, C) is called. If it is False, the match fails
2. C.__match_args__ is fetched; it should contain a tuple of
attribute names to be used for positional matching.
3. In our case, there should be only one attribute in it, `attr`,
and v.attr is fetched. If fetching v.attr raises AttributeError,
the match fails.
Our implementation strategy, then, is:
a. Overload C's isinstance check by implementing `__instancecheck__`
in a metaclass. Just always return True.
b. Make C's __match_args__ `('_view_result_<unique_name>',)`
c. Arrange for `_view_result_<unique_name>` on the matched object to
call match and return that value. If match raises NoMatch, transform
it into AttributeError, so that the match fails.
Calling match from the *getter* lets us avoid the need to save the
value somewhere between steps a and c, but requires us to install one
method per view in the scrutinee's class.
"""
class NoMatch(Exception):
pass
class ViewPatternMeta(type):
def __new__(mcls, name, bases, clsdict, **kwargs):
cls = super().__new__(mcls, name, bases, clsdict, **kwargs)
match = cls.match
@property # type: ignore
def _view_result_getter(self):
try:
return match(self)
except NoMatch:
raise AttributeError
fname = f'_view_result_{cls.__module__}.{cls.__qualname__}'
mangled = fname.replace("___", "___3_").replace(".", "___")
cls.__match_args__ = (mangled,) # type: ignore
cls._view_result_getter = _view_result_getter
return cls
def __instancecheck__(self, instance):
# XXX: This is very dodgy... Just install the property getter
# on the target class.
# This is pretty funny but it might be better to list all possible
# scrutinee base classes as an argument when defining the view pattern.
getter = self.__match_args__[0]
try:
getattr(type(instance), getter)
except AttributeError:
try:
setattr(type(instance), getter, self._view_result_getter)
except TypeError:
return False
return True
class ViewPattern(metaclass=ViewPatternMeta):
__match_args__ = ('_view_result',)
@staticmethod
def match(obj: object):
raise NoMatch
###### Testing
import dataclasses
@dataclasses.dataclass
class IntPair:
lhs: int
rhs: int
class sum_view(ViewPattern):
_view_result: int
@staticmethod
def match(obj: object) -> int:
match obj:
case IntPair(lhs, rhs):
return lhs + rhs
raise NoMatch
class product_view(ViewPattern):
_view_result: int
@staticmethod
def match(obj: object) -> int:
match obj:
case IntPair(lhs, rhs):
return lhs * rhs
raise NoMatch
match IntPair(lhs=10, rhs=15):
case sum_view(10):
print("NO!")
case product_view(10):
print("NO!")
case sum_view(25):
print("YES!")
case product_view(160):
print("???")
match 20:
case sum_view(10):
print("NO!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment