Created
February 15, 2019 16:31
-
-
Save dutc/9aba08d2b42adec97f723d8f1bc50318 to your computer and use it in GitHub Desktop.
Do What I Mean, Not What I Say
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from nwis import dwim | |
@dwim | |
def func(x, y): | |
return x @ y | |
if __name__ == '__main__': | |
print(f'func(1, 2) = {func(1, 2)}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from ast import parse, NodeVisitor, Name, Add, Sub, Mult, USub, MatMult | |
from collections import defaultdict | |
from functools import wraps | |
from importlib import import_module | |
from inspect import getsource, signature | |
from pkgutil import iter_modules | |
class OpsVisitor(NodeVisitor): | |
op2attr = {Add: '__add__', Sub: '__sub__', Mult: '__mul__', MatMult: '__matmul__', | |
USub: '__neg__'} | |
def __init__(self, *params): | |
self.params = set(params) | |
self.ops = defaultdict(list) | |
def visit_UnaryOp(self, node): | |
if isinstance(node.operand, Name) and node.operand.id in self.params: | |
self.ops[node.operand.id].append(self.op2attr[type(node.op)]) | |
super().visit(node.operand) | |
def visit_Attribute(self, node): | |
if node.value.id in self.params: | |
self.ops[node.value.id].append(node.value) | |
def visit_BinOp(self, node): | |
if isinstance(node.left, Name) and node.left.id in self.params: | |
self.ops[node.left.id].append(self.op2attr[type(node.op)]) | |
if isinstance(node.right, Name) and node.right.id in self.params: | |
self.ops[node.right.id].append(self.op2attr[type(node.op)]) | |
super().visit(node.left) | |
super().visit(node.right) | |
def visit_Assign(self, node): | |
assignments = set(x.id for x in node.targets) & self.params | |
if assignments: | |
raise TypeError(f'SSA violators will be prosecuted: {assignments}') | |
super().visit(node.value) | |
def visit(self, node): | |
super().visit(node) | |
return self.ops | |
def dwim(f): | |
sig = signature(f) | |
ast = parse(getsource(f)) | |
ops = OpsVisitor(*sig.parameters).visit(ast) | |
mods = {import_module('builtins')} | |
for m in iter_modules(): | |
if m.name in {'this', 'antigravity'}: continue | |
try: | |
mods.add(import_module(m.name)) | |
except Exception: | |
pass | |
objs = [getattr(m, x) for m in mods for x in dir(m)] | |
typs = {x for x in objs if isinstance(x, type)} | |
candidates = {k: {t for t in typs if all(hasattr(t, v) for v in vs)} | |
for k, vs in ops.items()} | |
for k, v in candidates.items(): | |
if not v: | |
raise TypeError('no consistent type for {k} given ops {ops[k]}') | |
@wraps(f) | |
def checker(*args, **kwargs): | |
for arg, value in sig.bind(*args, **kwargs).arguments.items(): | |
if not any(isinstance(value, t) for t in candidates[arg]): | |
raise TypeError(f'{arg} must be one of {candidates[arg]}') | |
return f(*args, **kwargs) | |
return checker |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment