Last active
October 7, 2022 01:59
-
-
Save risteon/263516474fe3c7ed03144fd24c175012 to your computer and use it in GitHub Desktop.
Tensorflow checkpoints: print all vars {name, shape, dtype} or export them to .csv for further inspection.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
import pathlib | |
import click | |
import csv | |
import prettytable | |
import tensorflow as tf | |
def get(checkpoints, value: str): | |
return [x.get_tensor(value) for x in checkpoints] | |
def inspect(file_name, export_csv: bool = False, quiet: bool = False): | |
try: | |
fieldnames = ["name", "shape", "dtype"] | |
reader = tf.train.load_checkpoint(file_name) | |
var_to_dtype_map = reader.get_variable_to_dtype_map() | |
var_to_shape_map = reader.get_variable_to_shape_map() | |
data = {k: (v, var_to_dtype_map[k]) for k, v in var_to_shape_map.items()} | |
data = [[x[0], x[1][0], x[1][1]] for x in sorted(data.items())] | |
data_strs = [[str(x[0]), str(x[1]), x[2].name] for x in data] | |
if not quiet: | |
x = prettytable.PrettyTable() | |
x.field_names = fieldnames | |
for field in fieldnames: | |
x.align[field] = "l" | |
for s in data_strs: | |
x.add_row(s) | |
print(x) | |
if export_csv: | |
with open( | |
"vars_{}.csv".format(pathlib.Path(file_name).name), mode="w" | |
) as csv_file: | |
csv_writer = csv.writer( | |
csv_file, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL | |
) | |
csv_writer.writerow(fieldnames) | |
csv_writer.writerows(data_strs) | |
except Exception as e: # pylint: disable=broad-except | |
print(str(e)) | |
if "corrupted compressed block contents" in str(e): | |
print( | |
"It's likely that your checkpoint file has been compressed " | |
"with SNAPPY." | |
) | |
if "Data loss" in str(e) and any( | |
e in file_name for e in [".index", ".meta", ".data"] | |
): | |
proposed_file = ".".join(file_name.split(".")[0:-1]) | |
v2_file_error_template = """ | |
It's likely that this is a V2 checkpoint and you need to provide the filename | |
*prefix*. Try removing the '.' and extension. Try: | |
inspect checkpoint --file_name = {}""" | |
print(v2_file_error_template.format(proposed_file)) | |
@click.command() | |
@click.argument("checkpoint") | |
@click.option("--export-csv/--no-export-csv", default=False) | |
@click.option("--quiet/--no-quiet", default=False) | |
def main(checkpoint, export_csv, quiet): | |
inspect(checkpoint, export_csv, quiet) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment