Skip to content

Instantly share code, notes, and snippets.

@ppwwyyxx
Created September 4, 2020 09:05
Show Gist options
  • Save ppwwyyxx/b28c0dc3062fdae5f0a47801afae35ba to your computer and use it in GitHub Desktop.
Save ppwwyyxx/b28c0dc3062fdae5f0a47801afae35ba to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
import sys
import pprint
import copy
from collections import OrderedDict
class ComputeAPICall:
def __init__(self, api, original):
self.api = api
self.original = original # original logs
self.attrs = OrderedDict()
def __str__(self):
return self.api + "\n" + pprint.pformat(self.attrs) + "\n"
__repr__ = __str__
def shape_hash(self):
dic = copy.deepcopy(self.attrs)
dic.pop("algo")
return hash(frozenset(dic.items()) | frozenset({self.api}))
def all_hash(self):
return hash(frozenset(self.attrs.items()) | frozenset({self.api}))
def parse_compute_block(api, file):
ret = []
for line in file:
line = line.strip()
if not line:
break
for ignore in ["handle", "workSpace", "workSpaceSizeInBytes",
"arrayLength", "reorderType", "dxData", "vect:", "nbDims:",
"dyData", "wData", "alpha", "beta", "xData", "yData",
"Process", "mathType:", "mode:", "Time:"]:
if ignore in line:
break
else:
ret.append(line)
call = ComputeAPICall(api, "\n".join(ret))
def parse_line(line):
ret = {}
ret["name"], rest = line.strip(":").split(":")
for attr in rest.strip(";").split(";"):
attr = attr.strip()
name, value = attr.split("=")
if name == "type":
continue
ret[name] = value
return ret
curr_arg_stack = []
curr_indent = 0
for line in ret:
if line.startswith("i!"):
line = line[2:]
line_strip = line.lstrip()
indent = len(line) - len(line_strip)
line = parse_line(line_strip)
if indent <= curr_indent: # pop previous line
curr_arg_stack.pop()
if indent < curr_indent: # pop previous block
curr_arg_stack.pop()
curr_indent = indent
curr_arg_stack.append(line["name"])
name = ".".join(curr_arg_stack)
if "val" in line:
call.attrs[name] = line["val"]
return call
def find_compute_calls(filename):
allowed = ["cudnnConvolutionForward", "cudnnConvolutionBackwardFilter", "cudnnConvolutionBackwardData"]
with open(filename) as f:
for line in f:
line = line.strip()
if line.startswith("I!"):
for api in allowed:
if api + "()" in line:
blk = parse_compute_block(api, f)
yield blk
else:
continue
def find_compute_calls_dedup(filename):
all_calls = set()
cnt = 0
for call in find_compute_calls(filename):
h = call.all_hash()
if h not in all_calls:
all_calls.add(h)
yield call
else:
cnt += 1
if cnt > 2000: # found 2k dup, assume that no more new convs will appear
return
if __name__ == "__main__":
# filename = sys.argv[1]
# for call in find_compute_calls_dedup(filename):
# print(call)
v7 = list(find_compute_calls_dedup("cudnnlog_cudnn7_cu102.txt"))
v8 = list(find_compute_calls_dedup("cudnnlog_cudnn8_cu102.txt"))
v7_map = {x.shape_hash() : x for x in v7}
for v8_call in v8:
shape_hash = v8_call.shape_hash()
if shape_hash in v7_map:
v7_call = v7_map[shape_hash]
if v7_call.all_hash() != v8_call.all_hash():
print("v7", v7_call)
print("v8", v8_call)
print('--' * 10)
else:
import IPython as IP; IP.embed()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment