Last active
December 15, 2015 02:59
-
-
Save moyix/5191442 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import ast | |
import inspect | |
import copy | |
import stackexchange | |
import BeautifulSoup | |
namegen_count = 0 | |
def namegen(): | |
global namegen_count | |
namegen_count += 1 | |
return 'name_%d' % namegen_count | |
def is_sorted(l): | |
try: | |
return all(l[i] <= l[i+1] for i in xrange(len(l)-1)) | |
except: | |
return False | |
# Some samples | |
s1 = 'rows.sort()' | |
s2 = 'return sorted(rows)' | |
s3 = 'def dothesort(rows): return sorted(rows)' | |
class FunctionVisitor(ast.NodeVisitor): | |
def __init__(self): | |
super(FunctionVisitor , self).__init__() | |
self.functions_found = [] | |
def visit_FunctionDef(self, node): | |
func_name = node.name | |
# Build a Module to contain it, as that's what compile expects | |
code_obj = compile(ast.Module(body=[node]), 'stackoverflow', 'exec') | |
# Evaluate it and get the compiled function out | |
gdict = {} # Globals dict | |
try: | |
# XXX: THIS IS THE DANGEROUS PART | |
eval(code_obj, gdict) | |
except: | |
return | |
func = gdict[func_name] | |
# Check that the function could be called as func(arg) | |
argspec = inspect.getargspec(func) | |
if argspec.defaults is not None: | |
defargs, nondefargs = (argspec.args[-len(argspec.defaults):], | |
argspec.args[:-len(argspec.defaults)]) | |
else: | |
defargs = [] | |
nondefargs = argspec.args | |
if len(nondefargs) == 1 or (len(nondefargs) == 0 and defarg_len > 0): | |
self.functions_found.append(func) | |
class NameVisitor(ast.NodeVisitor): | |
def __init__(self): | |
super(NameVisitor , self).__init__() | |
self.names_found = [] | |
def visit_Name(self, node): | |
self.names_found.append(node.id) | |
def make_functions(s): | |
"""Makes function(s) out of an arbitrary snippet of code by | |
trying a few different heuristics. | |
""" | |
func_candidates = [] | |
try: | |
parsed = ast.parse(s) | |
except SyntaxError: | |
return [] | |
# First try: run all functions with appropriate arity | |
fv = FunctionVisitor() | |
fv.visit(parsed) | |
func_candidates += fv.functions_found | |
# Second try: get names of possible arguments, and make a function | |
nv = NameVisitor() | |
nv.visit(parsed) | |
newmod = ast.Module(body=[]) | |
for name in nv.names_found: | |
fdef = ast.FunctionDef( | |
name=namegen(), | |
args=ast.arguments( | |
args=[ast.Name(id=name,ctx=ast.Param())], | |
defaults=[] | |
), | |
body=parsed.body, | |
decorator_list=[], | |
) | |
newmod.body.append(fdef) | |
newmod = ast.fix_missing_locations(newmod) | |
fv = FunctionVisitor() | |
fv.visit(newmod) | |
func_candidates += fv.functions_found | |
return func_candidates | |
def try_functions(test_data, funcs, oracle, byref=True): | |
"""Try a list of funcs to see if test_data satisfies some | |
(arity 1) oracle. If byref is True, then assume the function | |
might modify its input and try oracle on the input as well. | |
""" | |
for func in funcs: | |
try: | |
lcopy = copy.copy(test_data) | |
res = func(lcopy) | |
# Returns the answer? | |
if oracle(res): return res | |
# Modifies the input in place? | |
if byref and oracle(lcopy): return lcopy | |
except: pass | |
return None | |
def try_tests(test_data, answers, funcs): | |
"""Try a list of funcs to see if a list of test_data produces the correct | |
answers. | |
""" | |
assert len(test_data) == len(answers) | |
for func in funcs: | |
results = [] | |
for test, ans in zip(test_data, answers): | |
try: | |
results.append(func(test)) | |
except: | |
pass | |
if len(results) != len(answers): continue | |
if all(res == ans for res,ans in zip(results, answers)): return True | |
return None | |
def extract_snippets(body): | |
soup = BeautifulSoup.BeautifulSoup(body) | |
for c in soup.findAll('code'): | |
# Try interpreter text as well | |
snippet = c.text.replace('>>> ','') | |
yield snippet | |
def search_so(query): | |
snippets = [] | |
so = stackexchange.StackOverflow() | |
so.be_inclusive() # Fetch answer bodies | |
#so.impose_throttling = True # Be polite | |
res = so.search(tagged=['python'], intitle=query) | |
for question in res: | |
print "Checking answers for: '%s'" % question | |
question.fetch() | |
# Maybe the submitter got it right? | |
for snip in extract_snippets(question.body): yield snip | |
# Ok, then maybe one of the answers is correct | |
for ans in question.answers: | |
for snip in extract_snippets(ans.body): yield snip | |
def main(oracle, test_data, query): | |
for snip in search_so(query): | |
candidates = make_functions(snip) | |
res = try_functions(test_data,candidates,oracle) | |
if res: | |
print "Success with:" | |
print snip | |
print res | |
break | |
def main2(test_data, answers, query): | |
for snip in search_so(query): | |
candidates = make_functions(snip) | |
res = try_tests(test_data, answers, candidates) | |
if res: | |
print "Success with:" | |
print snip | |
break | |
if __name__ == "__main__": | |
main(is_sorted, [2,5,6,3,1], "sort a list") | |
#nums = [37, 2, 10, 7, 99] | |
#primes = [True, True, False, True, False] | |
#main2(nums, primes, "primality") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment