Skip to content

Instantly share code, notes, and snippets.

@andreyvit
Created May 8, 2009 19:28
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save andreyvit/108955 to your computer and use it in GitHub Desktop.
Save andreyvit/108955 to your computer and use it in GitHub Desktop.
import compiler
from compiler.ast import *
import pprint
from StringIO import StringIO
from tempfile import TemporaryFile
import re
#b = 1
#c = 2
#a = b+c
#print 'a =', a
import sys, imp, __builtin__
# Replacement for __import__()
def import_hook(name, globals=None, locals=None, fromlist=None):
parent = determine_parent(globals)
q, tail = find_head_package(parent, name)
m = load_tail(q, tail)
if not fromlist:
return q
if hasattr(m, "__path__"):
ensure_fromlist(m, fromlist)
return m
def determine_parent(globals):
if not globals or not globals.has_key("__name__"):
return None
pname = globals['__name__']
if globals.has_key("__path__"):
parent = sys.modules[pname]
assert globals is parent.__dict__
return parent
if '.' in pname:
i = pname.rfind('.')
pname = pname[:i]
parent = sys.modules[pname]
assert parent.__name__ == pname
return parent
return None
def find_head_package(parent, name):
if '.' in name:
i = name.find('.')
head = name[:i]
tail = name[i+1:]
else:
head = name
tail = ""
if parent:
qname = "%s.%s" % (parent.__name__, head)
else:
qname = head
q = import_module(head, qname, parent)
if q: return q, tail
if parent:
qname = head
parent = None
q = import_module(head, qname, parent)
if q: return q, tail
raise ImportError, "No module named " + qname
def load_tail(q, tail):
m = q
while tail:
i = tail.find('.')
if i < 0: i = len(tail)
head, tail = tail[:i], tail[i+1:]
mname = "%s.%s" % (m.__name__, head)
m = import_module(head, mname, m)
if not m:
raise ImportError, "No module named " + mname
return m
def ensure_fromlist(m, fromlist, recursive=0):
for sub in fromlist:
if sub == "*":
if not recursive:
try:
all = m.__all__
except AttributeError:
pass
else:
ensure_fromlist(m, all, 1)
continue
if sub != "*" and not hasattr(m, sub):
subname = "%s.%s" % (m.__name__, sub)
submod = import_module(sub, subname, m)
if not submod:
raise ImportError, "No module named " + subname
def import_module(partname, fqname, parent):
try:
return sys.modules[fqname]
except KeyError:
pass
try:
fp, pathname, stuff = imp.find_module(partname,
parent and parent.__path__)
except ImportError:
return None
fp = instrument_source(fp, fqname)
try:
m = imp.load_module(fqname, fp, pathname, stuff)
finally:
if fp: fp.close()
if parent:
setattr(parent, partname, m)
return m
# Replacement for reload()
def reload_hook(module):
name = module.__name__
if '.' not in name:
return import_module(name, name, None)
i = name.rfind('.')
pname = name[:i]
parent = sys.modules[pname]
return import_module(name[i+1:], name, parent)
def name(block):
return block.__class__.__name__
def expr2str(node):
if isinstance(node, Name):
return node.name
elif isinstance(node, Getattr):
return '%s.%s' % (expr2str(node.expr), node.attrname)
else:
return '???'
class Visitor(object):
def __init__(self):
self.additions = {}
def add(self, lineno, name):
self.additions.setdefault(lineno, [])
statement = 'print "%d %s " + repr(%s)' % (lineno, name, name)
self.additions[lineno].append(statement)
print statement
def visitAssign(self, node):
print node.lineno
if hasattr(node, 'nodes'):
for child in node.nodes:
self.visit(child)
def visitAssName(self, node):
self.add(node.lineno, node.name)
def visitAssAttr(self, node):
n = '%s.%s' % (expr2str(node.expr), node.attrname)
if '???' in n:
print "COMPLEX %s" % n
else:
self.add(node.lineno, n)
def rewrite(block, additions):
d = block.__dict__
print name(block), ':', ', '.join(["%s=%s" % (k, v) for k, v in d.iteritems() if not k in ['nodes', 'code']])
if not hasattr(block, 'nodes'):
return block
for node in block.nodes:
rewrite(node, additions)
if isinstance(node, AssName):
print 'print "%d %s " + repr(%s)' % (node.lineno, node.name, node.name)
def instrument_source(fp, fqname, indent_re = re.compile('\\s*')):
print fqname
text = fp.read()
fp.close()
try:
ast = compiler.parse(text)
visitor = Visitor()
compiler.visitor.walk(ast.node, visitor)
additions = visitor.additions
lines = text.split("\n")
new_lines = []
lineno = 1
for line in lines:
new_lines.append(line)
indent = indent_re.match(line).group()
for additional_line in (additions.get(lineno) or []):
new_lines.append(indent + additional_line)
lineno += 1
text = "\n".join(new_lines)
# rewrite(ast.node, additions)
except StandardError, e:
print "FAILED %s (%s): %s" % (fqname, e.__class__.__name__, e.message)
t = TemporaryFile()
t.write(text)
t.seek(0)
return t
#walk(ast, myvisitor())
# Now install our hooks
original_import = __builtin__.__import__
__builtin__.__import__ = import_hook
import updater
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment