Skip to content

Instantly share code, notes, and snippets.

@xiang-deng
Created October 1, 2019 19:28
Show Gist options
  • Save xiang-deng/4b97f0750a965a24dbdc6d3be82149b0 to your computer and use it in GitHub Desktop.
Save xiang-deng/4b97f0750a965a24dbdc6d3be82149b0 to your computer and use it in GitHub Desktop.
inspect tensorflow checkpoint
import numpy as np
import json
import pickle
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
import pdb
import argparse
if __name__ == "__main__":
args = argparse.ArgumentParser()
args.add_argument(
"-c",
"--checkpoint",
default=None,
type=str,
)
args.add_argument(
"-o",
"--output",
default=None,
type=str,
)
args = args.parse_args()
reader = pywrap_tensorflow.NewCheckpointReader(args.checkpoint)
var_to_shape_map = reader.get_variable_to_shape_map()
dump_tensor = {}
for key in sorted(var_to_shape_map):
print("tensor_name: ", key)
dump_tensor[key] = reader.get_tensor(key)
pdb.set_trace()
with open(args.output, 'wb') as f:
pickle.dump(dump_tensor, f)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment