Skip to content

Instantly share code, notes, and snippets.

@dutc
Created February 15, 2019 16:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dutc/9aba08d2b42adec97f723d8f1bc50318 to your computer and use it in GitHub Desktop.
Save dutc/9aba08d2b42adec97f723d8f1bc50318 to your computer and use it in GitHub Desktop.
Do What I Mean, Not What I Say
#!/usr/bin/env python3
from nwis import dwim
@dwim
def func(x, y):
return x @ y
if __name__ == '__main__':
print(f'func(1, 2) = {func(1, 2)}')
#!/usr/bin/env python3
from ast import parse, NodeVisitor, Name, Add, Sub, Mult, USub, MatMult
from collections import defaultdict
from functools import wraps
from importlib import import_module
from inspect import getsource, signature
from pkgutil import iter_modules
class OpsVisitor(NodeVisitor):
op2attr = {Add: '__add__', Sub: '__sub__', Mult: '__mul__', MatMult: '__matmul__',
USub: '__neg__'}
def __init__(self, *params):
self.params = set(params)
self.ops = defaultdict(list)
def visit_UnaryOp(self, node):
if isinstance(node.operand, Name) and node.operand.id in self.params:
self.ops[node.operand.id].append(self.op2attr[type(node.op)])
super().visit(node.operand)
def visit_Attribute(self, node):
if node.value.id in self.params:
self.ops[node.value.id].append(node.value)
def visit_BinOp(self, node):
if isinstance(node.left, Name) and node.left.id in self.params:
self.ops[node.left.id].append(self.op2attr[type(node.op)])
if isinstance(node.right, Name) and node.right.id in self.params:
self.ops[node.right.id].append(self.op2attr[type(node.op)])
super().visit(node.left)
super().visit(node.right)
def visit_Assign(self, node):
assignments = set(x.id for x in node.targets) & self.params
if assignments:
raise TypeError(f'SSA violators will be prosecuted: {assignments}')
super().visit(node.value)
def visit(self, node):
super().visit(node)
return self.ops
def dwim(f):
sig = signature(f)
ast = parse(getsource(f))
ops = OpsVisitor(*sig.parameters).visit(ast)
mods = {import_module('builtins')}
for m in iter_modules():
if m.name in {'this', 'antigravity'}: continue
try:
mods.add(import_module(m.name))
except Exception:
pass
objs = [getattr(m, x) for m in mods for x in dir(m)]
typs = {x for x in objs if isinstance(x, type)}
candidates = {k: {t for t in typs if all(hasattr(t, v) for v in vs)}
for k, vs in ops.items()}
for k, v in candidates.items():
if not v:
raise TypeError('no consistent type for {k} given ops {ops[k]}')
@wraps(f)
def checker(*args, **kwargs):
for arg, value in sig.bind(*args, **kwargs).arguments.items():
if not any(isinstance(value, t) for t in candidates[arg]):
raise TypeError(f'{arg} must be one of {candidates[arg]}')
return f(*args, **kwargs)
return checker
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment