Skip to content

Instantly share code, notes, and snippets.

@bricakeld
Created April 7, 2019 07:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bricakeld/a3d77c19f3eac96852159be9bec963a1 to your computer and use it in GitHub Desktop.
Save bricakeld/a3d77c19f3eac96852159be9bec963a1 to your computer and use it in GitHub Desktop.
Cut down sklearn tree to display a portion of a very large tree
import sys
import os
import re
dir = os.path.dirname(__file__)
def usage():
print('python tree_truncate.py <input dot file> <number of nodes> <optional: output dot file | default=out.dot>')
def check():
if len(sys.argv) < 3 or len(sys.argv) > 4:
usage()
sys.exit()
def truncate():
infile = os.path.abspath(sys.argv[1])
outfile = 'out.dot'
num_nodes = int(sys.argv[2])
print(infile)
if len(sys.argv) == 4:
outfile = os.path.abspath(sys.argv[3])
content = None
with open(infile) as f:
content = f.readlines()
content = [x.strip() for x in content]
if len(content) < (num_nodes * 2) + 4:
print('there are less nodes than requested, no need to truncate')
sys.exit()
outlines = []
outlines.extend(content[:4])
nums = [str(x) for x in range(num_nodes)]
matchers = ['^' + x + ' ->' for x in nums]
matchers.extend(['^' + x + ' \[' for x in nums])
matchers.extend([' -> ' + x + ' ' for x in nums])
print(matchers)
outlines.extend([s for s in content if any(bool(re.search(xs, s)) for xs in matchers)])
for l in outlines:
if re.match('^[\d]+ -> [\d]+ ;$', l):
temp = l.split()
number = temp[2]
if int(number) > num_nodes:
outlines.extend([s for s in content if bool(re.search('^' + number + ' \[', s))])
# print(outlines)
outlines.extend(content[-1])
with open(outfile, 'w', encoding='utf-8') as f:
f.writelines('%s\n' % s for s in outlines)
def main():
check()
truncate()
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment