Skip to content

Instantly share code, notes, and snippets.

@Solonarv
Last active November 26, 2019 17:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Solonarv/ca97f4a7e37eb99d1250db7ee4758e73 to your computer and use it in GitHub Desktop.
Save Solonarv/ca97f4a7e37eb99d1250db7ee4758e73 to your computer and use it in GitHub Desktop.
Augmented switch statements in python, using a hybrid context manager/decorator approach.
from abc import ABCMeta, abstractmethod
def do_nothing(f):
"""A decorator that swallows the function and does nothing with it."""
pass
def instantiate(callable):
"""A decorator that replaces a class definition with an instance of the class.
Useful for making opaque singleton objects.
>>> @instantiate
... class foo: pass
>>> foo
<__main__.foo object at 0xdeadbeef>
>>> type(foo)
<class 'foo'>
"""
return callable()
@instantiate
class fallthrough:
"""Special sentinel value to indicate that the switch()
statement should fall through to the next case."""
pass
class SwitchFinished(Exception):
"""Internal exception to indicate that a matching case
alternative has been found and the switch() statement should end."""
def __init__(self, val=None):
super().__init__("switch finished - you should never see this exception")
self.val=val
def _just_run_it(body):
ret = body()
if ret is not fallthrough:
raise SwitchFinished(ret)
class PatternFailedToMatch(Exception):
"""Indicates that no matching pattern was found.
Most pattern combinators will swallow, re-raise,
or otherwise interact with this.
Used to decide whether a case alternative's body
should run.
"""
pass
class Pattern(metaclass=ABCMeta):
"""A Pattern may be matched against a scrutinee,
returning some values if successful and throwing
PatternFailedToMatch otherwise.
"""
@abstractmethod
def __init__(self): pass
@abstractmethod
def match(self, scrutinee): pass
class Eq(Pattern):
"""A simple Pattern that checks whether the scrutinee
is equal to a given reference value.
"""
def __init__(self, reference):
self.reference = reference
def match(self, scrutinee):
if scrutinee != self.reference:
raise PatternFailedToMatch()
class Test(Pattern):
"""A more general form of Eq that applies an arbitrary
(boolean) test.
"""
def __init__(self, test):
self.test = test
def match(self, scrutinee):
if not self.test(scrutinee):
raise PatternFailedToMatch
class Type(Pattern):
"""Checks whether the scrutinee is an instance of the given type(s)."""
def __init__(self, *types):
if len(types) == 1:
types = types[0]
self.types = types
def match(self, scrutinee):
if isinstance(scrutinee, self.types):
return (scrutinee,)
else:
raise PatternFailedToMatch()
class Trivial(Pattern):
"""Trivial pattern that always matches."""
def match(self, scrutinee):
return (scrutinee,)
class Apply(Pattern):
"""Applies some converter function to the scrutinee, matching if the
function returns without an exception.
Some exceptions thrown by the converter will be re-raised as
PatternFailedToMatch. This is intended to catch e.g. a ValueError
arising from matching int() against "spam".
"""
def __init__(self, func, swallow=Exception, *args, **kwargs):
self.func = func
self.swallow = swallow
self.args = args
self.kwargs = kwargs
def match(self, scrutinee):
try:
return (self.func(scrutinee),)
except self.swallow as exc:
raise PatternFailedToMatch from exc
class PatternCombinator(Pattern):
"""Base class for patterns that combine other sub-patterns somehow."""
def __init__(self, *patterns):
self.patterns = patterns
class All(PatternCombinator):
"""Matches only if all sub-patterns match, returning a tuple of their values."""
def match(self, scrutinee):
return tuple(pat.match(scrutinee) for pat in self.patterns)
class Any(PatternCombinator):
"""Matches if any of the sub-patterns matches, returning the first matching
sub-pattern's result."""
def match(self, scrutinee):
for pat in self.patterns:
try:
return pat.match(scrutinee)
except PatternFailedToMatch:
continue
raise PatternFailedToMatch
class Tuple(PatternCombinator):
"""Matches a tuple of patterns against a tuple of values. The match is successful
only if the tuples are the same length; use lax=True to ignore lengths."""
def __init__(self, *patterns, lax=False):
self.patterns = patterns
def match(self, scrutinee):
if len(self.patterns) != len(scrutinee) and not lax:
raise PatternFailedToMatch()
return tuple(pat.match(x) for pat,x in zip(self.patterns, scrutinee))
class Chain(PatternCombinator):
"""Chains a number of patterns one after another."""
def __init__(self, *patterns):
self.patterns = patterns
def match(self, scrutinee):
vals = (scrutinee,)
for pat in self.patterns:
vals = pat.match(*vals)
return vals
class switch:
"""C-style 'switch' statement, augmented with pattern matching.
Usage example:
>>> with switch(input()) as case:
... @case(Chain, Apply(int), Eq(1))
... def _():
... print("one")
...
... @case(int)
... def _(ival): # here ival is an int
... print(ival*ival)
...
... @case() # default case
... def _():
... print("I didn't understand the input ;(")
While the need for dummy functions is unfortunate, it
can't be avoided: PEP 377, which proposes allowing
context managers to skip execution of the 'with' block,
was rejected and is unlikely to be implemented. Though
there are hacks to achieve the same behavior, they are
not portable.
"""
def __init__(self, scrutinee):
self.scrut = scrutinee
self.val = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, trace):
if isinstance(exc_value, SwitchFinished):
self.val = exc_value.val
return True
def __call__(self, pat=None, *args, **kwargs):
if pat is None:
return _just_run_it
elif isinstance(pat, type) and issubclass(pat, Pattern):
pat = pat(*args, **kwargs)
elif isinstance(pat, Pattern):
pass # don't need to update pat
elif callable(pat):
pat = Apply(pat, *args, **kwargs)
else:
pat = Eq(pat)
try:
vals = pat.match(self.scrut)
if vals is None:
return _just_run_it
def run_alt(body):
ret = body(*vals)
if ret is not fallthrough:
raise SwitchFinished(ret)
return run_alt
except PatternFailedToMatch:
return do_nothing
# Example usage
if __name__ == '__main__':
with switch(input()) as case:
@case(Chain, Apply(int), Eq(1))
def _():
print("one")
@case(int)
def _(x):
print(x*x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment