Created
August 26, 2016 22:07
-
-
Save lukeyeager/dfefe50979f32a8cba43a13997d7e6fd to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python2 | |
import argparse | |
import csv | |
from collections import OrderedDict | |
import os.path | |
import re | |
def parse_timings(input_filename, output_filename): | |
benchmark_started = False | |
iteration_timings = [] | |
layer_timings = OrderedDict() | |
avg_forward = None | |
avg_backward = None | |
avg_forward_backward = None | |
total_time = None | |
benchmark_ended = False | |
with open(input_filename, 'r') as infile: | |
for line in infile: | |
match = re.match('.+\] (.+)', line) | |
if match: | |
message = match.group(1).strip() | |
if message == '*** Benchmark begins ***': | |
benchmark_started = True | |
continue | |
if message == '*** Benchmark ends ***': | |
benchmark_ended = True | |
continue | |
match = re.match('Average Forward pass: (([0-9]*[.])?[0-9]+) ms.', message) | |
if match: | |
avg_forward = float(match.group(1)) | |
continue | |
match = re.match('Average Backward pass: (([0-9]*[.])?[0-9]+) ms.', message) | |
if match: | |
avg_backward = float(match.group(1)) | |
continue | |
match = re.match('Average Forward-Backward: (([0-9]*[.])?[0-9]+) ms.', message) | |
if match: | |
avg_forward_backward = float(match.group(1)) | |
continue | |
match = re.match('Total Time: (([0-9]*[.])?[0-9]+) ms.', message) | |
if match: | |
total_time = float(match.group(1)) | |
continue | |
match = re.match('Iteration: \d+ forward-backward time: (([0-9]*[.])?[0-9]+) ms.', message) | |
if match: | |
iteration_timings.append(float(match.group(1))) | |
continue | |
match = re.match('(\S+)\s+(\S+ward): (([0-9]*[.])?[0-9]+) ms.', message) | |
if match: | |
layer_name = match.group(1) | |
direction = match.group(2) | |
timing = float(match.group(3)) | |
if layer_name not in layer_timings: | |
layer_timings[layer_name] = {} | |
layer_timings[layer_name][direction] = timing | |
continue | |
assert benchmark_started, 'Never received "Benchmark starts" message' | |
assert benchmark_ended, 'Never received "Benchmark ends" message' | |
print len(iteration_timings), 'iterations' | |
print len(layer_timings), 'layers' | |
with open(output_filename, 'w') as outfile: | |
writer = csv.writer(outfile) | |
writer.writerow( ('Iteration', 'Time (ms)') ) | |
for i, t in enumerate(iteration_timings): | |
writer.writerow( (i+1, t) ) | |
writer.writerow( ('',) ) | |
writer.writerow( ('Layer', 'Forward (ms)', 'Forward (%)', 'Backward (ms)', 'Backward (%)') ) | |
for layer_name, d in layer_timings.iteritems(): | |
writer.writerow(( | |
layer_name, | |
d['forward'], | |
100 * d['forward'] / avg_forward, | |
d['backward'], | |
100 * d['backward'] / avg_backward, | |
)) | |
writer.writerow( ('',) ) | |
writer.writerow( ('Average forward pass', avg_forward) ) | |
writer.writerow( ('Average backward pass', avg_backward) ) | |
writer.writerow( ('Average forward-backward pass', avg_forward_backward) ) | |
writer.writerow( ('Total time', total_time) ) | |
print 'Saved to', output_filename | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('input_txt') | |
parser.add_argument('output_csv') | |
args = parser.parse_args() | |
assert os.path.exists(args.input_txt) | |
assert os.path.splitext(args.input_txt)[1] == '.txt' | |
assert os.path.splitext(args.output_csv)[1] == '.csv' | |
parse_timings(args.input_txt, args.output_csv) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment