Skip to content

Instantly share code, notes, and snippets.

@lukeyeager
Created August 26, 2016 22:07
Show Gist options
  • Save lukeyeager/dfefe50979f32a8cba43a13997d7e6fd to your computer and use it in GitHub Desktop.
Save lukeyeager/dfefe50979f32a8cba43a13997d7e6fd to your computer and use it in GitHub Desktop.
#!/usr/bin/env python2
import argparse
import csv
from collections import OrderedDict
import os.path
import re
def parse_timings(input_filename, output_filename):
benchmark_started = False
iteration_timings = []
layer_timings = OrderedDict()
avg_forward = None
avg_backward = None
avg_forward_backward = None
total_time = None
benchmark_ended = False
with open(input_filename, 'r') as infile:
for line in infile:
match = re.match('.+\] (.+)', line)
if match:
message = match.group(1).strip()
if message == '*** Benchmark begins ***':
benchmark_started = True
continue
if message == '*** Benchmark ends ***':
benchmark_ended = True
continue
match = re.match('Average Forward pass: (([0-9]*[.])?[0-9]+) ms.', message)
if match:
avg_forward = float(match.group(1))
continue
match = re.match('Average Backward pass: (([0-9]*[.])?[0-9]+) ms.', message)
if match:
avg_backward = float(match.group(1))
continue
match = re.match('Average Forward-Backward: (([0-9]*[.])?[0-9]+) ms.', message)
if match:
avg_forward_backward = float(match.group(1))
continue
match = re.match('Total Time: (([0-9]*[.])?[0-9]+) ms.', message)
if match:
total_time = float(match.group(1))
continue
match = re.match('Iteration: \d+ forward-backward time: (([0-9]*[.])?[0-9]+) ms.', message)
if match:
iteration_timings.append(float(match.group(1)))
continue
match = re.match('(\S+)\s+(\S+ward): (([0-9]*[.])?[0-9]+) ms.', message)
if match:
layer_name = match.group(1)
direction = match.group(2)
timing = float(match.group(3))
if layer_name not in layer_timings:
layer_timings[layer_name] = {}
layer_timings[layer_name][direction] = timing
continue
assert benchmark_started, 'Never received "Benchmark starts" message'
assert benchmark_ended, 'Never received "Benchmark ends" message'
print len(iteration_timings), 'iterations'
print len(layer_timings), 'layers'
with open(output_filename, 'w') as outfile:
writer = csv.writer(outfile)
writer.writerow( ('Iteration', 'Time (ms)') )
for i, t in enumerate(iteration_timings):
writer.writerow( (i+1, t) )
writer.writerow( ('',) )
writer.writerow( ('Layer', 'Forward (ms)', 'Forward (%)', 'Backward (ms)', 'Backward (%)') )
for layer_name, d in layer_timings.iteritems():
writer.writerow((
layer_name,
d['forward'],
100 * d['forward'] / avg_forward,
d['backward'],
100 * d['backward'] / avg_backward,
))
writer.writerow( ('',) )
writer.writerow( ('Average forward pass', avg_forward) )
writer.writerow( ('Average backward pass', avg_backward) )
writer.writerow( ('Average forward-backward pass', avg_forward_backward) )
writer.writerow( ('Total time', total_time) )
print 'Saved to', output_filename
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('input_txt')
parser.add_argument('output_csv')
args = parser.parse_args()
assert os.path.exists(args.input_txt)
assert os.path.splitext(args.input_txt)[1] == '.txt'
assert os.path.splitext(args.output_csv)[1] == '.csv'
parse_timings(args.input_txt, args.output_csv)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment