Skip to content

Instantly share code, notes, and snippets.

@ericgj
Last active June 18, 2018 09:41
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ericgj/8af30c2a89278a2442625aa7c6bd18dc to your computer and use it in GitHub Desktop.
Save ericgj/8af30c2a89278a2442625aa7c6bd18dc to your computer and use it in GitHub Desktop.
simple tagged union type matching in python
"""
Derived from [fn.py](https://github.com/kachayev/fn.py) function 'curried'
Amended to fix wrapping error: cf. https://github.com/kachayev/fn.py/pull/75
Copyright 2013 Alexey Kachayev
Under the Apache License, Version 2.0
http://www.apache.org/licenses/LICENSE-2.0
"""
from functools import partial, wraps, update_wrapper
from inspect import getargspec
def curry(func):
"""A decorator that makes the function curried
Usage example:
>>> @curry
... def sum5(a, b, c, d, e):
... return a + b + c + d + e
...
>>> sum5(1)(2)(3)(4)(5)
15
>>> sum5(1, 2, 3)(4, 5)
15
"""
@wraps(func)
def _curry(*args, **kwargs):
f = func
count = 0
while isinstance(f, partial):
if f.args:
count += len(f.args)
f = f.func
spec = getargspec(f)
if count == len(spec.args) - len(args):
return func(*args, **kwargs)
para_func = partial(func, *args, **kwargs)
update_wrapper(para_func, f)
return curry(para_func)
return _curry
from f import curry
@curry
def match(uniontype,cases,target):
"""
Return case matching target instance of union type.
Cases are expressed as dicts with types as keys and functions as values.
Note target values must be iterable. Typically target is a named tuple.
Union types are from the `typing` library (Python 3.5 stdlib or install from PyPI)
Usage example:
from typing import NamedTuple, Union, List
Ok = NamedTuple('Ok', [('message',unicode)])
ClientErr = NamedTuple('ClientErr', [('message',unicode), ('code',int)])
ServerErr = NamedTuple('ServerErr', [('message',unicode), ('code',int), ('backtrace', List[unicode])])
Response = Union[Ok, ClientErr, ServerErr]
display_response = (
match(Response, {
Ok: (lambda msg: "Everything ok: %s" % msg),
ClientErr: (lambda msg,code: "Oops, did you mean to do that? (%d %s)" % (code,msg)),
ServerErr: (lambda msg,code,backtrace: "Something bad happened: (%d %s)\n\n%s" % (code,msg, "\n".join(backtrace)))
})
)
#...
response = Ok("beautiful")
display_response(response) # "Everything ok: beautiful"
You must either specify a case for every type in the union type, or include a case for `type(None)`, which will be
used as a fallback if no cases match the target (called with no parameters):
match(Response, {
Ok: (lambda msg: msg),
type(None): (lambda : "Something went wrong")
})
"""
assert issubclass(target.__class__,uniontype), \
"%s is not in union type" % target.__class__.__name__
utypes = []
if hasattr(uniontype,'__union_set_params__'):
utypes = uniontype.__union_set_params__
else:
utypes = [uniontype] # in case where union type is flattened to single type
missing = [
t.__name__ for t in utypes \
if not (cases.has_key(type(None)) or cases.has_key(t))
]
assert len(missing) == 0, \
"No case found for the following type(s): %s" % ", ".join(missing)
fn = None
wildcard = False
try:
fn = (
next( cases[klass] for klass in cases if isinstance(target,klass) )
)
except StopIteration:
fn = cases.get(type(None),None)
wildcard = bool(fn)
# note should never happen due to type assertions above
if fn is None:
raise TypeError("No cases match %s" % target.__class__.__name__)
assert callable(fn), \
"Matched case is not callable; check your cases"
return fn() if wildcard else fn( *(slot for slot in target) )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment