Created
December 5, 2012 05:42
-
-
Save skydark/4212705 to your computer and use it in GitHub Desktop.
A simple pattern matcher #python
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
# -*- coding: utf-8 -*- | |
import unittest | |
GREATER, LESSER, EQUAL, NOT_COMPARABLE = 1, -1, 0, None | |
def compare(e1, e2): | |
order = e1.compare(e2) | |
if order != NOT_COMPARABLE: | |
return order | |
order = e2.compare(e1) | |
if order != NOT_COMPARABLE: | |
# FIXME: order must be in (0, 1, -1) | |
return -order | |
return NOT_COMPARABLE | |
class Element(object): | |
def match(self, value): | |
raise NotImplementedError("How to match me, um?") | |
def compare(self, other): | |
raise NotImplementedError("How to compare us, um?") | |
class Const(Element): | |
def __init__(self, c): | |
self.val = c | |
def match(self, value): | |
return self.val == value | |
def compare(self, other): | |
if other.__class__ == self.__class__ and self.val == other.val: | |
return EQUAL | |
return NOT_COMPARABLE | |
class Variable(Element): | |
def __init__(self, name='<NO NAME>'): | |
self.name = name | |
self.val = None | |
def match(self, value): | |
self.val = value | |
return True | |
def compare(self, other): | |
o = other.__class__ | |
if o == self.__class__: | |
return EQUAL | |
if o == Const: | |
return GREATER | |
return NOT_COMPARABLE | |
class Succ(Element): | |
def __init__(self, name='<NO NAME>'): | |
self.name = name | |
self.val = None | |
def match(self, value): | |
if isinstance(value, int) and value > 0: | |
self.val = value - 1 | |
return True | |
return False | |
def compare(self, other): | |
o = other.__class__ | |
if o == self.__class__: | |
return EQUAL | |
if o == Const: | |
return GREATER | |
if o == Variable: | |
return LESSER | |
return NOT_COMPARABLE | |
class List(Element): | |
def __init__(self, *elements): | |
self.val = elements | |
self.length = len(elements) | |
def match(self, value): | |
if not isinstance(value, (list, tuple)): | |
return False | |
if self.length != len(value): | |
return False | |
for element, val in zip(self.val, value): | |
if not element.match(val): | |
return False | |
return True | |
def compare(self, other): | |
o = other.__class__ | |
if o != self.__class__ or self.length != other.length: | |
return NOT_COMPARABLE | |
flag = EQUAL | |
for e1, e2 in zip(self.val, other.val): | |
order = compare(e1, e2) | |
if order == NOT_COMPARABLE: | |
return NOT_COMPARABLE | |
if order != flag: | |
if flag == EQUAL: | |
flag = order | |
elif order != EQUAL: | |
return NOT_COMPARABLE | |
return flag | |
class Function(object): | |
def __init__(self): | |
self.patterns = {} | |
def add_pattern(self, pattern, func): | |
self.patterns[pattern] = func | |
def match(self, args): | |
candidates = [pattern for pattern in self.patterns | |
if pattern.match(args)] | |
if not candidates: | |
raise TypeError("No matched!") | |
min_match = candidates[0] | |
for candidate in candidates[1:]: | |
order = compare(min_match, candidate) | |
if order in (EQUAL, NOT_COMPARABLE): | |
raise TypeError("Multiple implements matched!") | |
if order == GREATER: | |
min_match = candidate | |
return min_match | |
def __call__(self, *args): | |
min_match = self.match(args) | |
func = self.patterns[min_match] | |
return func(min_match) | |
class PatternMatchTestCase(unittest.TestCase): | |
def make_ackermann(self): | |
# a(0, n) = n+1 | |
# a(m>0, 0) = a(m-1, 1) | |
# a(m>0, n>0) = a(m-1, a(m, n-1)) | |
def _ackermann(m, n): | |
return ackermann(m - 1, ackermann(m, n - 1)) | |
ackermann = Function() | |
ackermann.add_pattern(List(Variable('m'), Variable('n')), | |
lambda match: _ackermann(match.val[0].val, match.val[1].val)) | |
ackermann.add_pattern(List(Succ('m'), Const(0)), | |
lambda match: ackermann(match.val[0].val, 1)) | |
ackermann.add_pattern(List(Const(0), Variable('n')), | |
lambda match: match.val[1].val + 1) | |
return ackermann | |
def make_failed_ackermann(self): | |
ackermann = Function() | |
ackermann.add_pattern(List(Variable('m'), Const(0)), | |
lambda match: ackermann(match.val[0].val - 1, 1)) | |
ackermann.add_pattern(List(Const(0), Variable('n')), | |
lambda match: match.val[1].val + 1) | |
return ackermann | |
def testNoMatch(self): | |
ackermann = self.make_failed_ackermann() | |
with self.assertRaises(TypeError) as context: | |
ackermann(3, 1) | |
self.assertEqual(context.exception.args[0], "No matched!") | |
def testMultiMatch(self): | |
ackermann = self.make_failed_ackermann() | |
with self.assertRaises(TypeError) as context: | |
ackermann(0, 0) | |
self.assertEqual(context.exception.args[0], | |
"Multiple implements matched!") | |
def testCalls(self): | |
ackermann = self.make_ackermann() | |
test_set = [ | |
(0, 0, 1), | |
(0, 1, 2), | |
(1, 0, 2), | |
(1, 1, 3), | |
(2, 3, 9), | |
(3, 2, 29), | |
(3, 4, 125), | |
(4, 0, 13), | |
# (4, 1, 65533), # maximum recursion depth exceeded | |
] | |
for m, n, v in test_set: | |
self.assertEqual(ackermann(m, n), v) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment