Skip to content

Instantly share code, notes, and snippets.

@dhermes
Last active August 29, 2015 14:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dhermes/85c3a3a464d2cff312ea to your computer and use it in GitHub Desktop.
Save dhermes/85c3a3a464d2cff312ea to your computer and use it in GitHub Desktop.
An `ast` parser which (mostly) correctly indents Python docstrings.
import __builtin__
import ast
import collections
import shutil
import sys
import tempfile
TRIPLE_QUOTES = ('"""', '\'\'\'')
# Also see: http://stackoverflow.com/a/17478618/1068170
# and: ('https://bitbucket.org/aivarannamaa/thonny/src/'
# '85e09e98a08db63d75f158b435dd07fc7a00c27c/src/'
# 'ast_utils.py?at=default')
class HeavyAst(object):
def __init__(self, filename):
self.filename = filename
# NOTE: We use 'rU' so line splitting works.
with open(filename, 'rU') as fh:
contents = fh.read()
self.ast_tree = ast.parse(contents)
self.as_lines = contents.split('\n')
def find_content(self, doc_str):
first_line_end = doc_str.find('\n')
first_line_docstring = doc_str[:first_line_end]
matches = [i for i, line in enumerate(self.as_lines)
if first_line_docstring in line]
if len(matches) != 1:
raise ValueError('Can\'t find line in source code.')
candidate_line = matches[0]
actual_line = None
while actual_line is None:
first_three = self.as_lines[candidate_line].lstrip()[:3]
if first_three in TRIPLE_QUOTES:
actual_line = candidate_line
else:
candidate_line -= 1
# NOTE: We add 1 since the lines in the file are 1-indexed.
return actual_line + 1
def rewrite_file(self, all_docstrings):
backup_file_tmp = tempfile.mkstemp()[1]
new_file_tmp = tempfile.mkstemp()[1]
backup_fh = open(backup_file_tmp, 'w')
new_fh = open(new_file_tmp, 'w')
curr_docstring = None
line_no = 0
for line_val in self.as_lines:
line_no += 1 # Start at 1.
if line_no != 1: # First line does not have a preceding line.
new_fh.write('\n')
backup_fh.write('\n')
# Write the lines as-is to the backup.
backup_fh.write(line_val)
# Check if a docstring is starting.
if line_no in all_docstrings:
if curr_docstring is not None:
raise ValueError('Two docstrings can\'t be simultaneous.')
curr_docstring = all_docstrings[line_no]
if curr_docstring is None:
new_fh.write(line_val)
else:
curr_docstring.write_line(new_fh, line_no)
if line_no == curr_docstring.end:
curr_docstring = None
backup_fh.close()
new_fh.close()
shutil.copyfile(backup_file_tmp, self.filename + '.bak')
print 'Created', self.filename + '.bak'
shutil.copyfile(new_file_tmp, self.filename)
print 'Over-wrote', self.filename
class DocstringObj(object):
def __init__(self, heavy_ast, ast_parent, ast_docstr_expr):
self.doc_str = ast.get_docstring(ast_parent)
self.doc_str_lines = self.doc_str.split('\n')
if isinstance(ast_parent, ast.Module):
self.start = 1
self.col_offset = 0
else:
self.start = heavy_ast.find_content(self.doc_str)
self.col_offset = ast_parent.col_offset + 4
self.end = ast_docstr_expr.lineno
stated_length = self.end - self.start + 1
missing_length = stated_length - len(self.doc_str_lines)
if missing_length < 0:
raise ValueError('Docstring is too long for reported start and end.')
self.doc_str_lines += [''] * missing_length
def get_line(self, line_no):
return self.doc_str_lines[line_no - self.start]
def write_line(self, fh, line_no):
line_val = self.get_line(line_no)
if line_no == self.start:
line_val = '"""' + line_val
if line_no == self.end:
line_val += '"""'
line_val = (' ' * self.col_offset) + line_val
fh.write(line_val.rstrip())
def __repr__(self):
return 'DocstringObj(start=%d,end=%d)' % (self.start, self.end)
def get_docstring_obj(ast_obj, heavy_ast):
if not isinstance(ast_obj, (ast.Module, ast.ClassDef, ast.FunctionDef)):
# Only module, class and function/methods can have a docstring.
return None
obj_body = getattr(ast_obj, 'body', [])
if len(obj_body) == 0:
return
docstring_candidate = obj_body[0]
if (isinstance(docstring_candidate, ast.Expr) and
isinstance(docstring_candidate.value, ast.Str)):
return DocstringObj(heavy_ast, ast_obj, docstring_candidate)
def _get_all_docstrings(ast_obj, result, heavy_ast):
docstring_obj = get_docstring_obj(ast_obj, heavy_ast)
if docstring_obj is not None:
if docstring_obj.start in result:
raise KeyError('Start: %d already in result.' % (docstring_obj.start,))
result[docstring_obj.start] = docstring_obj
child_objects = getattr(ast_obj, 'body', [])
if not isinstance(child_objects, collections.Iterable):
child_objects = [child_objects]
for child_ast_obj in child_objects:
_get_all_docstrings(child_ast_obj, result, heavy_ast)
def get_all_docstrings(heavy_ast):
if not isinstance(heavy_ast.ast_tree, ast.Module):
raise TypeError('Expected tree to be a module.')
result = {}
_get_all_docstrings(heavy_ast.ast_tree, result, heavy_ast)
return result
def rewrite_docstrings(filename):
heavy_ast = HeavyAst(filename)
all_docstrings = get_all_docstrings(heavy_ast)
heavy_ast.rewrite_file(all_docstrings)
return all_docstrings
def example():
A_ORIG = '\n'.join([
'def hello_func(name):',
' """Prints hello with the name.',
'',
' Args:',
' name: String, to print.',
' """',
' print \'Hello %s, nice to meet you.\' % (name,)',
'',
])
A_PEP8IFY = '\n'.join([
'def hello_func(name):',
' """Prints hello with the name.',
'',
' Args:',
' name: String, to print.',
' """',
' print \'Hello %s, nice to meet you.\' % (name, )',
'',
])
A_DESIRED = '\n'.join([
'def hello_func(name):',
' """Prints hello with the name.',
'',
' Args:',
' name: String, to print.',
' """',
' print \'Hello %s, nice to meet you.\' % (name, )',
'',
])
filename = tempfile.mkstemp()[1]
print 'Making example with temp file:', filename
with open(filename, 'w') as fh:
fh.write(A_PEP8IFY)
all_docstrings = rewrite_docstrings(filename)
print '=' * 70
print 'All docstrings found:'
for start in sorted(all_docstrings.keys()):
print all_docstrings[start]
# Check that back-up worked.
with open(filename + '.bak', 'r') as fh:
backed_up = fh.read()
if backed_up == A_PEP8IFY:
print 'Back-up succeeded.'
else:
raise ValueError('Back-up did not work correctly.')
# Check that the indent was correct.
with open(filename, 'r') as fh:
rewrite_content = fh.read()
if rewrite_content == A_DESIRED:
print 'Indent succeeded, new file:'
print ('=' * 70)
print rewrite_content, ('=' * 70)
else:
raise ValueError('Indent did not work correctly.')
if __name__ == '__main__':
# H/T: http://stackoverflow.com/a/9093598/1068170
if hasattr(__builtin__, '__IPYTHON__'):
print 'In IPYTHON, not running main().'
else:
if len(sys.argv) > 1:
filename = sys.argv[1]
rewrite_docstrings(filename)
else:
example()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment