Created
May 8, 2009 19:28
-
-
Save andreyvit/108955 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import compiler | |
from compiler.ast import * | |
import pprint | |
from StringIO import StringIO | |
from tempfile import TemporaryFile | |
import re | |
#b = 1 | |
#c = 2 | |
#a = b+c | |
#print 'a =', a | |
import sys, imp, __builtin__ | |
# Replacement for __import__() | |
def import_hook(name, globals=None, locals=None, fromlist=None): | |
parent = determine_parent(globals) | |
q, tail = find_head_package(parent, name) | |
m = load_tail(q, tail) | |
if not fromlist: | |
return q | |
if hasattr(m, "__path__"): | |
ensure_fromlist(m, fromlist) | |
return m | |
def determine_parent(globals): | |
if not globals or not globals.has_key("__name__"): | |
return None | |
pname = globals['__name__'] | |
if globals.has_key("__path__"): | |
parent = sys.modules[pname] | |
assert globals is parent.__dict__ | |
return parent | |
if '.' in pname: | |
i = pname.rfind('.') | |
pname = pname[:i] | |
parent = sys.modules[pname] | |
assert parent.__name__ == pname | |
return parent | |
return None | |
def find_head_package(parent, name): | |
if '.' in name: | |
i = name.find('.') | |
head = name[:i] | |
tail = name[i+1:] | |
else: | |
head = name | |
tail = "" | |
if parent: | |
qname = "%s.%s" % (parent.__name__, head) | |
else: | |
qname = head | |
q = import_module(head, qname, parent) | |
if q: return q, tail | |
if parent: | |
qname = head | |
parent = None | |
q = import_module(head, qname, parent) | |
if q: return q, tail | |
raise ImportError, "No module named " + qname | |
def load_tail(q, tail): | |
m = q | |
while tail: | |
i = tail.find('.') | |
if i < 0: i = len(tail) | |
head, tail = tail[:i], tail[i+1:] | |
mname = "%s.%s" % (m.__name__, head) | |
m = import_module(head, mname, m) | |
if not m: | |
raise ImportError, "No module named " + mname | |
return m | |
def ensure_fromlist(m, fromlist, recursive=0): | |
for sub in fromlist: | |
if sub == "*": | |
if not recursive: | |
try: | |
all = m.__all__ | |
except AttributeError: | |
pass | |
else: | |
ensure_fromlist(m, all, 1) | |
continue | |
if sub != "*" and not hasattr(m, sub): | |
subname = "%s.%s" % (m.__name__, sub) | |
submod = import_module(sub, subname, m) | |
if not submod: | |
raise ImportError, "No module named " + subname | |
def import_module(partname, fqname, parent): | |
try: | |
return sys.modules[fqname] | |
except KeyError: | |
pass | |
try: | |
fp, pathname, stuff = imp.find_module(partname, | |
parent and parent.__path__) | |
except ImportError: | |
return None | |
fp = instrument_source(fp, fqname) | |
try: | |
m = imp.load_module(fqname, fp, pathname, stuff) | |
finally: | |
if fp: fp.close() | |
if parent: | |
setattr(parent, partname, m) | |
return m | |
# Replacement for reload() | |
def reload_hook(module): | |
name = module.__name__ | |
if '.' not in name: | |
return import_module(name, name, None) | |
i = name.rfind('.') | |
pname = name[:i] | |
parent = sys.modules[pname] | |
return import_module(name[i+1:], name, parent) | |
def name(block): | |
return block.__class__.__name__ | |
def expr2str(node): | |
if isinstance(node, Name): | |
return node.name | |
elif isinstance(node, Getattr): | |
return '%s.%s' % (expr2str(node.expr), node.attrname) | |
else: | |
return '???' | |
class Visitor(object): | |
def __init__(self): | |
self.additions = {} | |
def add(self, lineno, name): | |
self.additions.setdefault(lineno, []) | |
statement = 'print "%d %s " + repr(%s)' % (lineno, name, name) | |
self.additions[lineno].append(statement) | |
print statement | |
def visitAssign(self, node): | |
print node.lineno | |
if hasattr(node, 'nodes'): | |
for child in node.nodes: | |
self.visit(child) | |
def visitAssName(self, node): | |
self.add(node.lineno, node.name) | |
def visitAssAttr(self, node): | |
n = '%s.%s' % (expr2str(node.expr), node.attrname) | |
if '???' in n: | |
print "COMPLEX %s" % n | |
else: | |
self.add(node.lineno, n) | |
def rewrite(block, additions): | |
d = block.__dict__ | |
print name(block), ':', ', '.join(["%s=%s" % (k, v) for k, v in d.iteritems() if not k in ['nodes', 'code']]) | |
if not hasattr(block, 'nodes'): | |
return block | |
for node in block.nodes: | |
rewrite(node, additions) | |
if isinstance(node, AssName): | |
print 'print "%d %s " + repr(%s)' % (node.lineno, node.name, node.name) | |
def instrument_source(fp, fqname, indent_re = re.compile('\\s*')): | |
print fqname | |
text = fp.read() | |
fp.close() | |
try: | |
ast = compiler.parse(text) | |
visitor = Visitor() | |
compiler.visitor.walk(ast.node, visitor) | |
additions = visitor.additions | |
lines = text.split("\n") | |
new_lines = [] | |
lineno = 1 | |
for line in lines: | |
new_lines.append(line) | |
indent = indent_re.match(line).group() | |
for additional_line in (additions.get(lineno) or []): | |
new_lines.append(indent + additional_line) | |
lineno += 1 | |
text = "\n".join(new_lines) | |
# rewrite(ast.node, additions) | |
except StandardError, e: | |
print "FAILED %s (%s): %s" % (fqname, e.__class__.__name__, e.message) | |
t = TemporaryFile() | |
t.write(text) | |
t.seek(0) | |
return t | |
#walk(ast, myvisitor()) | |
# Now install our hooks | |
original_import = __builtin__.__import__ | |
__builtin__.__import__ = import_hook | |
import updater |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment