Skip to content

Instantly share code, notes, and snippets.

@mattharrison
Last active June 22, 2020 20:48
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mattharrison/2a1a263597d80e99cf85e898b800ec32 to your computer and use it in GitHub Desktop.
Save mattharrison/2a1a263597d80e99cf85e898b800ec32 to your computer and use it in GitHub Desktop.
This script runs black on a text file with doctest and blackens them....
import argparse
import sys
import black
from blib2to3.pgen2.tokenize import TokenError
TEST_DATA = """
Normal
>>> name = 'matt'
>>> for num in [1,2,3]:
... print(num)
... print(2, num)
Rest
"""
LONG_LINES = '''
NORMAL
>>> def draw_cmap(name, size=10, aspect=.25):
... fig=plt.figure()
... fig.set_size_inches(4,4)
... #ax = plt.subplot(111)
... ax = plt.axes([0,0,1,1], frameon=False)
... mapname = name
... set_cmap(mapname)
... colors = getattr(cm, mapname)
... # Then we disable our xaxis and yaxis completely. If we just say plt.axis('off'),
... # they are still used in the computation of the image padding.
END
'''
def line_class(fin):
r"""
>>> for data in line_class(TEST_DATA.split('\n')):
... print(data)
(3, '')
(3, 'Normal')
(3, '')
(1, ">>> name = 'matt'")
(1, '>>> for num in [1,2,3]:')
(2, '... print(num)')
(2, '... print(2, num)')
(3, '')
(3, 'Rest')
(3, '')
(3, '')
(3, '')
"""
for line in fin:
sline = line.strip()
if sline.startswith('>>>'):
klass = 1
elif sline.startswith('...'):
klass = 2
else:
klass = 3
yield klass, line
def code_chunks(fin):
r"""
returns tuples of (klass, [lines])
klass:
1 is console code
2 non-console code
>>> lines = TEST_DATA.split('\n')
>>> print(lines)
['', 'Normal', '', ">>> name = 'matt'", '>>> for num in [1,2,3]:', '... print(num)', '... print(2, num)', '', 'Rest', '', '', '']
>>> print((list(code_chunks(lines))))
[(2, ['']), (2, ['Normal']), (2, ['']), (1, [">>> name = 'matt'"]), (1, ['>>> for num in [1,2,3]:', '... print(num)', '... print(2, num)']), (2, ['']), (2, ['Rest']), (2, ['']), (2, ['']), (2, [''])]
"""
in_code = False
chunk = []
for klass, line in line_class(fin):
if in_code:
if klass == 1:
if chunk:
yield 1, chunk # previously saw >>>, yield and add >>> in own chunk
chunk = [line]
elif klass == 2: # add ... to chunk
chunk.append(line)
elif klass == 3:
if chunk:
yield 1, chunk
yield 2, [line]
chunk = []
else:
if klass == 1:
chunk = [line]
in_code = True
elif klass == 2:
print("ERROR!")
else:
if chunk:
yield 1, chunk
chunk = []
in_code = False
yield 2, [line]
if chunk:
yield 1, chunk
def test_process():
r"""
>>> import io
>>> out = io.StringIO()
>>> data = [f'{line}\n' for line in LONG_LINES.split('\n')]
>>> process(data, out, chars=30)
>>> print(out.getvalue())
<BLANKLINE>
<BLANKLINE>
NORMAL
<BLANKLINE>
>>> def draw_cmap(
... name,
... size=10,
... aspect=0.25,
... ):
... fig = plt.figure()
... fig.set_size_inches(
... 4, 4
... )
... # ax = plt.subplot(111)
... ax = plt.axes(
... [0, 0, 1, 1],
... frameon=False,
... )
... mapname = name
... set_cmap(mapname)
... colors = getattr(
... cm, mapname
... )
... # Then we disable our xaxis and yaxis completely. If we just say plt.axis('off'),
... # they are still used in the computation of the image padding.
<BLANKLINE>
END
<BLANKLINE>
<BLANKLINE>
"""
pass
def process(fin, fout, chars=51):
lines = []
add_newlines = True
joiner = '\n'
orig_line_num = 0
new_line_num = 0
for i, (klass, chunk) in enumerate(code_chunks(fin)):
if i == 0:
first = chunk[0]
add_newlines = not first.endswith('\n')
if not add_newlines:
joiner = ''
if klass == 2:
lines.append(chunk[0])
else:
first = chunk[0]
whitespace_len = len(first) - len(first.lstrip())
new_chunk = [line[whitespace_len + 4:] for line in chunk]
mode = black.FileMode(line_length=chars-4) #target_versions=set(),
try:
new_content = black.format_str(joiner.join(new_chunk), mode=mode)
#new_content = black.format_file_contents(joiner.join(new_chunk),
# fast=True, mode=mode)
except black.InvalidInput as ex:
print("LINENUM", orig_line_num)
print("ERRR!", joiner.join(new_chunk))
raise
except TokenError:
print("2LINENUM", orig_line_num)
print("2ERRR!", joiner.join(new_chunk))
raise
new_with_prompts = new_content.split('\n')
new_with_prompts[0] = f'{" "*whitespace_len}>>> {new_with_prompts[0]}'
for i in range(1, len(new_with_prompts)):
new_with_prompts[i] = f'{" "*whitespace_len}... {new_with_prompts[i]}'
if new_with_prompts[-1].strip() == '...':
new_with_prompts = new_with_prompts[:-1]
lines.extend([f'{line}\n' for line in new_with_prompts])
new_line_num = len(lines)
orig_line_num += len(chunk)
fout.write(joiner.join(lines))
def main(args):
ap = argparse.ArgumentParser(description='Look for python console snippets and apply black to them')
ap.add_argument('-s', '--src', help='src file')
ap.add_argument('-d', '--dst', help='dst file (default stdout)', default=sys.stdout)
ap.add_argument('-l', '--length', help='black split column (from >>> position, so if you start >>> at col 15 and set this to 30 you only have til column 45) default 51', type=int, default=51)
ap.add_argument('-t', '--test', help='run doctest', action='store_true')
opt = ap.parse_args(args)
if opt.src:
with open(opt.src) as fin:
if opt.dst != sys.stdout:
fout = open(opt.dst, 'w')
else:
fout = sys.stdout
process(fin, fout, chars=opt.length)
if opt.test:
import doctest
doctest.testmod()
if __name__ == '__main__':
main(sys.argv[1:])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment