Skip to content

Instantly share code, notes, and snippets.

@risteon
Last active October 7, 2022 01:59
Show Gist options
  • Save risteon/263516474fe3c7ed03144fd24c175012 to your computer and use it in GitHub Desktop.
Save risteon/263516474fe3c7ed03144fd24c175012 to your computer and use it in GitHub Desktop.
Tensorflow checkpoints: print all vars {name, shape, dtype} or export them to .csv for further inspection.
# -*- coding: utf-8 -*-
import pathlib
import click
import csv
import prettytable
import tensorflow as tf
def get(checkpoints, value: str):
return [x.get_tensor(value) for x in checkpoints]
def inspect(file_name, export_csv: bool = False, quiet: bool = False):
try:
fieldnames = ["name", "shape", "dtype"]
reader = tf.train.load_checkpoint(file_name)
var_to_dtype_map = reader.get_variable_to_dtype_map()
var_to_shape_map = reader.get_variable_to_shape_map()
data = {k: (v, var_to_dtype_map[k]) for k, v in var_to_shape_map.items()}
data = [[x[0], x[1][0], x[1][1]] for x in sorted(data.items())]
data_strs = [[str(x[0]), str(x[1]), x[2].name] for x in data]
if not quiet:
x = prettytable.PrettyTable()
x.field_names = fieldnames
for field in fieldnames:
x.align[field] = "l"
for s in data_strs:
x.add_row(s)
print(x)
if export_csv:
with open(
"vars_{}.csv".format(pathlib.Path(file_name).name), mode="w"
) as csv_file:
csv_writer = csv.writer(
csv_file, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL
)
csv_writer.writerow(fieldnames)
csv_writer.writerows(data_strs)
except Exception as e: # pylint: disable=broad-except
print(str(e))
if "corrupted compressed block contents" in str(e):
print(
"It's likely that your checkpoint file has been compressed "
"with SNAPPY."
)
if "Data loss" in str(e) and any(
e in file_name for e in [".index", ".meta", ".data"]
):
proposed_file = ".".join(file_name.split(".")[0:-1])
v2_file_error_template = """
It's likely that this is a V2 checkpoint and you need to provide the filename
*prefix*. Try removing the '.' and extension. Try:
inspect checkpoint --file_name = {}"""
print(v2_file_error_template.format(proposed_file))
@click.command()
@click.argument("checkpoint")
@click.option("--export-csv/--no-export-csv", default=False)
@click.option("--quiet/--no-quiet", default=False)
def main(checkpoint, export_csv, quiet):
inspect(checkpoint, export_csv, quiet)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment