Skip to content

Instantly share code, notes, and snippets.

@skydark
Created December 5, 2012 05:42
Show Gist options
  • Save skydark/4212705 to your computer and use it in GitHub Desktop.
Save skydark/4212705 to your computer and use it in GitHub Desktop.
A simple pattern matcher #python
#!/usr/bin/python
# -*- coding: utf-8 -*-
import unittest
GREATER, LESSER, EQUAL, NOT_COMPARABLE = 1, -1, 0, None
def compare(e1, e2):
order = e1.compare(e2)
if order != NOT_COMPARABLE:
return order
order = e2.compare(e1)
if order != NOT_COMPARABLE:
# FIXME: order must be in (0, 1, -1)
return -order
return NOT_COMPARABLE
class Element(object):
def match(self, value):
raise NotImplementedError("How to match me, um?")
def compare(self, other):
raise NotImplementedError("How to compare us, um?")
class Const(Element):
def __init__(self, c):
self.val = c
def match(self, value):
return self.val == value
def compare(self, other):
if other.__class__ == self.__class__ and self.val == other.val:
return EQUAL
return NOT_COMPARABLE
class Variable(Element):
def __init__(self, name='<NO NAME>'):
self.name = name
self.val = None
def match(self, value):
self.val = value
return True
def compare(self, other):
o = other.__class__
if o == self.__class__:
return EQUAL
if o == Const:
return GREATER
return NOT_COMPARABLE
class Succ(Element):
def __init__(self, name='<NO NAME>'):
self.name = name
self.val = None
def match(self, value):
if isinstance(value, int) and value > 0:
self.val = value - 1
return True
return False
def compare(self, other):
o = other.__class__
if o == self.__class__:
return EQUAL
if o == Const:
return GREATER
if o == Variable:
return LESSER
return NOT_COMPARABLE
class List(Element):
def __init__(self, *elements):
self.val = elements
self.length = len(elements)
def match(self, value):
if not isinstance(value, (list, tuple)):
return False
if self.length != len(value):
return False
for element, val in zip(self.val, value):
if not element.match(val):
return False
return True
def compare(self, other):
o = other.__class__
if o != self.__class__ or self.length != other.length:
return NOT_COMPARABLE
flag = EQUAL
for e1, e2 in zip(self.val, other.val):
order = compare(e1, e2)
if order == NOT_COMPARABLE:
return NOT_COMPARABLE
if order != flag:
if flag == EQUAL:
flag = order
elif order != EQUAL:
return NOT_COMPARABLE
return flag
class Function(object):
def __init__(self):
self.patterns = {}
def add_pattern(self, pattern, func):
self.patterns[pattern] = func
def match(self, args):
candidates = [pattern for pattern in self.patterns
if pattern.match(args)]
if not candidates:
raise TypeError("No matched!")
min_match = candidates[0]
for candidate in candidates[1:]:
order = compare(min_match, candidate)
if order in (EQUAL, NOT_COMPARABLE):
raise TypeError("Multiple implements matched!")
if order == GREATER:
min_match = candidate
return min_match
def __call__(self, *args):
min_match = self.match(args)
func = self.patterns[min_match]
return func(min_match)
class PatternMatchTestCase(unittest.TestCase):
def make_ackermann(self):
# a(0, n) = n+1
# a(m>0, 0) = a(m-1, 1)
# a(m>0, n>0) = a(m-1, a(m, n-1))
def _ackermann(m, n):
return ackermann(m - 1, ackermann(m, n - 1))
ackermann = Function()
ackermann.add_pattern(List(Variable('m'), Variable('n')),
lambda match: _ackermann(match.val[0].val, match.val[1].val))
ackermann.add_pattern(List(Succ('m'), Const(0)),
lambda match: ackermann(match.val[0].val, 1))
ackermann.add_pattern(List(Const(0), Variable('n')),
lambda match: match.val[1].val + 1)
return ackermann
def make_failed_ackermann(self):
ackermann = Function()
ackermann.add_pattern(List(Variable('m'), Const(0)),
lambda match: ackermann(match.val[0].val - 1, 1))
ackermann.add_pattern(List(Const(0), Variable('n')),
lambda match: match.val[1].val + 1)
return ackermann
def testNoMatch(self):
ackermann = self.make_failed_ackermann()
with self.assertRaises(TypeError) as context:
ackermann(3, 1)
self.assertEqual(context.exception.args[0], "No matched!")
def testMultiMatch(self):
ackermann = self.make_failed_ackermann()
with self.assertRaises(TypeError) as context:
ackermann(0, 0)
self.assertEqual(context.exception.args[0],
"Multiple implements matched!")
def testCalls(self):
ackermann = self.make_ackermann()
test_set = [
(0, 0, 1),
(0, 1, 2),
(1, 0, 2),
(1, 1, 3),
(2, 3, 9),
(3, 2, 29),
(3, 4, 125),
(4, 0, 13),
# (4, 1, 65533), # maximum recursion depth exceeded
]
for m, n, v in test_set:
self.assertEqual(ackermann(m, n), v)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment