Skip to content

Instantly share code, notes, and snippets.

@miikka
Created April 4, 2015 22:12
Show Gist options
  • Save miikka/71a82ee8c13ecc3cd868 to your computer and use it in GitHub Desktop.
Save miikka/71a82ee8c13ecc3cd868 to your computer and use it in GitHub Desktop.
Pattern matching in Python
"""This is a sketch of implementing pattern matching of lists.
"""
from collections import namedtuple
# Destructuring lists
#
# In a destructuring pattern:
#
# * V('foo') matches a value and assigns it to foo
# * VV('foo') matches the rest of the list and assigns it to foo
# * X ignores a value
# * XX ignores the rest of the list
# * constant matches a constant
V = namedtuple('V', ['name'])
VV = namedtuple('VV', ['name'])
X = object()
XX = object()
def destruct(pattern, value):
result = {}
count = 0
for idx, (p, v) in enumerate(zip(pattern, value)):
if isinstance(p, V):
result[p.name] = v
elif isinstance(p, VV) or p == XX:
break
elif p == X:
pass
elif isinstance(p, list) and isinstance(v, list):
result.update(destruct(p, v))
elif p != v:
return None
count = count + 1
if len(pattern) == count + 1:
p = pattern[count]
if isinstance(p, VV):
result[p.name] = value[count:]
elif not isinstance(p, XX):
return None
elif count != len(value):
return None
return result
# Pattern matching
class EmptyIterator(object):
def __iter__(self): return self
def __next__(self): raise StopIteration
class MatchResult(object):
def __init__(self, result):
self.__dict__ = result
def __repr__(self):
return 'MatchResult({})'.format(repr(self.__dict__))
class NoMatchingPattern(Exception):
pass
class pattern(object):
def __init__(self, value):
self.matched = False
self.value = value
def __enter__(self):
return self
def __exit__(self, *exc):
if not self.matched:
raise NoMatchingPattern(self.value)
def match(self, pattern):
x = destruct(pattern, self.value)
if x is None or self.matched:
return EmptyIterator()
self.matched = True
return [MatchResult(x)]
# Example
def fibo(n):
with pattern([n]) as p:
for _ in p.match([0]): return 0
for _ in p.match([1]): return 1
for x in p.match([V('n')]):
return fibo(x.n - 1) + fibo(x.n - 2)
def flatten(xs):
with pattern(xs) as p:
for x in p.match([]):
return []
for x in p.match([[VV('inner')], VV('outer')]):
return x.inner + flatten(x.outer)
if __name__ == '__main__':
print('fibo(7) =', fibo(7))
print(flatten([[1, 2], [3, 4], [5]]))
print("\nThere will be an NoMatchingPattern exception:")
with pattern([1]) as p:
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment