Skip to content

Instantly share code, notes, and snippets.

@odashi
Last active July 23, 2017 14:28
Show Gist options
  • Save odashi/eba45d1d79d58e07f8d8 to your computer and use it in GitHub Desktop.
Save odashi/eba45d1d79d58e07f8d8 to your computer and use it in GitHub Desktop.
Pythonで型付きラムダ計算
# debug.py
# coding: utf-8
import sys
DEBUG = False
WHITE = '\033[37m'
RED = '\033[91m'
GREEN = '\033[92m'
BLUE = '\033[94m'
# print strings with blue color.
# if config.DEBUG is not True, this function not do anything.
def print(*args, **kwargs):
if not DEBUG:
return
s = (
(kwargs['color'] if 'color' in kwargs else BLUE) +
' '.join(str(arg) for arg in args) +
(kwargs['end'] if 'end' in kwargs else '\n') +
'\033[0m')
#s = ' '.join(str(arg) for arg in args)
sys.stdout.write(s)
# [decorator] trace function calls
def tracefunc(func):
import functools
@functools.wraps(func)
def wrapper(*args, **kwargs):
print(func.__module__+'.'+func.__name__, color=GREEN)
ret = func(*args, **kwargs)
return ret
return wrapper
# error.py
# coding: utf-8
class NLSolverError(Exception):
def __init__(self, message):
self.__m = message
def __str__(self):
return self.__m
class ParseError(NLSolverError):
def __init__(self, message):
self.__m = message
def __str__(self):
return self.__m
class ExecutionError(NLSolverError):
def __init__(self, message):
self.__m = message
def __str__(self):
return self.__m
def asserttype(name, value, tp):
if not isinstance(value, tp):
raise TypeError('\'%s\' must be <%s> (<%s> given)' % (
name,
tp.__name__,
type(value).__name__,
))
# lambdacalc.py
# coding: utf-8
from collections import defaultdict
from . import error
from . import debug
def _tostr(term):
if isinstance(term, str):
return '"' + term + '"'
else:
return str(term)
class Type:
def __init__(self):
self.__hash = None
pass
def __str__(self):
return '<Abstract Type>'
def __eq__(self, other):
return False
def __hash__(self):
if self.__hash is None:
self.__hash = hash('TYPE#'+str(self))
return self.__hash
class DataType(Type):
def __init__(self, name):
Type.__init__(self)
error.asserttype('name', name, str)
self.__name = name
def __str__(self):
return self.__name
def __eq__(self, other):
return isinstance(other, DataType) and self.__name == other.__name
def __hash__(self):
return Type.__hash__(self)
def name(self):
return self.__name
class FunctionType(Type):
def __init__(self, argtype, rettype):
Type.__init__(self)
error.asserttype('argtype', argtype, Type)
error.asserttype('rettype', rettype, Type)
self.__argtype = argtype
self.__rettype = rettype
def __str__(self):
if isinstance(self.__argtype, FunctionType):
return '[' + str(self.__argtype) + ']->' + str(self.__rettype)
else: # DataType
return str(self.__argtype) + '->' + str(self.__rettype)
def __eq__(self, other):
return (
isinstance(other, FunctionType) and
self.__argtype == other.__argtype and
self.__rettype == other.__rettype
)
def __hash__(self):
return Type.__hash__(self)
def argtype(self):
return self.__argtype
def rettype(self):
return self.__rettype
class Term:
def __init__(self):
self.__hash = None
pass
def __str__(self):
return '<Abstract Term>'
def __eq__(self, other):
return False
def __hash__(self):
if self.__hash is None:
self.__hash = hash('TERM#'+str(self))
return self.__hash
class Constant(Term):
def __init__(self, name):
Term.__init__(self)
error.asserttype('name', name, str)
self.__name = name
def __str__(self):
return self.__name
def __eq__(self, other):
return isinstance(other, Constant) and self.__name == other.__name
def __hash__(self):
return Term.__hash__(self)
def name(self):
return self.__name
class Variable(Term):
def __init__(self, name):
Term.__init__(self)
error.asserttype('name', name, str)
self.__name = name
def __str__(self):
return '$' + self.__name
def __eq__(self, other):
return isinstance(other, Variable) and self.__name == other.__name
def __hash__(self):
return Term.__hash__(self)
def name(self):
return self.__name
class Application(Term):
def __init__(self, lhs, rhs):
Term.__init__(self)
self.__lhs = lhs
self.__rhs = rhs
def __str__(self):
return '(' + _tostr(self.__lhs) + ' ' + _tostr(self.__rhs) + ')'
def __eq__(self, other):
return (
isinstance(other, Application) and
self.__lhs == other.__lhs and
self.__rhs == other.__rhs
)
def __hash__(self):
return Term.__hash__(self)
def lhs(self):
return self.__lhs
def rhs(self):
return self.__rhs
class Abstraction(Term):
def __init__(self, var, tp, term):
Term.__init__(self)
error.asserttype('var', var, Variable)
error.asserttype('tp', tp, Type)
self.__var = var
self.__tp = tp
self.__term = term
def __str__(self):
return (
'(\\ ' + str(self.__var) +
' :' + str(self.__tp) +
' ' + _tostr(self.__term) +
')'
)
def __eq__(self, other):
return (
isinstance(other, Abstraction) and
self.__var == other.__var and
self.__tp == other.__tp and
self.__term == other.__term
)
def __hash__(self):
return Term.__hash__(self)
def var(self):
return self.__var
def type(self):
return self.__tp
def term(self):
return self.__term
def _parse_s_form(text, pos):
if pos >= len(text):
raise error.ParseError('invalid S-form')
if text[pos] != '(':
# leaf node
if text[pos] == ')':
raise error.ParseError('invalid S-form')
begin = pos
while pos < len(text) and not text[pos].isspace() and text[pos] not in ['(', ')']:
pos += 1
return text[begin:pos], pos
else:
# elements of bracket
pos += 1
elms = []
while True:
# skip space
while pos < len(text) and text[pos].isspace():
pos += 1
if pos >= len(text):
raise error.ParseError('invalid S-form')
if text[pos] == ')':
# end of form
break
elm, pos = _parse_s_form(text, pos)
elms.append(elm)
return tuple(elms), pos+1
def _parse_typeelm(text, pos):
if pos >= len(text) or text[pos] in [']', '>']:
raise error.ParseError('invalid Type format')
if text[pos] == '[':
# chain
tp, pos = _parse_typechain(text, pos+1)
if pos >= len(text) or text[pos] != ']':
raise error.ParseError('invalid Type format')
return tp, pos+1
else:
# leaf node (DataType)
begin = pos
while pos < len(text) and text[pos] not in ['[', ']', '>']:
pos += 1
return DataType(text[begin:pos]), pos
def _parse_typechain(text, pos):
if pos >= len(text):
raise error.ParseError('invalid Type format')
tp, pos = _parse_typeelm(text, pos)
while pos < len(text) and text[pos] != ']':
if text[pos] != '>':
raise TypeError('invalid Type format')
tp1, pos = _parse_typechain(text, pos+1)
tp = FunctionType(tp, tp1)
return tp, pos
def _convert(text):
if text[0].isdigit():
# number
ret = None
try:
# [0-9]+
ret = int(text)
except:
try:
# [0-9]+\.[0-9]*
ret = float(text)
except:
pass
# throw exception if ret is still None in this position.
# its code are aiming to avoid throwing from above try block.
if ret is None:
raise error.ParseError('invalid int/float format')
return ret
if text[0] == '"':
# string
if len(text) < 2 or text[-1] != '"':
raise error.ParseError('invalid str format')
return text[1:-1]
if text[0] == ':':
# Type
return parsetype(text[1:])
if text[0] == '$':
# Variable
if len(text) < 2:
raise error.ParseError('invalid Variable format')
return Variable(text[1:])
return Constant(text)
def _make_term(s_form):
if isinstance(s_form, tuple):
if len(s_form) < 2:
# ill-formed
raise error.ParseError('ill-formed length of tuple')
if s_form[0] == '\\':
# lambda-abstarction ('\\', Variable, Type, term)
if len(s_form) != 4:
raise error.ParseError('invalid lambda-abstraction term')
var = _make_term(s_form[1])
tp = _make_term(s_form[2])
term = _make_term(s_form[3])
if not isinstance(var, Variable):
raise error.ParseError('invalid lambda-abstraction term')
if not isinstance(tp, Type):
raise error.ParseError('invalid lambda-abstraction term')
return Abstraction(var, tp, term)
# application (term, term, ...)
term = Application(_make_term(s_form[0]), _make_term(s_form[1]))
for elm in s_form[2:]:
term = Application(term, _make_term(elm))
return term
# all of terminate (without '\\') is well-formed
if s_form == '\\':
raise error.ParseError('unexpected \'\\\'')
return _convert(s_form)
def parsetype(text):
text = text.strip()
tp, pos = _parse_typechain(text, 0)
if pos < len(text):
raise error.ParseError('invalid Type format')
return tp
def parse(text):
text = text.strip()
s_form, pos = _parse_s_form(text, 0)
if pos < len(text):
raise error.ParseError('invalid S-form')
return _make_term(s_form)
def _varnames(term):
if isinstance(term, Variable):
return {term.name()}
if isinstance(term, Application):
return _varnames(term.lhs()) | _varnames(term.rhs())
if isinstance(term, Abstraction):
return {term.var().name()} | _varnames(term.term())
# other values
return set()
def _othervar(term1, term2):
fvs = sorted(_varnames(term1) | _varnames(term2))
return Variable((fvs[-1] if fvs else '') + '0')
# term[after / before]
def _substitute_var(term, before, after):
error.asserttype('before', before, Variable)
error.asserttype('after', after, Variable)
if isinstance(term, Variable):
return after if term == before else term
if isinstance(term, Application):
lhs = _substitute_var(term.lhs(), before, after)
rhs = _substitute_var(term.rhs(), before, after)
return Application(lhs, rhs)
if isinstance(term, Abstraction):
if term.var() == before:
return term
else:
return Abstraction(term.var(), term.type(), _substitute_var(term.term(), before, after))
# other values
return term
# term[before := after]
def _substitute(term, before, after):
error.asserttype('before', before, Variable)
if isinstance(term, Variable):
return after if term == before else term
if isinstance(term, Application):
lhs = _substitute(term.lhs(), before, after)
rhs = _substitute(term.rhs(), before, after)
return Application(lhs, rhs)
if isinstance(term, Abstraction):
if term.var() == before:
return term
else:
ov = _othervar(term, after)
elm = _substitute_var(term.term(), term.var(), ov)
return Abstraction(ov, term.type(), _substitute(elm, before, after))
# other values
return term
# calc Type description of term
def gettype(term, valuetype_func, const_types, var_types):
if isinstance(term, Constant):
if term.name() in const_types:
return const_types[term.name()]
else:
return None
if isinstance(term, Variable):
if term.name() in var_types:
return var_types[term.name()]
else:
return None
if isinstance(term, Application):
ltp = gettype(term.lhs(), valuetype_func, const_types, var_types)
rtp = gettype(term.rhs(), valuetype_func, const_types, var_types)
if not isinstance(ltp, FunctionType) or not isinstance(rtp, Type):
return None
if ltp.argtype() != rtp:
return None
return ltp.rettype()
if isinstance(term, Abstraction):
varname = term.var().name()
oldvartype = var_types[varname] if varname in var_types else None
var_types[varname] = term.type()
rtp = gettype(term.term(), valuetype_func, const_types, var_types)
var_types[varname] = oldvartype
if not isinstance(rtp, Type):
return None
return FunctionType(term.type(), rtp)
return valuetype_func(term)
# beta-reduction
def _reduce(abst, term, valuetype_func, const_types, var_types):
error.asserttype('abst', abst, Abstraction)
tp = gettype(term, valuetype_func, const_types, var_types)
if tp != abst.type():
raise error.ExecutionError('unexpected lambda type (:%s expected, :%s given)' % (abst.type(), tp))
return _substitute(abst.term(), abst.var(), term)
# recursive beta-reduction
def solve(term, valuetype_func, const_types, var_types):
if isinstance(term, Abstraction):
varname = term.var().name()
oldvartype = var_types[varname] if varname in var_types else None
var_types[varname] = term.type()
elm = solve(term.term(), valuetype_func, const_types, var_types)
var_types[varname] = oldvartype
return Abstraction(term.var(), term.type(), elm)
if isinstance(term, Application):
lhs = solve(term.lhs(), valuetype_func, const_types, var_types)
rhs = solve(term.rhs(), valuetype_func, const_types, var_types)
if isinstance(lhs, Abstraction):
elm = _reduce(lhs, rhs, valuetype_func, const_types, var_types)
return solve(elm, valuetype_func, const_types, var_types)
else:
return Application(lhs, rhs)
# other values
return term
# calc normal form
def normal(term, lambda_depth=0):
if isinstance(term, Constant):
return term
if isinstance(term, Variable):
return term
if isinstance(term, Application):
lhs = normal(term.lhs(), lambda_depth)
rhs = normal(term.rhs(), lambda_depth)
return Application(lhs, rhs)
if isinstance(term, Abstraction):
newvar = Variable(str(lambda_depth))
newelm = _substitute_var(term.term(), term.var(), newvar)
newelm = normal(newelm, lambda_depth+1)
return Abstraction(newvar, term.type(), newelm)
# other values
return term
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment