Skip to content

Instantly share code, notes, and snippets.

@dpattison3
Last active October 20, 2021 07:11
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dpattison3/0b08002479d4e4f8eb98d6cc55f500a5 to your computer and use it in GitHub Desktop.
Save dpattison3/0b08002479d4e4f8eb98d6cc55f500a5 to your computer and use it in GitHub Desktop.
import commentjson
import os
import sys
import collections
import numpy as np
import scipy as scp
import scipy.misc
import tensorflow as tf
sys.path.insert(1,'incl')
try:
# Check whether setup was done correctly
import tensorvision.utils as tv_utils
import tensorvision.core as core
import tensorvision.train
import tensorflow_fcn.utils
except ImportError:
# You forgot to initialize submodules
logging.error("Could not import the submodules.")
logging.error("Please execute:"
"'git submodule update --init --recursive'")
exit(1)
# load the network from the working directory
# for some reason I can't get this to work without
# using tensorvision - TODO figure this out so we can
# use any checkpoint from the run directory
logdir = 'CHANGE_THIS'
image_pl = tf.placeholder(tf.float32)
hypes = tv_utils.load_hypes_from_logdir(logdir)
modules = tv_utils.load_modules_from_logdir(logdir)
image_pl = tf.placeholder(tf.float32)
image = tf.expand_dims(image_pl, 0)
pred = core.build_inference_graph(hypes, modules,
image=image)
# load the weights
sess = tf.Session()
saver = tf.train.Saver()
core.load_weights(logdir, sess, saver)
# freeze the graph
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess, sess.graph.as_graph_def(), ['Validation/decoder/Softmax'])
# remove training only nodes
frozen_graph_def = tf.graph_util.remove_training_nodes(frozen_graph_def)
# save the model
with tf.gfile.GFile('test_frozen_model.pb', 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment