Created
August 27, 2017 16:25
-
-
Save wenfahu/e822ef09716b63936984d181d34c6d28 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow as tf | |
import numpy as np | |
import pickle | |
from tensorflow.core.framework import attr_value_pb2 | |
from tensorflow.core.framework import graph_pb2 | |
from tensorflow.core.framework import node_def_pb2 | |
from tensorflow.python.framework import tensor_util | |
from tensorflow.python.platform import gfile | |
import re | |
import argparse | |
import pdb, traceback, sys | |
def main(args): | |
GRAPH_PATH= args.GRAPH_PATH | |
MASK_PATH= args.MASK_PATH | |
output_graph=args.output_graph | |
with tf.Session() as sess: | |
with tf.gfile.FastGFile(GRAPH_PATH, 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
_ = tf.import_graph_def(graph_def, name="") | |
# nodes_map = dict( [(n.name, n) for n in tf.get_default_graph().as_graph_def().node ]) | |
nodes_map = {} | |
conv_nodes = [ n for n in tf.get_default_graph().as_graph_def().node if n.op == 'Conv2D'] | |
edges = {} # Keyed by the dest node name. | |
# Keeps track of node sequences. It is important to still output the | |
# operations in the original order. | |
node_seq = {} # Keyed by node name. | |
seq = 0 | |
gd = tf.get_default_graph().as_graph_def() | |
def dfs_last_conv(node): | |
if node.input: | |
for n in node.input: | |
in_node = nodes_map[n] | |
if in_node.op == 'Conv2D': | |
return in_node | |
elif in_node.op == 'Concat': | |
return None | |
else: | |
return dfs_last_conv(in_node) | |
else: | |
return None | |
def _node_name(n): | |
if n.startswith("^"): | |
return n[1:] | |
else: | |
return n.split(":")[0] | |
for node in gd.node: | |
n = _node_name(node.name) | |
nodes_map[n] = node | |
edges[n] = [_node_name(x) for x in node.input] | |
node_seq[n] = seq | |
seq += 1 | |
with open(MASK_PATH, 'r') as f: | |
mask = pickle.load(f) | |
def get_node_output_channels(node_name): | |
weight_node = nodes_map[node_name] | |
tensor = weight_node.attr['value'].tensor | |
return int(tensor.tensor_shape.dim[3].size) | |
class GraphNode(object): | |
"""Pruning info for a node. """ | |
def __init__(self, channels, indices): | |
"""TODO: to be defined1. """ | |
self.channels = channels | |
self.indices = indices | |
# pdb.set_trace() | |
node_n = [n.name for n in conv_nodes] | |
weights_n = [None] * len(node_n) | |
for idx, conv in enumerate(node_n): | |
node = nodes_map[conv] | |
for n in node.input: | |
if nodes_map[n].op == 'Const': | |
weights_n[idx] = n | |
def get_conv_weights(node): | |
''' | |
return the nodes contains weights of input conv node | |
''' | |
for n in node.input: | |
if nodes_map[n].op == 'Const': | |
return n | |
input_mask = {} # pruning mask for the input channels | |
node_info = {} # store info for the milestone nodes | |
output_mask = {} # pruning mask for the output channels | |
bias_mask = {} # pruning mask for the biases | |
for k,v in mask.iteritems(): | |
output_mask[_node_name(k)] = np.array(v) | |
for idx, n in enumerate(weights_n): | |
if n in output_mask: | |
pass | |
else: | |
output_mask[n] = np.array([]) | |
# pdb.set_trace() | |
for idx, n in enumerate(weights_n): | |
conv_name = node_n[idx] | |
last_conv = dfs_last_conv(nodes_map[conv_name]) | |
if last_conv: | |
last_conv_w = get_conv_weights(last_conv) | |
input_mask[n] = output_mask[last_conv_w] | |
else: | |
input_mask[n] = np.array([]) | |
# pdb.set_trace() | |
# General convs | |
# shallow_convs = [n for n in node_n if re.match(r'InceptionResnetV1/Conv2d_\d[a-z]_\dx\d', n)] | |
input_mask['InceptionResnetV1/Conv2d_2a_3x3/weights'] = output_mask['InceptionResnetV1/Conv2d_1a_3x3/weights'] | |
input_mask['InceptionResnetV1/Conv2d_2b_3x3/weights'] = output_mask['InceptionResnetV1/Conv2d_2a_3x3/weights'] | |
input_mask['InceptionResnetV1/Conv2d_3b_1x1/weights'] = output_mask['InceptionResnetV1/Conv2d_2b_3x3/weights'] | |
input_mask['InceptionResnetV1/Conv2d_4a_3x3/weights'] = output_mask['InceptionResnetV1/Conv2d_3b_1x1/weights'] | |
input_mask['InceptionResnetV1/Conv2d_4b_3x3/weights'] = output_mask['InceptionResnetV1/Conv2d_4a_3x3/weights'] | |
# Pruning the sub graphs with residual block , | |
# the pruning mask of residual block is determined by the shortcut projection | |
# the residual blocks in the network are the 5 x Block35, 10 x block17, 5 x block5 | |
# Incepiton module for inception_renset_a | |
for i in range(1, 6): | |
# pdb.set_trace() | |
branch_0 = 'InceptionResnetV1/Repeat/block35_{}/Branch_0/Conv2d_1x1/weights'.format(i) | |
branch_1 = 'InceptionResnetV1/Repeat/block35_{}/Branch_1/Conv2d_0b_3x3/weights'.format(i) | |
branch_2 = 'InceptionResnetV1/Repeat/block35_{}/Branch_2/Conv2d_0c_3x3/weights'.format(i) | |
mixed_conv = 'InceptionResnetV1/Repeat/block35_{}/Conv2d_1x1/weights'.format(i) | |
mixed_bias = re.sub(r'weights', 'biases', mixed_conv) | |
input_mask[mixed_conv] = np.concatenate((np.array(output_mask[branch_0]), | |
(np.array(output_mask[branch_1]) + get_node_output_channels(branch_0)), | |
(np.array(output_mask[branch_2]) + get_node_output_channels(branch_0) + get_node_output_channels(branch_1))) ) | |
# Inception-Resnet-A's last conv has to keep consistence with conv2d_4b_3x3 | |
output_mask[mixed_conv] = output_mask['InceptionResnetV1/Conv2d_4b_3x3/weights'] | |
bias_mask[mixed_bias] = output_mask[mixed_conv] | |
node_info['inception-resnet-a'] = GraphNode(256, output_mask['InceptionResnetV1/Conv2d_4b_3x3/weights']) | |
# Reduction A: output has a concat of feature maps [ branch0 + branch1 + branch2 ( inception-resnet-a )] | |
# pdb.set_trace() | |
branch_0_in = 'InceptionResnetV1/Mixed_6a/Branch_0/Conv2d_1a_3x3/weights' | |
branch_1_in = 'InceptionResnetV1/Mixed_6a/Branch_1/Conv2d_0a_1x1/weights' | |
indices = node_info['inception-resnet-a'].indices | |
input_mask[branch_0_in] = indices | |
input_mask[branch_1_in] = indices | |
num_fm_mix6a_branch0 = get_node_output_channels('InceptionResnetV1/Mixed_6a/Branch_0/Conv2d_1a_3x3/weights') | |
num_fm_mix6a_branch1 = get_node_output_channels('InceptionResnetV1/Mixed_6a/Branch_1/Conv2d_1a_3x3/weights') | |
num_fm_mix6a_branch2 = node_info['inception-resnet-a'].channels | |
node_info['reduction_a'] = GraphNode(np.sum([num_fm_mix6a_branch0, num_fm_mix6a_branch1, num_fm_mix6a_branch2]), | |
np.concatenate((output_mask['InceptionResnetV1/Mixed_6a/Branch_0/Conv2d_1a_3x3/weights'], | |
np.array(output_mask['InceptionResnetV1/Mixed_6a/Branch_1/Conv2d_1a_3x3/weights']) + num_fm_mix6a_branch0, | |
np.array(node_info['inception-resnet-a'].indices) + num_fm_mix6a_branch0 + num_fm_mix6a_branch1 ) )) | |
# 10 x Inception-Resnet-B | |
# Incepiton module for inception_renset_b | |
for i in range(1, 11): | |
branch_0 = 'InceptionResnetV1/Repeat_1/block17_{idx}/Branch_0/Conv2d_1x1/weights'.format(idx=i) | |
branch_1 = 'InceptionResnetV1/Repeat_1/block17_{idx}/Branch_1/Conv2d_0c_7x1/weights'.format(idx=i) | |
branch_1_in = 'InceptionResnetV1/Repeat_1/block17_{idx}/Branch_1/Conv2d_0a_1x1/weights'.format(idx=i) | |
mixed_conv = 'InceptionResnetV1/Repeat_1/block17_{idx}/Conv2d_1x1/weights'.format(idx=i) | |
mixed_bias = re.sub(r'weights', 'biases', mixed_conv) | |
input_mask[branch_0] = node_info['reduction_a'].indices | |
input_mask[branch_1_in] = node_info['reduction_a'].indices | |
input_mask[mixed_conv] = np.concatenate((np.array(output_mask[branch_0]), | |
(np.array(output_mask[branch_1]) + get_node_output_channels(branch_0)) ) ) | |
output_mask[mixed_conv] = node_info['reduction_a'].indices | |
bias_mask[mixed_bias] = output_mask[mixed_conv] | |
node_info['inception-resnet-b'] = node_info['reduction_a'] | |
# Reduction B | |
num_fm_mix7a_branch0 = get_node_output_channels('InceptionResnetV1/Mixed_7a/Branch_0/Conv2d_1a_3x3/weights') | |
num_fm_mix7a_branch1 = get_node_output_channels('InceptionResnetV1/Mixed_7a/Branch_1/Conv2d_1a_3x3/weights') | |
num_fm_mix7a_branch2 = get_node_output_channels('InceptionResnetV1/Mixed_7a/Branch_2/Conv2d_1a_3x3/weights') | |
num_fm_mix7a_branch3 = node_info['inception-resnet-b'].channels | |
branch_0_in = 'InceptionResnetV1/Mixed_7a/Branch_0/Conv2d_0a_1x1/weights' | |
branch_1_in = 'InceptionResnetV1/Mixed_7a/Branch_1/Conv2d_0a_1x1/weights' | |
branch_2_in = 'InceptionResnetV1/Mixed_7a/Branch_2/Conv2d_0a_1x1/weights' | |
indices = node_info['inception-resnet-b'].indices | |
input_mask[branch_0_in] = indices | |
input_mask[branch_1_in] = indices | |
input_mask[branch_2_in] = indices | |
node_info['reduction_b'] = GraphNode(np.sum([num_fm_mix7a_branch0, num_fm_mix7a_branch1, num_fm_mix7a_branch2, num_fm_mix7a_branch3]), | |
np.concatenate((output_mask['InceptionResnetV1/Mixed_7a/Branch_0/Conv2d_1a_3x3/weights'], | |
np.array(output_mask['InceptionResnetV1/Mixed_7a/Branch_1/Conv2d_1a_3x3/weights']) + num_fm_mix7a_branch0, | |
np.array(output_mask['InceptionResnetV1/Mixed_7a/Branch_2/Conv2d_1a_3x3/weights']) + num_fm_mix7a_branch0 + num_fm_mix7a_branch1, | |
np.array(node_info['inception-resnet-b'].indices) + num_fm_mix7a_branch0 + num_fm_mix7a_branch1 + num_fm_mix7a_branch2 ) )) | |
# 5 x Inception-Resnet-C | |
# Incepiton module for inception_renset_c | |
for i in range(1, 6): | |
branch_0 = 'InceptionResnetV1/Repeat_2/block8_{idx}/Branch_0/Conv2d_1x1/weights'.format(idx=i) | |
branch_1 = 'InceptionResnetV1/Repeat_2/block8_{idx}/Branch_1/Conv2d_0c_3x1/weights'.format(idx=i) | |
branch_1_in = 'InceptionResnetV1/Repeat_2/block8_{idx}/Branch_1/Conv2d_0a_1x1/weights'.format(idx=i) | |
mixed_conv = 'InceptionResnetV1/Repeat_2/block8_{idx}/Conv2d_1x1/weights'.format(idx=i) | |
mixed_bias = re.sub(r'weights', 'biases', mixed_conv) | |
input_mask[branch_0] = node_info['reduction_b'].indices | |
input_mask[branch_1_in] = node_info['reduction_b'].indices | |
input_mask[mixed_conv] = np.concatenate((np.array(output_mask[branch_0]), | |
(np.array(output_mask[branch_1]) + get_node_output_channels(branch_0)) )) | |
output_mask[mixed_conv] = node_info['reduction_b'].indices | |
#pdb.set_trace() | |
bias_mask[mixed_bias] = output_mask[mixed_conv] | |
node_info['inception-resnet-c'] = node_info['reduction_b'] | |
# the last block8 | |
branch_0 = 'InceptionResnetV1/Block8/Branch_0/Conv2d_1x1/weights'.format(idx=i) | |
branch_1 = 'InceptionResnetV1/Block8/Branch_1/Conv2d_0c_3x1/weights'.format(idx=i) | |
branch_1_in = 'InceptionResnetV1/Block8/Branch_1/Conv2d_0a_1x1/weights'.format(idx=i) | |
mixed_conv = 'InceptionResnetV1/Block8/Conv2d_1x1/weights'.format(idx=i) | |
mixed_bias = re.sub(r'weights', 'biases', mixed_conv) | |
input_mask[branch_0] = node_info['inception-resnet-c'].indices | |
input_mask[branch_1_in] = node_info['inception-resnet-c'].indices | |
input_mask[mixed_conv] = np.concatenate((np.array(output_mask[branch_0]), | |
(np.array(output_mask[branch_1]) + get_node_output_channels(branch_0)) )) | |
output_mask[mixed_conv] = node_info['inception-resnet-c'].indices | |
bias_mask[mixed_bias] = output_mask[mixed_conv] | |
# batch norm nodes | |
bn_mask = {} | |
new_shape = {} | |
for n in weights_n: | |
if n == 'InceptionResnetV1/Block8/Conv2d_1x1/weights': | |
continue | |
if re.search(r'block\d+_\d+/Conv2d_1x1/weights$', n): | |
continue | |
bn_prefix = re.sub(r'weights$', 'BatchNorm/', n) | |
beta = bn_prefix + 'beta' | |
moving_variance = bn_prefix + 'moving_variance' | |
moving_mean = bn_prefix + 'moving_mean' | |
# moving_mean_variance = bn_prefix + 'moving_mean_variance' | |
moments_shape = bn_prefix + 'moments/Shape' | |
bn_mask[beta] = output_mask[n] | |
bn_mask[moving_mean] = output_mask[n] | |
bn_mask[moving_variance] = output_mask[n] | |
# bn_mask[moving_mean_variance] = output_mask[n] | |
new_shape[moments_shape] = output_mask[n].shape[0] | |
retrain_mask = {} | |
for node, indices in output_mask.iteritems(): | |
shape = nodes_map[node].attr['value'].tensor.tensor_shape | |
shape = [dim.size for dim in shape.dim] | |
retrain_mask[node] = { | |
'shape': shape, | |
'indices': list(indices) | |
} | |
for node, indices in bn_mask.iteritems(): | |
shape = nodes_map[node].attr['value'].tensor.tensor_shape | |
shape = [dim.size for dim in shape.dim] | |
retrain_mask[node] = { | |
'shape': shape, | |
'indices': list(indices) | |
} | |
for node, indices in bias_mask.iteritems(): | |
shape = nodes_map[node].attr['value'].tensor.tensor_shape | |
shape = [dim.size for dim in shape.dim] | |
retrain_mask[node] = { | |
'shape': shape, | |
'indices': list(indices) | |
} | |
for node, indices in bn_mask.iteritems(): | |
shape = nodes_map[node].attr['value'].tensor.tensor_shape | |
shape = [dim.size for dim in shape.dim] | |
retrain_mask[node] = { | |
'shape': shape, | |
'indices': list(indices) | |
} | |
with open(args.output_mask, 'wb') as f: | |
pickle.dump(retrain_mask, f) | |
def update_node_shape(in_node, delta): | |
# pdb.set_trace() | |
out_node = node_def_pb2.NodeDef() | |
shape = tensor_util.MakeNdarray(in_node.attr['value'].tensor) | |
reduced_shape = shape - delta | |
out_node.op = 'Const' | |
out_node.name = in_node.name | |
dtype = in_node.attr["dtype"] | |
out_node.attr["dtype"].CopyFrom(dtype) | |
out_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue( | |
tensor=tensor_util.make_tensor_proto(reduced_shape, | |
dtype=dtype.type | |
))) | |
return out_node | |
def prune_bn_node(input_node, mask): | |
"""TODO: Pruning batch norm nodes according to given mask. | |
:node_n: input node_def | |
:mask: given pruning mask | |
:returns: the pruned_node_def | |
""" | |
output_node = node_def_pb2.NodeDef() | |
ndarray = tensor_util.MakeNdarray(input_node.attr['value'].tensor) | |
pruned_ndarray = np.delete(ndarray, mask) | |
output_node.op = 'Const' | |
output_node.name = input_node.name | |
dtype = input_node.attr["dtype"] | |
output_node.attr["dtype"].CopyFrom(dtype) | |
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue( | |
tensor=tensor_util.make_tensor_proto(pruned_ndarray, | |
dtype=dtype.type, | |
shape=pruned_ndarray.shape))) | |
return output_node | |
def prune_bias_node(input_node, mask): | |
"""TODO: Pruning biases nodes according to given mask. | |
:node_n: input node_def | |
:mask: given pruning mask | |
:returns: the pruned_node_def | |
""" | |
output_node = node_def_pb2.NodeDef() | |
ndarray = tensor_util.MakeNdarray(input_node.attr['value'].tensor) | |
pruned_ndarray = np.delete(ndarray, mask) | |
output_node.op = 'Const' | |
output_node.name = input_node.name | |
dtype = input_node.attr["dtype"] | |
output_node.attr["dtype"].CopyFrom(dtype) | |
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue( | |
tensor=tensor_util.make_tensor_proto(pruned_ndarray, | |
dtype=dtype.type, | |
shape=pruned_ndarray.shape))) | |
return output_node | |
def prune_node(input_node, out_indices, in_indices): | |
output_node = node_def_pb2.NodeDef() | |
ndarray = tensor_util.MakeNdarray(input_node.attr['value'].tensor) | |
# pdb.set_trace() | |
pruned_ndarray = ndarray | |
if out_indices.size: | |
pruned_ndarray = np.delete(pruned_ndarray, out_indices, axis=3) | |
if in_indices.size: | |
pruned_ndarray = np.delete(pruned_ndarray, in_indices, axis=2) | |
output_node.op = 'Const' | |
output_node.name = input_node.name | |
dtype = input_node.attr["dtype"] | |
output_node.attr["dtype"].CopyFrom(dtype) | |
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue( | |
tensor=tensor_util.make_tensor_proto(pruned_ndarray, | |
dtype=dtype.type, | |
shape=pruned_ndarray.shape))) | |
return output_node | |
output_graph_def = graph_pb2.GraphDef() | |
for node in gd.node: | |
output_node = node_def_pb2.NodeDef() | |
# pdb.set_trace() | |
if node.name in weights_n: | |
name = node.name | |
print("Pruning node {}".format(name)) | |
output_node = prune_node(node, output_mask[name], input_mask[name]) | |
elif node.name in bn_mask: | |
name = node.name | |
print("pruning node {}".format(name)) | |
output_node = prune_bn_node(node, bn_mask[name]) | |
elif node.name in bias_mask: | |
name = node.name | |
print ("pruning node {}".format(name)) | |
output_node = prune_bias_node(node, bias_mask[name]) | |
elif node.name in new_shape: | |
name = node.name | |
print("pruning node {}".format(name)) | |
output_node = update_node_shape(node, new_shape[name]) | |
elif node.name == 'Bottleneck/weights': | |
weights = tensor_util.MakeNdarray(node.attr['value'].tensor) | |
weights = np.delete(weights, output_mask['InceptionResnetV1/Block8/Conv2d_1x1/weights'], 0) | |
output_node.op = 'Const' | |
output_node.name = node.name | |
dtype = node.attr["dtype"] | |
output_node.attr["dtype"].CopyFrom(dtype) | |
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue( | |
tensor=tensor_util.make_tensor_proto(weights, | |
dtype=dtype.type, | |
shape=weights.shape))) | |
elif node.name =='InceptionResnetV1/Logits/Flatten/Reshape/shape' : | |
last_shape = nodes_map['InceptionResnetV1/Block8/Conv2d_1x1/weights'].attr['value'].tensor.tensor_shape.dim[-1].size | |
_shape = last_shape - output_mask['InceptionResnetV1/Block8/Conv2d_1x1/weights'].shape[0] | |
output_node.op = 'Const' | |
output_node.name = node.name | |
dtype = node.attr["dtype"] | |
output_node.attr["dtype"].CopyFrom(dtype) | |
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue( | |
tensor=tensor_util.make_tensor_proto(np.array([-1, _shape]), | |
dtype=dtype.type | |
))) | |
else: | |
output_node.CopyFrom(node) | |
output_graph_def.node.extend([output_node]) | |
out_map = dict( [ (n.name, n) for n in output_graph_def.node]) | |
# pdb.set_trace() | |
with gfile.GFile(output_graph, "wb") as f: | |
f.write(output_graph_def.SerializeToString()) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('GRAPH_PATH') | |
parser.add_argument('MASK_PATH') | |
parser.add_argument('output_graph') | |
parser.add_argument('output_mask') | |
args = parser.parse_args() | |
main(args) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow as tf | |
import numpy as np | |
from time import gmtime, strftime | |
import pickle | |
import pdb | |
import os | |
import argparse | |
import re | |
def main(args): | |
with tf.Session() as sess: | |
saver = tf.train.import_meta_graph(args.meta) | |
saver.restore(sess, args.ckpt) | |
# variables holding the conv weights | |
conv_nodes_names = [ n.name for n in tf.get_default_graph().as_graph_def().node if n.op == 'Conv2D'] | |
conv_vars_names = [ re.sub('convolution$', 'weights:0', name) for name in conv_nodes_names] | |
conv_vars = [v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if v.name in conv_vars_names] | |
# conv_vars = [ v for v in tf.global_variables() if v.name in conv_nodes_names] | |
zero_idx= {} | |
# pruning_mask_dict = {} | |
# pdb.set_trace() | |
num_total_2d_kernels = 0 | |
num_pruned = 0 | |
for idx, conv in enumerate(conv_vars) : | |
weights = conv.eval() | |
kernel_w, kernel_h, input_channels, num_filter= weights.shape | |
num_total_2d_kernels += input_channels * num_filter | |
if args.threshold > 0.0: | |
zero_indices = np.where( | |
np.sum(np.square(weights), axis=(0,1,2)) < args.threshold)[0] | |
print (conv.name) | |
if zero_indices.size: | |
n_zero_filters = zero_indices.shape[0] | |
print('{} kernels pruned'.format(n_zero_filters)) | |
num_pruned += n_zero_filters * input_channels | |
weights[:,:,:,zero_indices] = 0 | |
assign_op = conv.assign(weights) | |
assign_op.eval() | |
zero_idx[conv.name] = zero_indices | |
print("A total number of {} kernels prunned".format(num_pruned)) | |
print("Kernel pruning rate: {}".format(num_pruned / num_total_2d_kernels)) | |
mask_path = args.mask_path | |
with open(mask_path, 'wb') as f: | |
pickle.dump(zero_idx, f) | |
if __name__=='__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('meta') | |
parser.add_argument('ckpt') | |
parser.add_argument('mask_path') | |
parser.add_argument('--threshold', type=float) | |
args = parser.parse_args() | |
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow as tf | |
import numpy as np | |
import pickle | |
from tensorflow.core.framework import attr_value_pb2 | |
from tensorflow.core.framework import graph_pb2 | |
from tensorflow.core.framework import node_def_pb2 | |
from tensorflow.python.framework import tensor_util | |
from tensorflow.python.platform import gfile | |
import re | |
import argparse | |
import pdb, traceback, sys | |
def main(args): | |
GRAPH_PATH= args.GRAPH_PATH | |
MASK_PATH= args.MASK_PATH | |
output_graph=args.output_graph | |
with tf.Session() as sess: | |
with tf.gfile.FastGFile(GRAPH_PATH, 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
_ = tf.import_graph_def(graph_def, name="") | |
# nodes_map = dict( [(n.name, n) for n in tf.get_default_graph().as_graph_def().node ]) | |
nodes_map = {} | |
conv_nodes = [ n for n in tf.get_default_graph().as_graph_def().node if n.op == 'Conv2D'] | |
edges = {} # Keyed by the dest node name. | |
# Keeps track of node sequences. It is important to still output the | |
# operations in the original order. | |
node_seq = {} # Keyed by node name. | |
seq = 0 | |
gd = tf.get_default_graph().as_graph_def() | |
def dfs_last_conv(node): | |
if node.input: | |
for n in node.input: | |
in_node = nodes_map[n] | |
if in_node.op == 'Conv2D': | |
return in_node | |
elif in_node.op == 'Concat': | |
return None | |
else: | |
return dfs_last_conv(in_node) | |
else: | |
return None | |
def _node_name(n): | |
if n.startswith("^"): | |
return n[1:] | |
else: | |
return n.split(":")[0] | |
for node in gd.node: | |
n = _node_name(node.name) | |
nodes_map[n] = node | |
edges[n] = [_node_name(x) for x in node.input] | |
node_seq[n] = seq | |
seq += 1 | |
with open(MASK_PATH, 'r') as f: | |
mask = pickle.load(f) | |
def get_node_output_channels(node_name): | |
weight_node = nodes_map[node_name] | |
tensor = weight_node.attr['value'].tensor | |
return int(tensor.tensor_shape.dim[3].size) | |
class GraphNode(object): | |
"""Pruning info for a node. """ | |
def __init__(self, channels, indices): | |
"""TODO: to be defined1. """ | |
self.channels = channels | |
self.indices = indices | |
# pdb.set_trace() | |
node_n = [n.name for n in conv_nodes] | |
weights_n = [None] * len(node_n) | |
for idx, conv in enumerate(node_n): | |
node = nodes_map[conv] | |
for n in node.input: | |
if nodes_map[n].op == 'Const': | |
weights_n[idx] = n | |
def get_conv_weights(node): | |
''' | |
return the nodes contains weights of input conv node | |
''' | |
for n in node.input: | |
if nodes_map[n].op == 'Const': | |
return n | |
input_mask = {} # pruning mask for the input channels | |
node_info = {} # store info for the milestone nodes | |
output_mask = {} # pruning mask for the output channels | |
bias_mask = {} # pruning mask for the biases | |
for k, v in mask.iteritems(): | |
output_mask[_node_name(k)] = np.array(v['indices']) | |
pdb.set_trace() | |
for idx, n in enumerate(weights_n): | |
if n in output_mask: | |
pass | |
else: | |
output_mask[n] = np.array([]) | |
for idx, n in enumerate(weights_n): | |
conv_name = node_n[idx] | |
last_conv = dfs_last_conv(nodes_map[conv_name]) | |
if last_conv: | |
last_conv_w = get_conv_weights(last_conv) | |
input_mask[n] = output_mask[last_conv_w] | |
else: | |
input_mask[n] = np.array([]) | |
# pdb.set_trace() | |
# General convs | |
# shallow_convs = [n for n in node_n if re.match(r'InceptionResnetV1/Conv2d_\d[a-z]_\dx\d', n)] | |
input_mask['InceptionResnetV1/Conv2d_2a_3x3/weights'] = output_mask['InceptionResnetV1/Conv2d_1a_3x3/weights'] | |
input_mask['InceptionResnetV1/Conv2d_2b_3x3/weights'] = output_mask['InceptionResnetV1/Conv2d_2a_3x3/weights'] | |
input_mask['InceptionResnetV1/Conv2d_3b_1x1/weights'] = output_mask['InceptionResnetV1/Conv2d_2b_3x3/weights'] | |
input_mask['InceptionResnetV1/Conv2d_4a_3x3/weights'] = output_mask['InceptionResnetV1/Conv2d_3b_1x1/weights'] | |
input_mask['InceptionResnetV1/Conv2d_4b_3x3/weights'] = output_mask['InceptionResnetV1/Conv2d_4a_3x3/weights'] | |
# Pruning the sub graphs with residual block , | |
# the pruning mask of residual block is determined by the shortcut projection | |
# the residual blocks in the network are the 5 x Block35, 10 x block17, 5 x block5 | |
# Incepiton module for inception_renset_a | |
for i in range(1, 6): | |
# pdb.set_trace() | |
branch_0 = 'InceptionResnetV1/Repeat/block35_{}/Branch_0/Conv2d_1x1/weights'.format(i) | |
branch_1 = 'InceptionResnetV1/Repeat/block35_{}/Branch_1/Conv2d_0b_3x3/weights'.format(i) | |
branch_2 = 'InceptionResnetV1/Repeat/block35_{}/Branch_2/Conv2d_0c_3x3/weights'.format(i) | |
mixed_conv = 'InceptionResnetV1/Repeat/block35_{}/Conv2d_1x1/weights'.format(i) | |
mixed_bias = re.sub(r'weights', 'biases', mixed_conv) | |
input_mask[mixed_conv] = np.concatenate((np.array(output_mask[branch_0]), | |
(np.array(output_mask[branch_1]) + get_node_output_channels(branch_0)), | |
(np.array(output_mask[branch_2]) + get_node_output_channels(branch_0) + get_node_output_channels(branch_1))) ) | |
# Inception-Resnet-A's last conv has to keep consistence with conv2d_4b_3x3 | |
output_mask[mixed_conv] = output_mask['InceptionResnetV1/Conv2d_4b_3x3/weights'] | |
bias_mask[mixed_bias] = output_mask[mixed_conv] | |
node_info['inception-resnet-a'] = GraphNode(256, output_mask['InceptionResnetV1/Conv2d_4b_3x3/weights']) | |
# Reduction A: output has a concat of feature maps [ branch0 + branch1 + branch2 ( inception-resnet-a )] | |
# pdb.set_trace() | |
branch_0_in = 'InceptionResnetV1/Mixed_6a/Branch_0/Conv2d_1a_3x3/weights' | |
branch_1_in = 'InceptionResnetV1/Mixed_6a/Branch_1/Conv2d_0a_1x1/weights' | |
indices = node_info['inception-resnet-a'].indices | |
input_mask[branch_0_in] = indices | |
input_mask[branch_1_in] = indices | |
num_fm_mix6a_branch0 = get_node_output_channels('InceptionResnetV1/Mixed_6a/Branch_0/Conv2d_1a_3x3/weights') | |
num_fm_mix6a_branch1 = get_node_output_channels('InceptionResnetV1/Mixed_6a/Branch_1/Conv2d_1a_3x3/weights') | |
num_fm_mix6a_branch2 = node_info['inception-resnet-a'].channels | |
node_info['reduction_a'] = GraphNode(np.sum([num_fm_mix6a_branch0, num_fm_mix6a_branch1, num_fm_mix6a_branch2]), | |
np.concatenate((output_mask['InceptionResnetV1/Mixed_6a/Branch_0/Conv2d_1a_3x3/weights'], | |
np.array(output_mask['InceptionResnetV1/Mixed_6a/Branch_1/Conv2d_1a_3x3/weights']) + num_fm_mix6a_branch0, | |
np.array(node_info['inception-resnet-a'].indices) + num_fm_mix6a_branch0 + num_fm_mix6a_branch1 ) )) | |
# 10 x Inception-Resnet-B | |
# Incepiton module for inception_renset_b | |
for i in range(1, 11): | |
branch_0 = 'InceptionResnetV1/Repeat_1/block17_{idx}/Branch_0/Conv2d_1x1/weights'.format(idx=i) | |
branch_1 = 'InceptionResnetV1/Repeat_1/block17_{idx}/Branch_1/Conv2d_0c_7x1/weights'.format(idx=i) | |
branch_1_in = 'InceptionResnetV1/Repeat_1/block17_{idx}/Branch_1/Conv2d_0a_1x1/weights'.format(idx=i) | |
mixed_conv = 'InceptionResnetV1/Repeat_1/block17_{idx}/Conv2d_1x1/weights'.format(idx=i) | |
mixed_bias = re.sub(r'weights', 'biases', mixed_conv) | |
input_mask[branch_0] = node_info['reduction_a'].indices | |
input_mask[branch_1_in] = node_info['reduction_a'].indices | |
input_mask[mixed_conv] = np.concatenate((np.array(output_mask[branch_0]), | |
(np.array(output_mask[branch_1]) + get_node_output_channels(branch_0)) ) ) | |
output_mask[mixed_conv] = node_info['reduction_a'].indices | |
bias_mask[mixed_bias] = output_mask[mixed_conv] | |
node_info['inception-resnet-b'] = node_info['reduction_a'] | |
# Reduction B | |
num_fm_mix7a_branch0 = get_node_output_channels('InceptionResnetV1/Mixed_7a/Branch_0/Conv2d_1a_3x3/weights') | |
num_fm_mix7a_branch1 = get_node_output_channels('InceptionResnetV1/Mixed_7a/Branch_1/Conv2d_1a_3x3/weights') | |
num_fm_mix7a_branch2 = get_node_output_channels('InceptionResnetV1/Mixed_7a/Branch_2/Conv2d_1a_3x3/weights') | |
num_fm_mix7a_branch3 = node_info['inception-resnet-b'].channels | |
branch_0_in = 'InceptionResnetV1/Mixed_7a/Branch_0/Conv2d_0a_1x1/weights' | |
branch_1_in = 'InceptionResnetV1/Mixed_7a/Branch_1/Conv2d_0a_1x1/weights' | |
branch_2_in = 'InceptionResnetV1/Mixed_7a/Branch_2/Conv2d_0a_1x1/weights' | |
indices = node_info['inception-resnet-b'].indices | |
input_mask[branch_0_in] = indices | |
input_mask[branch_1_in] = indices | |
input_mask[branch_2_in] = indices | |
node_info['reduction_b'] = GraphNode(np.sum([num_fm_mix7a_branch0, num_fm_mix7a_branch1, num_fm_mix7a_branch2, num_fm_mix7a_branch3]), | |
np.concatenate((output_mask['InceptionResnetV1/Mixed_7a/Branch_0/Conv2d_1a_3x3/weights'], | |
np.array(output_mask['InceptionResnetV1/Mixed_7a/Branch_1/Conv2d_1a_3x3/weights']) + num_fm_mix7a_branch0, | |
np.array(output_mask['InceptionResnetV1/Mixed_7a/Branch_2/Conv2d_1a_3x3/weights']) + num_fm_mix7a_branch0 + num_fm_mix7a_branch1, | |
np.array(node_info['inception-resnet-b'].indices) + num_fm_mix7a_branch0 + num_fm_mix7a_branch1 + num_fm_mix7a_branch2 ) )) | |
# 5 x Inception-Resnet-C | |
# Incepiton module for inception_renset_c | |
for i in range(1, 6): | |
branch_0 = 'InceptionResnetV1/Repeat_2/block8_{idx}/Branch_0/Conv2d_1x1/weights'.format(idx=i) | |
branch_1 = 'InceptionResnetV1/Repeat_2/block8_{idx}/Branch_1/Conv2d_0c_3x1/weights'.format(idx=i) | |
branch_1_in = 'InceptionResnetV1/Repeat_2/block8_{idx}/Branch_1/Conv2d_0a_1x1/weights'.format(idx=i) | |
mixed_conv = 'InceptionResnetV1/Repeat_2/block8_{idx}/Conv2d_1x1/weights'.format(idx=i) | |
mixed_bias = re.sub(r'weights', 'biases', mixed_conv) | |
input_mask[branch_0] = node_info['reduction_b'].indices | |
input_mask[branch_1_in] = node_info['reduction_b'].indices | |
input_mask[mixed_conv] = np.concatenate((np.array(output_mask[branch_0]), | |
(np.array(output_mask[branch_1]) + get_node_output_channels(branch_0)) )) | |
output_mask[mixed_conv] = node_info['reduction_b'].indices | |
#pdb.set_trace() | |
bias_mask[mixed_bias] = output_mask[mixed_conv] | |
node_info['inception-resnet-c'] = node_info['reduction_b'] | |
# the last block8 | |
branch_0 = 'InceptionResnetV1/Block8/Branch_0/Conv2d_1x1/weights'.format(idx=i) | |
branch_1 = 'InceptionResnetV1/Block8/Branch_1/Conv2d_0c_3x1/weights'.format(idx=i) | |
branch_1_in = 'InceptionResnetV1/Block8/Branch_1/Conv2d_0a_1x1/weights'.format(idx=i) | |
mixed_conv = 'InceptionResnetV1/Block8/Conv2d_1x1/weights'.format(idx=i) | |
mixed_bias = re.sub(r'weights', 'biases', mixed_conv) | |
input_mask[branch_0] = node_info['inception-resnet-c'].indices | |
input_mask[branch_1_in] = node_info['inception-resnet-c'].indices | |
input_mask[mixed_conv] = np.concatenate((np.array(output_mask[branch_0]), | |
(np.array(output_mask[branch_1]) + get_node_output_channels(branch_0)) )) | |
output_mask[mixed_conv] = node_info['inception-resnet-c'].indices | |
bias_mask[mixed_bias] = output_mask[mixed_conv] | |
retrain_mask = {} | |
for node, indices in output_mask.iteritems(): | |
shape = nodes_map[node].attr['value'].tensor.tensor_shape | |
shape = np.array([dim.size for dim in shape.dim]) | |
retrain_mask[node] = { | |
'shape': shape, | |
'indices': indices | |
} | |
with open(args.output_mask, 'wb') as f: | |
pickle.dump(retrain_mask, f) | |
# batch norm nodes | |
bn_mask = {} | |
new_shape = {} | |
for n in weights_n: | |
if not re.search(r'block\d+_\d/Conv2d_1x1/weights$', n): | |
bn_prefix = re.sub(r'weights$', 'BatchNorm/', n) | |
beta = bn_prefix + 'beta' | |
moving_variance = bn_prefix + 'moving_variance' | |
moving_mean = bn_prefix + 'moving_mean' | |
moving_mean_variance = bn_prefix + 'moving_mean_variance' | |
moments_shape = bn_prefix + 'moments/Shape' | |
bn_mask[beta] = output_mask[n] | |
bn_mask[moving_mean] = output_mask[n] | |
bn_mask[moving_variance] = output_mask[n] | |
bn_mask[moving_mean_variance] = output_mask[n] | |
new_shape[moments_shape] = output_mask[n].shape[0] | |
def update_node_shape(in_node, delta): | |
# pdb.set_trace() | |
out_node = node_def_pb2.NodeDef() | |
shape = tensor_util.MakeNdarray(in_node.attr['value'].tensor) | |
reduced_shape = shape - delta | |
out_node.op = 'Const' | |
out_node.name = in_node.name | |
dtype = in_node.attr["dtype"] | |
out_node.attr["dtype"].CopyFrom(dtype) | |
out_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue( | |
tensor=tensor_util.make_tensor_proto(reduced_shape, | |
dtype=dtype.type | |
))) | |
return out_node | |
def prune_bn_node(input_node, mask): | |
"""TODO: Pruning batch norm nodes according to given mask. | |
:node_n: input node_def | |
:mask: given pruning mask | |
:returns: the pruned_node_def | |
""" | |
output_node = node_def_pb2.NodeDef() | |
ndarray = tensor_util.MakeNdarray(input_node.attr['value'].tensor) | |
pruned_ndarray = np.delete(ndarray, mask) | |
output_node.op = 'Const' | |
output_node.name = input_node.name | |
dtype = input_node.attr["dtype"] | |
output_node.attr["dtype"].CopyFrom(dtype) | |
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue( | |
tensor=tensor_util.make_tensor_proto(pruned_ndarray, | |
dtype=dtype.type, | |
shape=pruned_ndarray.shape))) | |
return output_node | |
def prune_bias_node(input_node, mask): | |
"""TODO: Pruning biases nodes according to given mask. | |
:node_n: input node_def | |
:mask: given pruning mask | |
:returns: the pruned_node_def | |
""" | |
output_node = node_def_pb2.NodeDef() | |
ndarray = tensor_util.MakeNdarray(input_node.attr['value'].tensor) | |
pruned_ndarray = np.delete(ndarray, mask) | |
output_node.op = 'Const' | |
output_node.name = input_node.name | |
dtype = input_node.attr["dtype"] | |
output_node.attr["dtype"].CopyFrom(dtype) | |
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue( | |
tensor=tensor_util.make_tensor_proto(pruned_ndarray, | |
dtype=dtype.type, | |
shape=pruned_ndarray.shape))) | |
return output_node | |
def prune_node(input_node, out_indices, in_indices): | |
output_node = node_def_pb2.NodeDef() | |
ndarray = tensor_util.MakeNdarray(input_node.attr['value'].tensor) | |
# pdb.set_trace() | |
pruned_ndarray = ndarray | |
if out_indices.size: | |
pruned_ndarray = np.delete(pruned_ndarray, out_indices, axis=3) | |
if in_indices.size: | |
pruned_ndarray = np.delete(pruned_ndarray, in_indices, axis=2) | |
output_node.op = 'Const' | |
output_node.name = input_node.name | |
dtype = input_node.attr["dtype"] | |
output_node.attr["dtype"].CopyFrom(dtype) | |
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue( | |
tensor=tensor_util.make_tensor_proto(pruned_ndarray, | |
dtype=dtype.type, | |
shape=pruned_ndarray.shape))) | |
return output_node | |
output_graph_def = graph_pb2.GraphDef() | |
for node in gd.node: | |
output_node = node_def_pb2.NodeDef() | |
# pdb.set_trace() | |
if node.name in weights_n: | |
name = node.name | |
print("Pruning node {}".format(name)) | |
output_node = prune_node(node, output_mask[name], input_mask[name]) | |
elif node.name in bn_mask: | |
name = node.name | |
print("pruning node {}".format(name)) | |
output_node = prune_bn_node(node, bn_mask[name]) | |
elif node.name in bias_mask: | |
name = node.name | |
print ("pruning node {}".format(name)) | |
output_node = prune_bias_node(node, bias_mask[name]) | |
elif node.name in new_shape: | |
name = node.name | |
print("pruning node {}".format(name)) | |
output_node = update_node_shape(node, new_shape[name]) | |
elif node.name == 'Bottleneck/weights': | |
weights = tensor_util.MakeNdarray(node.attr['value'].tensor) | |
weights = np.delete(weights, output_mask['InceptionResnetV1/Block8/Conv2d_1x1/weights'], 0) | |
output_node.op = 'Const' | |
output_node.name = node.name | |
dtype = node.attr["dtype"] | |
output_node.attr["dtype"].CopyFrom(dtype) | |
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue( | |
tensor=tensor_util.make_tensor_proto(weights, | |
dtype=dtype.type, | |
shape=weights.shape))) | |
elif node.name =='InceptionResnetV1/Logits/Flatten/Reshape/shape' : | |
last_shape = nodes_map['InceptionResnetV1/Block8/Conv2d_1x1/weights'].attr['value'].tensor.tensor_shape.dim[-1].size | |
_shape = last_shape - output_mask['InceptionResnetV1/Block8/Conv2d_1x1/weights'].shape[0] | |
output_node.op = 'Const' | |
output_node.name = node.name | |
dtype = node.attr["dtype"] | |
output_node.attr["dtype"].CopyFrom(dtype) | |
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue( | |
tensor=tensor_util.make_tensor_proto(np.array([-1, _shape]), | |
dtype=dtype.type | |
))) | |
else: | |
output_node.CopyFrom(node) | |
output_graph_def.node.extend([output_node]) | |
out_map = dict( [ (n.name, n) for n in output_graph_def.node]) | |
pdb.set_trace() | |
with gfile.GFile(output_graph, "wb") as f: | |
f.write(output_graph_def.SerializeToString()) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('GRAPH_PATH') | |
parser.add_argument('MASK_PATH') | |
parser.add_argument('output_graph') | |
parser.add_argument('output_mask') | |
args = parser.parse_args() | |
main(args) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import tensorflow as tf | |
import numpy as np | |
import pickle | |
import argparse | |
def main(args): | |
"""TODO: Docstring for main(. | |
:a: TODO | |
:returns: TODO | |
""" | |
with tf.Session() as sess: | |
saver = tf.train.import_meta_graph(args.meta) | |
saver.restore(sess, args.ckpt) | |
with open(args.zidx, 'rb') as f: | |
zero_idx = pickle.load(f) | |
def _node_name(n): | |
if n.startswith("^"): | |
return n[1:] | |
else: | |
return n.split(":")[0] | |
for v in tf.global_variables(): | |
node_name = _node_name(v.name) | |
if node_name in zero_idx: | |
weights = v.eval() | |
weights[..., zero_idx[node_name]['indices']] = 0 | |
assign_op = v.assign(weights) | |
assign_op.eval() | |
saver.save(sess, args.output) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('meta') | |
parser.add_argument('ckpt') | |
parser.add_argument('zidx') | |
parser.add_argument('output') | |
args = parser.parse_args() | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment