Skip to content

Instantly share code, notes, and snippets.

@dibrinsofor
Last active April 25, 2024 14:47
Show Gist options
  • Save dibrinsofor/b6eee7d3799de58d9e13f831c6c09069 to your computer and use it in GitHub Desktop.
Save dibrinsofor/b6eee7d3799de58d9e13f831c6c09069 to your computer and use it in GitHub Desktop.
Flagging interested patterns from source code that uses Mypy
import pathlib
import click
import io, os
from enum import Enum
from typing import Optional, Union
import mypy.nodes
from mypy.parse import parse
from mypy.options import Options
from mypy.errors import CompileError, Errors
import sys
sys.path.append("../check")
from check import flush_errs, serialize, get_file_name, log_error, get_files, check_error_found, rare_error_found, parse_call_expr
ERRORS = set()
PATTERNS = Enum('PATTERNS', ['Unpack', 'Wrapper', 'DepDicts'])
# patterns
# 1. instances that depend on untyped keys in dicts
# 2. Uses of the unpack type [TODO: ignore]
# 3. wrappers called with any args
class Flagged:
line: int
var_name: str
pattern: PATTERNS
def __init__(self, line, patterns, v_n) -> None:
self.line = line
self.pattern = patterns
self.var_name = v_n
def trace_untyped_wrapper(name: str, stmt: any) -> bool:
found = False
# TODO: flag any cases that attempt to type the args and kwargs
kw = [mypy.nodes.ArgKind.ARG_STAR, mypy.nodes.ArgKind.ARG_STAR2]
if isinstance(stmt, mypy.nodes.FuncDef):
for sub_stmt in stmt.body.body:
found = trace_untyped_wrapper(name, sub_stmt)
if found:
break
elif isinstance(stmt, mypy.nodes.ReturnStmt):
if hasattr(stmt.expr, 'callee'):
found = trace_untyped_wrapper(name, stmt.expr)
elif isinstance(stmt, mypy.nodes.AssignmentStmt):
if isinstance(stmt.rvalue, mypy.nodes.CallExpr):
found = trace_untyped_wrapper(name, stmt.rvalue)
elif isinstance(stmt, mypy.nodes.CallExpr):
if stmt.callee.name == name:
if kw == stmt.arg_kinds:
found = True
return found
else:
rare_error_found('Statement')
return found
def trace_untyped_dict(name: str, stmt: any) -> bool:
found = False
vals = Union[mypy.nodes.StrExpr | mypy.nodes.IntExpr | mypy.nodes.BytesExpr | mypy.nodes.FloatExpr]
if isinstance(stmt, mypy.nodes.IfStmt):
for sub_stmt in stmt.expr:
found = trace_untyped_dict(name, sub_stmt)
if found:
break
elif isinstance(stmt, mypy.nodes.ComparisonExpr):
for ops in stmt.operators:
if "in" in ops:
if (name == stmt.operands[-1].name) and \
(isinstance(stmt.operands[0], vals)):
found = True
return found
else:
if isinstance(stmt.operands[0], mypy.nodes.IndexExpr):
found = trace_untyped_dict(name, stmt.operands[0])
elif isinstance(stmt, mypy.nodes.OpExpr):
left = trace_untyped_dict(name, stmt.left)
right = trace_untyped_dict(name, stmt.right)
if left or right:
found = True
return found
elif isinstance(stmt, mypy.nodes.IndexExpr):
if (stmt.base.name == name) and isinstance(stmt.index, vals):
found = True
return found
else:
rare_error_found('Statement')
return found
def has_wrapper_pattern(node: any) -> list[Flagged]:
line = node.line
found_patts: list[Flagged] = []
for idx, arg in enumerate(node.type.arg_types):
if is_weak_callable(arg):
var_name = node.arg_names[idx]
for stmt in node.body.body:
trace = trace_untyped_wrapper(var_name, stmt)
if trace:
found_patts.append(Flagged(line, PATTERNS.DepDicts, var_name))
return found_patts
def has_unpacked_type(node: any) -> list[Flagged]:
line = node.line
found_patts: list[Flagged] = []
raise NotImplementedError()
def has_dep_dicts(node: any) -> list[Flagged]:
line = node.line
found_patts: list[Flagged] = []
for idx, arg in enumerate(node.type.arg_types):
if is_weak_dict(arg):
var_name = node.arg_names[idx]
for stmt in node.body.body:
trace = trace_untyped_dict(var_name, stmt)
if trace:
found_patts.append(Flagged(line, PATTERNS.DepDicts, var_name))
return found_patts
def check_for_pattern(node: any) -> list[Flagged]:
patterns_found = []
exists = has_wrapper_pattern(node)
if exists != []:
patterns_found.extend(exists)
exists = has_unpacked_type(node)
if exists != []:
patterns_found.extend(exists)
exists = has_dep_dicts(node)
if exists != []:
patterns_found.extend(exists)
return patterns_found
def is_weak_type(t: any) -> bool:
raise NotImplementedError()
def is_weak_callable(t: any) -> bool:
t = str(t).upper()
if t == "CALLABLE?":
return True
elif "CALLABLE?" in t and "ANY?" in t:
return True
return False
def is_weak_dict(t: any) -> bool:
t = str(t).upper()
if t == "DICT?":
return True
elif "DICT?" in t and "ANY?" in t:
return True
return False
# only looking for weak callables or dicts now
def func_is_weak(node: any) -> Optional[list[any]]:
for arg in node.type.arg_types:
if is_weak_callable(arg):
return True
if is_weak_dict(arg):
return True
return False
def parse_expr(stmt: any) -> any:
if isinstance(stmt, (mypy.nodes.IntExpr, mypy.nodes.StrExpr, mypy.nodes.BytesExpr, mypy.nodes.FloatExpr, mypy.nodes.ComplexExpr)):
return stmt.value
elif isinstance(stmt, mypy.nodes.NameExpr):
return stmt.name
elif isinstance(stmt, mypy.nodes.EllipsisExpr):
...
elif isinstance(stmt, mypy.nodes.CallExpr):
return parse_call_expr(stmt)
else:
rare_error_found('Expression')
def parse_node(node: any) -> Optional[list[list[Flagged]]]:
patts_found = []
if isinstance(node, mypy.nodes.FuncDef):
if func_is_weak(node):
patts = check_for_pattern(node)
if patts != []:
patts_found.append(patts)
elif isinstance(node, mypy.nodes.Decorator):
if func_is_weak(node.func):
patts = check_for_pattern(node.func)
if patts != []:
patts_found.append(patts)
elif isinstance(node, mypy.nodes.OverloadedFuncDef):
breakpoint()
print("overloaded", node.line)
elif isinstance(node, mypy.nodes.ClassDef):
breakpoint()
print("ClassDef", node.line)
# TODO: top level has type_vars and decorators attr. look into them
for stmt in node.defs.body:
res = parse_node(stmt)
if res is not None:
patts_found.extend(res)
elif isinstance(node, mypy.nodes.ExpressionStmt):
breakpoint()
print("ExpressionStmt")
elif isinstance(node, mypy.nodes.AssignmentStmt):
if node.is_alias_def():
...
else:
...
breakpoint()
print("Assignment")
else:
rare_error_found('Statement')
return patts_found
def scan_file(file_name: str) -> list[tuple[int, str]]:
t = []
if not file_name.endswith(".py"):
check_error_found("expected type stub not source code")
return t
f = open(pathlib.Path(file_name), "r", io.DEFAULT_BUFFER_SIZE)
options = Options()
errors = Errors(options)
try:
ast = parse(f.read(), file_name, None, errors, options)
except CompileError:
check_error_found("unable to scan file")
return t
for stmt in ast.defs:
result = parse_node(stmt)
if result is not None or result != []:
t.extend(result)
f.close()
return t
@click.command()
@click.option('-r', '--run-tests', 'tests', is_flag=True, help="Run tests without generating new scan files")
@click.argument("filepath", type=click.Path(exists=True))
def sample(filepath, tests) -> None:
if os.path.isdir(filepath):
dir_files = get_files(filepath)
for file in dir_files:
p = scan_file(file)
print(p)
if len(ERRORS) != 0:
flush_errs(file)
else:
p = scan_file(filepath)
if len(ERRORS) != 0:
flush_errs(filepath)
if not tests:
# write output
...
if __name__ == "__main__":
sample()
@dibrinsofor
Copy link
Author

[WIP]: Parser to flag interesting patterns from source code (identified during manual study) that may inform the need the for more precise Mypy types.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment