Last active
July 18, 2022 22:45
-
-
Save dhilst/24bfd7904ccefb542abf7fa099e7e516 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import * | |
import ast | |
from dataclasses import dataclass | |
@dataclass(frozen=True) | |
class FuncSig: | |
name : str | |
args : list[str] | |
ret: str | |
def __repr__(self): | |
args = " -> ".join(self.args + [self.ret]) | |
return f"{self.name} : {args}" | |
class Typechecker(ast.NodeVisitor): | |
def __init__(self): | |
super().__init__() | |
self.typeenv = {} | |
def visit_FunctionDef(self, node): | |
oldenv = self.typeenv.copy() | |
self.typeenv.update({arg.arg: arg.annotation.id for arg in node.args.args}) | |
self.generic_visit(node) | |
signature = FuncSig(node.name, [arg.annotation.id for arg in node.args.args], | |
node.returns.id) | |
self.typeenv = oldenv | |
self.typeenv[node.name] = signature | |
def visit_Call(self, node): | |
if type(node.func) is ast.Name and node.func.id in self.typeenv: | |
actual_args = [] | |
for arg in node.args: | |
if type(arg) is ast.Constant: | |
actual_args.append(type(arg.value).__name__) | |
elif type(arg) is ast.Name: | |
if arg.id in self.typeenv: | |
actual_args.append(self.typeenv[arg.id]) | |
else: | |
# Cannot typecheck, no type information | |
return | |
expected_args = self.typeenv[node.func.id].args | |
# dumb typechecking | |
if actual_args != expected_args: | |
raise TypeError(f"Type error in call for {node.func.id}, " | |
f"expected : {expected_args}, found : {actual_args}") | |
self.generic_visit(node) | |
def typecheck(text, typeenv={}): | |
tree = ast.parse(text) | |
Typechecker().visit(tree) | |
try: | |
typecheck(""" | |
def inc(a: int, b: int) -> int: | |
return a + 1 | |
def foo(a: int) -> float: | |
return inc(a, "a") # type error here | |
""") | |
except TypeError as e: | |
print(e) # Type error in call for inc, | |
# expected : ['int', 'int'], found : ['int', 'str']p |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment