Skip to content

Instantly share code, notes, and snippets.

@vsajip
Forked from DasIch/bar.py
Created February 26, 2014 10:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vsajip/9227411 to your computer and use it in GitHub Desktop.
Save vsajip/9227411 to your computer and use it in GitHub Desktop.
from nonlocal_ import nonlocal_
def foo():
a = 1
def bar():
nonlocal_('a')
a = 2
bar()
return a
assert foo() == 2, foo()
import nonlocal_
import bar
import os
import sys
import imp
import types
import ast
if sys.version_info[0] == 2:
class NonLocalTransformer(ast.NodeTransformer):
def __init__(self):
ast.NodeTransformer.__init__(self)
self.function_stack = []
def visit_FunctionDef(self, node):
self.function_stack.append(node)
for i, statement in enumerate(node.body):
node.body[i] = self.visit(statement)
return self.function_stack.pop()
def visit_Expr(self, node):
node.value = self.generic_visit(node.value)
if isinstance(node.value, ast.Call):
call = node.value
if isinstance(call.func, ast.Name) and call.func.id == 'nonlocal_':
for argument in call.args:
function_def = self.find_defining_function(argument.s)
if function_def is not None:
i, function_def = function_def
function_def = ListWrapTransformation(
argument.s
).visit(function_def)
self.function_stack[i] = function_def
return ast.Pass()
return node
def find_defining_function(self, name):
def _assignment_in_statement(statement):
if isinstance(statement, ast.Assign):
return any(map(_name_in_expression, statement.targets))
if hasattr(statement, 'body'):
return any(map(_assignment_in_statement, statement.body))
def _name_in_expression(expression):
if isinstance(expression, ast.Attribute):
return _name_in_expression(expression.value)
elif isinstance(expression, ast.Subscript):
return _name_in_expression(expression.value)
elif isinstance(expression, ast.Name):
return expression.id == name
elif isinstance(expression, ast.List):
return any(map(_name_in_expression, expression.elts))
elif isinstance(expression, ast.Tuple):
return any(map(_name_in_expression, expression.elts))
raise NotImplemented(expression)
for i, function_def in enumerate(reversed(self.function_stack), 1):
i = len(self.function_stack) - i
if i == 1:
continue
if _assignment_in_statement(function_def):
return i, function_def
class ListWrapTransformation(ast.NodeTransformer):
def __init__(self, name):
ast.NodeTransformer.__init__(self)
self.name = name
self.first = True
def visit_Assign(self, node):
if self.first:
if isinstance(node.targets[0], ast.Name) and node.targets[0].id == self.name:
assert len(node.targets) == 1
node.value = ast.List([node.value], ast.Load())
self.first = False
return node
else:
node.targets = [self.visit(target) for target in node.targets]
node.value = self.visit(node.value)
return node
def visit_Name(self, node):
if node.id == self.name:
if self.first:
self.first = False
return ast.List([node], node.ctx)
return ast.Subscript(
ast.Name(node.id, ast.Load()),
ast.Index(ast.Num(0)),
node.ctx
)
return node
else:
class NonLocalTransformer(ast.NodeTransformer):
def visit_Expr(self, node):
if isinstance(node.value, ast.Call):
call = node.value
if isinstance(call.func, ast.Name) and call.func.id == 'nonlocal_':
return ast.Nonlocal([argument.s for argument in call.args])
return node
class NonLocalImporter(object):
def __init__(self):
self._found_modules = {}
def find_module(self, name, path=None):
try:
self._found_modules[name] = (imp.find_module(name, path), path)
except ImportError:
return None
return self
def load_module(self, name):
(file, filename, description), path = self._found_modules[name]
newpath = None
if description[2] == imp.PY_SOURCE:
with file:
code = file.read()
elif description[2] == imp.PY_COMPILED:
filename = filename[:-1] # .pyc or .pyo
with open(filename, 'U') as file:
code = file.read()
elif description[2] == imp.PKG_DIRECTORY:
filename = os.path.join(filename, '__init__.py')
newpath = [filename]
with open(filename, 'U') as file:
code = file.read()
else:
return imp.load_module(name, file, filename, description)
module = types.ModuleType(name)
module.__file__ = filename
if newpath:
module.__path__ = newpath
tree = ast.parse(code)
tree = NonLocalTransformer().visit(tree)
ast.fix_missing_locations(tree)
code = compile(tree, filename, 'exec')
sys.modules[name] = module
exec(code, module.__dict__)
return module
def nonlocal_(*names):
raise RuntimeError(
'nonlocal_ needs to be imported before a module using it is imported'
)
sys.meta_path.insert(0, NonLocalImporter())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment