Skip to content

Instantly share code, notes, and snippets.

@maksbotan
Created April 8, 2020 15:23
Show Gist options
  • Save maksbotan/32c81fdc8d3dcad32583c9faaa355e34 to your computer and use it in GitHub Desktop.
Save maksbotan/32c81fdc8d3dcad32583c9faaa355e34 to your computer and use it in GitHub Desktop.
Simplistic python typechecking
import ast
import sys
from enum import Enum
from typing import Dict, NamedTuple, Union
class PrimType(Enum):
ty_int = "int"
ty_bool = "bool"
def __str__(self):
return self.value
Type = Union[PrimType, "FunctionType"]
class FunctionType(NamedTuple):
arg: Type
res: Type
def __str__(self):
return f"{self.arg} -> {self.res}"
Env = Dict[str, Type]
class TypeCheckError(Exception):
pass
def typecheck(expr, env: Env) -> Type:
if isinstance(expr, ast.Num):
return PrimType.ty_int
if isinstance(expr, ast.NameConstant) and expr.value in (False, True):
return PrimType.ty_bool
if isinstance(expr, ast.Name):
ty = env.get(expr.id)
if ty is not None:
return ty
if isinstance(expr, ast.BinOp):
if isinstance(expr.op, (ast.Add, ast.Mult, ast.Sub, ast.Div)):
if (
typecheck(expr.left, env) == PrimType.ty_int
and typecheck(expr.right, env) == PrimType.ty_int
):
return PrimType.ty_int
if isinstance(expr, ast.Compare):
ty_left = typecheck(expr.left, env)
ty_right = typecheck(expr.comparators[0], env)
if ty_left == ty_right:
return PrimType.ty_bool
if isinstance(expr, ast.IfExp):
ty_compare = typecheck(expr.test, env)
ty_then = typecheck(expr.body, env)
ty_else = typecheck(expr.orelse, env)
if ty_compare == PrimType.ty_bool and ty_then == ty_else:
return ty_then
if isinstance(expr, ast.Lambda):
arg = expr.args.args[0]
body = expr.body
new_env = {}
new_env.update(env)
new_env[arg.arg] = PrimType.ty_int
body_type = typecheck(body, new_env)
return FunctionType(PrimType.ty_int, body_type)
if isinstance(expr, ast.Call):
ty_func = typecheck(expr.func, env)
ty_arg = typecheck(expr.args[0], env)
if isinstance(ty_func, FunctionType) and ty_func.arg == ty_arg:
return ty_func.res
raise TypeCheckError(ast.dump(expr))
inp = sys.argv[1]
expr = ast.parse(inp).body[0].value
print(typecheck(expr, {}))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment