Skip to content

Instantly share code, notes, and snippets.

@wenfahu
Created August 27, 2017 16:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wenfahu/e822ef09716b63936984d181d34c6d28 to your computer and use it in GitHub Desktop.
Save wenfahu/e822ef09716b63936984d181d34c6d28 to your computer and use it in GitHub Desktop.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
import pickle
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import tensor_util
from tensorflow.python.platform import gfile
import re
import argparse
import pdb, traceback, sys
def main(args):
GRAPH_PATH= args.GRAPH_PATH
MASK_PATH= args.MASK_PATH
output_graph=args.output_graph
with tf.Session() as sess:
with tf.gfile.FastGFile(GRAPH_PATH, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name="")
# nodes_map = dict( [(n.name, n) for n in tf.get_default_graph().as_graph_def().node ])
nodes_map = {}
conv_nodes = [ n for n in tf.get_default_graph().as_graph_def().node if n.op == 'Conv2D']
edges = {} # Keyed by the dest node name.
# Keeps track of node sequences. It is important to still output the
# operations in the original order.
node_seq = {} # Keyed by node name.
seq = 0
gd = tf.get_default_graph().as_graph_def()
def dfs_last_conv(node):
if node.input:
for n in node.input:
in_node = nodes_map[n]
if in_node.op == 'Conv2D':
return in_node
elif in_node.op == 'Concat':
return None
else:
return dfs_last_conv(in_node)
else:
return None
def _node_name(n):
if n.startswith("^"):
return n[1:]
else:
return n.split(":")[0]
for node in gd.node:
n = _node_name(node.name)
nodes_map[n] = node
edges[n] = [_node_name(x) for x in node.input]
node_seq[n] = seq
seq += 1
with open(MASK_PATH, 'r') as f:
mask = pickle.load(f)
def get_node_output_channels(node_name):
weight_node = nodes_map[node_name]
tensor = weight_node.attr['value'].tensor
return int(tensor.tensor_shape.dim[3].size)
class GraphNode(object):
"""Pruning info for a node. """
def __init__(self, channels, indices):
"""TODO: to be defined1. """
self.channels = channels
self.indices = indices
# pdb.set_trace()
node_n = [n.name for n in conv_nodes]
weights_n = [None] * len(node_n)
for idx, conv in enumerate(node_n):
node = nodes_map[conv]
for n in node.input:
if nodes_map[n].op == 'Const':
weights_n[idx] = n
def get_conv_weights(node):
'''
return the nodes contains weights of input conv node
'''
for n in node.input:
if nodes_map[n].op == 'Const':
return n
input_mask = {} # pruning mask for the input channels
node_info = {} # store info for the milestone nodes
output_mask = {} # pruning mask for the output channels
bias_mask = {} # pruning mask for the biases
for k,v in mask.iteritems():
output_mask[_node_name(k)] = np.array(v)
for idx, n in enumerate(weights_n):
if n in output_mask:
pass
else:
output_mask[n] = np.array([])
# pdb.set_trace()
for idx, n in enumerate(weights_n):
conv_name = node_n[idx]
last_conv = dfs_last_conv(nodes_map[conv_name])
if last_conv:
last_conv_w = get_conv_weights(last_conv)
input_mask[n] = output_mask[last_conv_w]
else:
input_mask[n] = np.array([])
# pdb.set_trace()
# General convs
# shallow_convs = [n for n in node_n if re.match(r'InceptionResnetV1/Conv2d_\d[a-z]_\dx\d', n)]
input_mask['InceptionResnetV1/Conv2d_2a_3x3/weights'] = output_mask['InceptionResnetV1/Conv2d_1a_3x3/weights']
input_mask['InceptionResnetV1/Conv2d_2b_3x3/weights'] = output_mask['InceptionResnetV1/Conv2d_2a_3x3/weights']
input_mask['InceptionResnetV1/Conv2d_3b_1x1/weights'] = output_mask['InceptionResnetV1/Conv2d_2b_3x3/weights']
input_mask['InceptionResnetV1/Conv2d_4a_3x3/weights'] = output_mask['InceptionResnetV1/Conv2d_3b_1x1/weights']
input_mask['InceptionResnetV1/Conv2d_4b_3x3/weights'] = output_mask['InceptionResnetV1/Conv2d_4a_3x3/weights']
# Pruning the sub graphs with residual block ,
# the pruning mask of residual block is determined by the shortcut projection
# the residual blocks in the network are the 5 x Block35, 10 x block17, 5 x block5
# Incepiton module for inception_renset_a
for i in range(1, 6):
# pdb.set_trace()
branch_0 = 'InceptionResnetV1/Repeat/block35_{}/Branch_0/Conv2d_1x1/weights'.format(i)
branch_1 = 'InceptionResnetV1/Repeat/block35_{}/Branch_1/Conv2d_0b_3x3/weights'.format(i)
branch_2 = 'InceptionResnetV1/Repeat/block35_{}/Branch_2/Conv2d_0c_3x3/weights'.format(i)
mixed_conv = 'InceptionResnetV1/Repeat/block35_{}/Conv2d_1x1/weights'.format(i)
mixed_bias = re.sub(r'weights', 'biases', mixed_conv)
input_mask[mixed_conv] = np.concatenate((np.array(output_mask[branch_0]),
(np.array(output_mask[branch_1]) + get_node_output_channels(branch_0)),
(np.array(output_mask[branch_2]) + get_node_output_channels(branch_0) + get_node_output_channels(branch_1))) )
# Inception-Resnet-A's last conv has to keep consistence with conv2d_4b_3x3
output_mask[mixed_conv] = output_mask['InceptionResnetV1/Conv2d_4b_3x3/weights']
bias_mask[mixed_bias] = output_mask[mixed_conv]
node_info['inception-resnet-a'] = GraphNode(256, output_mask['InceptionResnetV1/Conv2d_4b_3x3/weights'])
# Reduction A: output has a concat of feature maps [ branch0 + branch1 + branch2 ( inception-resnet-a )]
# pdb.set_trace()
branch_0_in = 'InceptionResnetV1/Mixed_6a/Branch_0/Conv2d_1a_3x3/weights'
branch_1_in = 'InceptionResnetV1/Mixed_6a/Branch_1/Conv2d_0a_1x1/weights'
indices = node_info['inception-resnet-a'].indices
input_mask[branch_0_in] = indices
input_mask[branch_1_in] = indices
num_fm_mix6a_branch0 = get_node_output_channels('InceptionResnetV1/Mixed_6a/Branch_0/Conv2d_1a_3x3/weights')
num_fm_mix6a_branch1 = get_node_output_channels('InceptionResnetV1/Mixed_6a/Branch_1/Conv2d_1a_3x3/weights')
num_fm_mix6a_branch2 = node_info['inception-resnet-a'].channels
node_info['reduction_a'] = GraphNode(np.sum([num_fm_mix6a_branch0, num_fm_mix6a_branch1, num_fm_mix6a_branch2]),
np.concatenate((output_mask['InceptionResnetV1/Mixed_6a/Branch_0/Conv2d_1a_3x3/weights'],
np.array(output_mask['InceptionResnetV1/Mixed_6a/Branch_1/Conv2d_1a_3x3/weights']) + num_fm_mix6a_branch0,
np.array(node_info['inception-resnet-a'].indices) + num_fm_mix6a_branch0 + num_fm_mix6a_branch1 ) ))
# 10 x Inception-Resnet-B
# Incepiton module for inception_renset_b
for i in range(1, 11):
branch_0 = 'InceptionResnetV1/Repeat_1/block17_{idx}/Branch_0/Conv2d_1x1/weights'.format(idx=i)
branch_1 = 'InceptionResnetV1/Repeat_1/block17_{idx}/Branch_1/Conv2d_0c_7x1/weights'.format(idx=i)
branch_1_in = 'InceptionResnetV1/Repeat_1/block17_{idx}/Branch_1/Conv2d_0a_1x1/weights'.format(idx=i)
mixed_conv = 'InceptionResnetV1/Repeat_1/block17_{idx}/Conv2d_1x1/weights'.format(idx=i)
mixed_bias = re.sub(r'weights', 'biases', mixed_conv)
input_mask[branch_0] = node_info['reduction_a'].indices
input_mask[branch_1_in] = node_info['reduction_a'].indices
input_mask[mixed_conv] = np.concatenate((np.array(output_mask[branch_0]),
(np.array(output_mask[branch_1]) + get_node_output_channels(branch_0)) ) )
output_mask[mixed_conv] = node_info['reduction_a'].indices
bias_mask[mixed_bias] = output_mask[mixed_conv]
node_info['inception-resnet-b'] = node_info['reduction_a']
# Reduction B
num_fm_mix7a_branch0 = get_node_output_channels('InceptionResnetV1/Mixed_7a/Branch_0/Conv2d_1a_3x3/weights')
num_fm_mix7a_branch1 = get_node_output_channels('InceptionResnetV1/Mixed_7a/Branch_1/Conv2d_1a_3x3/weights')
num_fm_mix7a_branch2 = get_node_output_channels('InceptionResnetV1/Mixed_7a/Branch_2/Conv2d_1a_3x3/weights')
num_fm_mix7a_branch3 = node_info['inception-resnet-b'].channels
branch_0_in = 'InceptionResnetV1/Mixed_7a/Branch_0/Conv2d_0a_1x1/weights'
branch_1_in = 'InceptionResnetV1/Mixed_7a/Branch_1/Conv2d_0a_1x1/weights'
branch_2_in = 'InceptionResnetV1/Mixed_7a/Branch_2/Conv2d_0a_1x1/weights'
indices = node_info['inception-resnet-b'].indices
input_mask[branch_0_in] = indices
input_mask[branch_1_in] = indices
input_mask[branch_2_in] = indices
node_info['reduction_b'] = GraphNode(np.sum([num_fm_mix7a_branch0, num_fm_mix7a_branch1, num_fm_mix7a_branch2, num_fm_mix7a_branch3]),
np.concatenate((output_mask['InceptionResnetV1/Mixed_7a/Branch_0/Conv2d_1a_3x3/weights'],
np.array(output_mask['InceptionResnetV1/Mixed_7a/Branch_1/Conv2d_1a_3x3/weights']) + num_fm_mix7a_branch0,
np.array(output_mask['InceptionResnetV1/Mixed_7a/Branch_2/Conv2d_1a_3x3/weights']) + num_fm_mix7a_branch0 + num_fm_mix7a_branch1,
np.array(node_info['inception-resnet-b'].indices) + num_fm_mix7a_branch0 + num_fm_mix7a_branch1 + num_fm_mix7a_branch2 ) ))
# 5 x Inception-Resnet-C
# Incepiton module for inception_renset_c
for i in range(1, 6):
branch_0 = 'InceptionResnetV1/Repeat_2/block8_{idx}/Branch_0/Conv2d_1x1/weights'.format(idx=i)
branch_1 = 'InceptionResnetV1/Repeat_2/block8_{idx}/Branch_1/Conv2d_0c_3x1/weights'.format(idx=i)
branch_1_in = 'InceptionResnetV1/Repeat_2/block8_{idx}/Branch_1/Conv2d_0a_1x1/weights'.format(idx=i)
mixed_conv = 'InceptionResnetV1/Repeat_2/block8_{idx}/Conv2d_1x1/weights'.format(idx=i)
mixed_bias = re.sub(r'weights', 'biases', mixed_conv)
input_mask[branch_0] = node_info['reduction_b'].indices
input_mask[branch_1_in] = node_info['reduction_b'].indices
input_mask[mixed_conv] = np.concatenate((np.array(output_mask[branch_0]),
(np.array(output_mask[branch_1]) + get_node_output_channels(branch_0)) ))
output_mask[mixed_conv] = node_info['reduction_b'].indices
#pdb.set_trace()
bias_mask[mixed_bias] = output_mask[mixed_conv]
node_info['inception-resnet-c'] = node_info['reduction_b']
# the last block8
branch_0 = 'InceptionResnetV1/Block8/Branch_0/Conv2d_1x1/weights'.format(idx=i)
branch_1 = 'InceptionResnetV1/Block8/Branch_1/Conv2d_0c_3x1/weights'.format(idx=i)
branch_1_in = 'InceptionResnetV1/Block8/Branch_1/Conv2d_0a_1x1/weights'.format(idx=i)
mixed_conv = 'InceptionResnetV1/Block8/Conv2d_1x1/weights'.format(idx=i)
mixed_bias = re.sub(r'weights', 'biases', mixed_conv)
input_mask[branch_0] = node_info['inception-resnet-c'].indices
input_mask[branch_1_in] = node_info['inception-resnet-c'].indices
input_mask[mixed_conv] = np.concatenate((np.array(output_mask[branch_0]),
(np.array(output_mask[branch_1]) + get_node_output_channels(branch_0)) ))
output_mask[mixed_conv] = node_info['inception-resnet-c'].indices
bias_mask[mixed_bias] = output_mask[mixed_conv]
# batch norm nodes
bn_mask = {}
new_shape = {}
for n in weights_n:
if n == 'InceptionResnetV1/Block8/Conv2d_1x1/weights':
continue
if re.search(r'block\d+_\d+/Conv2d_1x1/weights$', n):
continue
bn_prefix = re.sub(r'weights$', 'BatchNorm/', n)
beta = bn_prefix + 'beta'
moving_variance = bn_prefix + 'moving_variance'
moving_mean = bn_prefix + 'moving_mean'
# moving_mean_variance = bn_prefix + 'moving_mean_variance'
moments_shape = bn_prefix + 'moments/Shape'
bn_mask[beta] = output_mask[n]
bn_mask[moving_mean] = output_mask[n]
bn_mask[moving_variance] = output_mask[n]
# bn_mask[moving_mean_variance] = output_mask[n]
new_shape[moments_shape] = output_mask[n].shape[0]
retrain_mask = {}
for node, indices in output_mask.iteritems():
shape = nodes_map[node].attr['value'].tensor.tensor_shape
shape = [dim.size for dim in shape.dim]
retrain_mask[node] = {
'shape': shape,
'indices': list(indices)
}
for node, indices in bn_mask.iteritems():
shape = nodes_map[node].attr['value'].tensor.tensor_shape
shape = [dim.size for dim in shape.dim]
retrain_mask[node] = {
'shape': shape,
'indices': list(indices)
}
for node, indices in bias_mask.iteritems():
shape = nodes_map[node].attr['value'].tensor.tensor_shape
shape = [dim.size for dim in shape.dim]
retrain_mask[node] = {
'shape': shape,
'indices': list(indices)
}
for node, indices in bn_mask.iteritems():
shape = nodes_map[node].attr['value'].tensor.tensor_shape
shape = [dim.size for dim in shape.dim]
retrain_mask[node] = {
'shape': shape,
'indices': list(indices)
}
with open(args.output_mask, 'wb') as f:
pickle.dump(retrain_mask, f)
def update_node_shape(in_node, delta):
# pdb.set_trace()
out_node = node_def_pb2.NodeDef()
shape = tensor_util.MakeNdarray(in_node.attr['value'].tensor)
reduced_shape = shape - delta
out_node.op = 'Const'
out_node.name = in_node.name
dtype = in_node.attr["dtype"]
out_node.attr["dtype"].CopyFrom(dtype)
out_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
tensor=tensor_util.make_tensor_proto(reduced_shape,
dtype=dtype.type
)))
return out_node
def prune_bn_node(input_node, mask):
"""TODO: Pruning batch norm nodes according to given mask.
:node_n: input node_def
:mask: given pruning mask
:returns: the pruned_node_def
"""
output_node = node_def_pb2.NodeDef()
ndarray = tensor_util.MakeNdarray(input_node.attr['value'].tensor)
pruned_ndarray = np.delete(ndarray, mask)
output_node.op = 'Const'
output_node.name = input_node.name
dtype = input_node.attr["dtype"]
output_node.attr["dtype"].CopyFrom(dtype)
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
tensor=tensor_util.make_tensor_proto(pruned_ndarray,
dtype=dtype.type,
shape=pruned_ndarray.shape)))
return output_node
def prune_bias_node(input_node, mask):
"""TODO: Pruning biases nodes according to given mask.
:node_n: input node_def
:mask: given pruning mask
:returns: the pruned_node_def
"""
output_node = node_def_pb2.NodeDef()
ndarray = tensor_util.MakeNdarray(input_node.attr['value'].tensor)
pruned_ndarray = np.delete(ndarray, mask)
output_node.op = 'Const'
output_node.name = input_node.name
dtype = input_node.attr["dtype"]
output_node.attr["dtype"].CopyFrom(dtype)
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
tensor=tensor_util.make_tensor_proto(pruned_ndarray,
dtype=dtype.type,
shape=pruned_ndarray.shape)))
return output_node
def prune_node(input_node, out_indices, in_indices):
output_node = node_def_pb2.NodeDef()
ndarray = tensor_util.MakeNdarray(input_node.attr['value'].tensor)
# pdb.set_trace()
pruned_ndarray = ndarray
if out_indices.size:
pruned_ndarray = np.delete(pruned_ndarray, out_indices, axis=3)
if in_indices.size:
pruned_ndarray = np.delete(pruned_ndarray, in_indices, axis=2)
output_node.op = 'Const'
output_node.name = input_node.name
dtype = input_node.attr["dtype"]
output_node.attr["dtype"].CopyFrom(dtype)
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
tensor=tensor_util.make_tensor_proto(pruned_ndarray,
dtype=dtype.type,
shape=pruned_ndarray.shape)))
return output_node
output_graph_def = graph_pb2.GraphDef()
for node in gd.node:
output_node = node_def_pb2.NodeDef()
# pdb.set_trace()
if node.name in weights_n:
name = node.name
print("Pruning node {}".format(name))
output_node = prune_node(node, output_mask[name], input_mask[name])
elif node.name in bn_mask:
name = node.name
print("pruning node {}".format(name))
output_node = prune_bn_node(node, bn_mask[name])
elif node.name in bias_mask:
name = node.name
print ("pruning node {}".format(name))
output_node = prune_bias_node(node, bias_mask[name])
elif node.name in new_shape:
name = node.name
print("pruning node {}".format(name))
output_node = update_node_shape(node, new_shape[name])
elif node.name == 'Bottleneck/weights':
weights = tensor_util.MakeNdarray(node.attr['value'].tensor)
weights = np.delete(weights, output_mask['InceptionResnetV1/Block8/Conv2d_1x1/weights'], 0)
output_node.op = 'Const'
output_node.name = node.name
dtype = node.attr["dtype"]
output_node.attr["dtype"].CopyFrom(dtype)
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
tensor=tensor_util.make_tensor_proto(weights,
dtype=dtype.type,
shape=weights.shape)))
elif node.name =='InceptionResnetV1/Logits/Flatten/Reshape/shape' :
last_shape = nodes_map['InceptionResnetV1/Block8/Conv2d_1x1/weights'].attr['value'].tensor.tensor_shape.dim[-1].size
_shape = last_shape - output_mask['InceptionResnetV1/Block8/Conv2d_1x1/weights'].shape[0]
output_node.op = 'Const'
output_node.name = node.name
dtype = node.attr["dtype"]
output_node.attr["dtype"].CopyFrom(dtype)
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
tensor=tensor_util.make_tensor_proto(np.array([-1, _shape]),
dtype=dtype.type
)))
else:
output_node.CopyFrom(node)
output_graph_def.node.extend([output_node])
out_map = dict( [ (n.name, n) for n in output_graph_def.node])
# pdb.set_trace()
with gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('GRAPH_PATH')
parser.add_argument('MASK_PATH')
parser.add_argument('output_graph')
parser.add_argument('output_mask')
args = parser.parse_args()
main(args)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
from time import gmtime, strftime
import pickle
import pdb
import os
import argparse
import re
def main(args):
with tf.Session() as sess:
saver = tf.train.import_meta_graph(args.meta)
saver.restore(sess, args.ckpt)
# variables holding the conv weights
conv_nodes_names = [ n.name for n in tf.get_default_graph().as_graph_def().node if n.op == 'Conv2D']
conv_vars_names = [ re.sub('convolution$', 'weights:0', name) for name in conv_nodes_names]
conv_vars = [v for v in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) if v.name in conv_vars_names]
# conv_vars = [ v for v in tf.global_variables() if v.name in conv_nodes_names]
zero_idx= {}
# pruning_mask_dict = {}
# pdb.set_trace()
num_total_2d_kernels = 0
num_pruned = 0
for idx, conv in enumerate(conv_vars) :
weights = conv.eval()
kernel_w, kernel_h, input_channels, num_filter= weights.shape
num_total_2d_kernels += input_channels * num_filter
if args.threshold > 0.0:
zero_indices = np.where(
np.sum(np.square(weights), axis=(0,1,2)) < args.threshold)[0]
print (conv.name)
if zero_indices.size:
n_zero_filters = zero_indices.shape[0]
print('{} kernels pruned'.format(n_zero_filters))
num_pruned += n_zero_filters * input_channels
weights[:,:,:,zero_indices] = 0
assign_op = conv.assign(weights)
assign_op.eval()
zero_idx[conv.name] = zero_indices
print("A total number of {} kernels prunned".format(num_pruned))
print("Kernel pruning rate: {}".format(num_pruned / num_total_2d_kernels))
mask_path = args.mask_path
with open(mask_path, 'wb') as f:
pickle.dump(zero_idx, f)
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('meta')
parser.add_argument('ckpt')
parser.add_argument('mask_path')
parser.add_argument('--threshold', type=float)
args = parser.parse_args()
main(args)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
import pickle
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import tensor_util
from tensorflow.python.platform import gfile
import re
import argparse
import pdb, traceback, sys
def main(args):
GRAPH_PATH= args.GRAPH_PATH
MASK_PATH= args.MASK_PATH
output_graph=args.output_graph
with tf.Session() as sess:
with tf.gfile.FastGFile(GRAPH_PATH, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name="")
# nodes_map = dict( [(n.name, n) for n in tf.get_default_graph().as_graph_def().node ])
nodes_map = {}
conv_nodes = [ n for n in tf.get_default_graph().as_graph_def().node if n.op == 'Conv2D']
edges = {} # Keyed by the dest node name.
# Keeps track of node sequences. It is important to still output the
# operations in the original order.
node_seq = {} # Keyed by node name.
seq = 0
gd = tf.get_default_graph().as_graph_def()
def dfs_last_conv(node):
if node.input:
for n in node.input:
in_node = nodes_map[n]
if in_node.op == 'Conv2D':
return in_node
elif in_node.op == 'Concat':
return None
else:
return dfs_last_conv(in_node)
else:
return None
def _node_name(n):
if n.startswith("^"):
return n[1:]
else:
return n.split(":")[0]
for node in gd.node:
n = _node_name(node.name)
nodes_map[n] = node
edges[n] = [_node_name(x) for x in node.input]
node_seq[n] = seq
seq += 1
with open(MASK_PATH, 'r') as f:
mask = pickle.load(f)
def get_node_output_channels(node_name):
weight_node = nodes_map[node_name]
tensor = weight_node.attr['value'].tensor
return int(tensor.tensor_shape.dim[3].size)
class GraphNode(object):
"""Pruning info for a node. """
def __init__(self, channels, indices):
"""TODO: to be defined1. """
self.channels = channels
self.indices = indices
# pdb.set_trace()
node_n = [n.name for n in conv_nodes]
weights_n = [None] * len(node_n)
for idx, conv in enumerate(node_n):
node = nodes_map[conv]
for n in node.input:
if nodes_map[n].op == 'Const':
weights_n[idx] = n
def get_conv_weights(node):
'''
return the nodes contains weights of input conv node
'''
for n in node.input:
if nodes_map[n].op == 'Const':
return n
input_mask = {} # pruning mask for the input channels
node_info = {} # store info for the milestone nodes
output_mask = {} # pruning mask for the output channels
bias_mask = {} # pruning mask for the biases
for k, v in mask.iteritems():
output_mask[_node_name(k)] = np.array(v['indices'])
pdb.set_trace()
for idx, n in enumerate(weights_n):
if n in output_mask:
pass
else:
output_mask[n] = np.array([])
for idx, n in enumerate(weights_n):
conv_name = node_n[idx]
last_conv = dfs_last_conv(nodes_map[conv_name])
if last_conv:
last_conv_w = get_conv_weights(last_conv)
input_mask[n] = output_mask[last_conv_w]
else:
input_mask[n] = np.array([])
# pdb.set_trace()
# General convs
# shallow_convs = [n for n in node_n if re.match(r'InceptionResnetV1/Conv2d_\d[a-z]_\dx\d', n)]
input_mask['InceptionResnetV1/Conv2d_2a_3x3/weights'] = output_mask['InceptionResnetV1/Conv2d_1a_3x3/weights']
input_mask['InceptionResnetV1/Conv2d_2b_3x3/weights'] = output_mask['InceptionResnetV1/Conv2d_2a_3x3/weights']
input_mask['InceptionResnetV1/Conv2d_3b_1x1/weights'] = output_mask['InceptionResnetV1/Conv2d_2b_3x3/weights']
input_mask['InceptionResnetV1/Conv2d_4a_3x3/weights'] = output_mask['InceptionResnetV1/Conv2d_3b_1x1/weights']
input_mask['InceptionResnetV1/Conv2d_4b_3x3/weights'] = output_mask['InceptionResnetV1/Conv2d_4a_3x3/weights']
# Pruning the sub graphs with residual block ,
# the pruning mask of residual block is determined by the shortcut projection
# the residual blocks in the network are the 5 x Block35, 10 x block17, 5 x block5
# Incepiton module for inception_renset_a
for i in range(1, 6):
# pdb.set_trace()
branch_0 = 'InceptionResnetV1/Repeat/block35_{}/Branch_0/Conv2d_1x1/weights'.format(i)
branch_1 = 'InceptionResnetV1/Repeat/block35_{}/Branch_1/Conv2d_0b_3x3/weights'.format(i)
branch_2 = 'InceptionResnetV1/Repeat/block35_{}/Branch_2/Conv2d_0c_3x3/weights'.format(i)
mixed_conv = 'InceptionResnetV1/Repeat/block35_{}/Conv2d_1x1/weights'.format(i)
mixed_bias = re.sub(r'weights', 'biases', mixed_conv)
input_mask[mixed_conv] = np.concatenate((np.array(output_mask[branch_0]),
(np.array(output_mask[branch_1]) + get_node_output_channels(branch_0)),
(np.array(output_mask[branch_2]) + get_node_output_channels(branch_0) + get_node_output_channels(branch_1))) )
# Inception-Resnet-A's last conv has to keep consistence with conv2d_4b_3x3
output_mask[mixed_conv] = output_mask['InceptionResnetV1/Conv2d_4b_3x3/weights']
bias_mask[mixed_bias] = output_mask[mixed_conv]
node_info['inception-resnet-a'] = GraphNode(256, output_mask['InceptionResnetV1/Conv2d_4b_3x3/weights'])
# Reduction A: output has a concat of feature maps [ branch0 + branch1 + branch2 ( inception-resnet-a )]
# pdb.set_trace()
branch_0_in = 'InceptionResnetV1/Mixed_6a/Branch_0/Conv2d_1a_3x3/weights'
branch_1_in = 'InceptionResnetV1/Mixed_6a/Branch_1/Conv2d_0a_1x1/weights'
indices = node_info['inception-resnet-a'].indices
input_mask[branch_0_in] = indices
input_mask[branch_1_in] = indices
num_fm_mix6a_branch0 = get_node_output_channels('InceptionResnetV1/Mixed_6a/Branch_0/Conv2d_1a_3x3/weights')
num_fm_mix6a_branch1 = get_node_output_channels('InceptionResnetV1/Mixed_6a/Branch_1/Conv2d_1a_3x3/weights')
num_fm_mix6a_branch2 = node_info['inception-resnet-a'].channels
node_info['reduction_a'] = GraphNode(np.sum([num_fm_mix6a_branch0, num_fm_mix6a_branch1, num_fm_mix6a_branch2]),
np.concatenate((output_mask['InceptionResnetV1/Mixed_6a/Branch_0/Conv2d_1a_3x3/weights'],
np.array(output_mask['InceptionResnetV1/Mixed_6a/Branch_1/Conv2d_1a_3x3/weights']) + num_fm_mix6a_branch0,
np.array(node_info['inception-resnet-a'].indices) + num_fm_mix6a_branch0 + num_fm_mix6a_branch1 ) ))
# 10 x Inception-Resnet-B
# Incepiton module for inception_renset_b
for i in range(1, 11):
branch_0 = 'InceptionResnetV1/Repeat_1/block17_{idx}/Branch_0/Conv2d_1x1/weights'.format(idx=i)
branch_1 = 'InceptionResnetV1/Repeat_1/block17_{idx}/Branch_1/Conv2d_0c_7x1/weights'.format(idx=i)
branch_1_in = 'InceptionResnetV1/Repeat_1/block17_{idx}/Branch_1/Conv2d_0a_1x1/weights'.format(idx=i)
mixed_conv = 'InceptionResnetV1/Repeat_1/block17_{idx}/Conv2d_1x1/weights'.format(idx=i)
mixed_bias = re.sub(r'weights', 'biases', mixed_conv)
input_mask[branch_0] = node_info['reduction_a'].indices
input_mask[branch_1_in] = node_info['reduction_a'].indices
input_mask[mixed_conv] = np.concatenate((np.array(output_mask[branch_0]),
(np.array(output_mask[branch_1]) + get_node_output_channels(branch_0)) ) )
output_mask[mixed_conv] = node_info['reduction_a'].indices
bias_mask[mixed_bias] = output_mask[mixed_conv]
node_info['inception-resnet-b'] = node_info['reduction_a']
# Reduction B
num_fm_mix7a_branch0 = get_node_output_channels('InceptionResnetV1/Mixed_7a/Branch_0/Conv2d_1a_3x3/weights')
num_fm_mix7a_branch1 = get_node_output_channels('InceptionResnetV1/Mixed_7a/Branch_1/Conv2d_1a_3x3/weights')
num_fm_mix7a_branch2 = get_node_output_channels('InceptionResnetV1/Mixed_7a/Branch_2/Conv2d_1a_3x3/weights')
num_fm_mix7a_branch3 = node_info['inception-resnet-b'].channels
branch_0_in = 'InceptionResnetV1/Mixed_7a/Branch_0/Conv2d_0a_1x1/weights'
branch_1_in = 'InceptionResnetV1/Mixed_7a/Branch_1/Conv2d_0a_1x1/weights'
branch_2_in = 'InceptionResnetV1/Mixed_7a/Branch_2/Conv2d_0a_1x1/weights'
indices = node_info['inception-resnet-b'].indices
input_mask[branch_0_in] = indices
input_mask[branch_1_in] = indices
input_mask[branch_2_in] = indices
node_info['reduction_b'] = GraphNode(np.sum([num_fm_mix7a_branch0, num_fm_mix7a_branch1, num_fm_mix7a_branch2, num_fm_mix7a_branch3]),
np.concatenate((output_mask['InceptionResnetV1/Mixed_7a/Branch_0/Conv2d_1a_3x3/weights'],
np.array(output_mask['InceptionResnetV1/Mixed_7a/Branch_1/Conv2d_1a_3x3/weights']) + num_fm_mix7a_branch0,
np.array(output_mask['InceptionResnetV1/Mixed_7a/Branch_2/Conv2d_1a_3x3/weights']) + num_fm_mix7a_branch0 + num_fm_mix7a_branch1,
np.array(node_info['inception-resnet-b'].indices) + num_fm_mix7a_branch0 + num_fm_mix7a_branch1 + num_fm_mix7a_branch2 ) ))
# 5 x Inception-Resnet-C
# Incepiton module for inception_renset_c
for i in range(1, 6):
branch_0 = 'InceptionResnetV1/Repeat_2/block8_{idx}/Branch_0/Conv2d_1x1/weights'.format(idx=i)
branch_1 = 'InceptionResnetV1/Repeat_2/block8_{idx}/Branch_1/Conv2d_0c_3x1/weights'.format(idx=i)
branch_1_in = 'InceptionResnetV1/Repeat_2/block8_{idx}/Branch_1/Conv2d_0a_1x1/weights'.format(idx=i)
mixed_conv = 'InceptionResnetV1/Repeat_2/block8_{idx}/Conv2d_1x1/weights'.format(idx=i)
mixed_bias = re.sub(r'weights', 'biases', mixed_conv)
input_mask[branch_0] = node_info['reduction_b'].indices
input_mask[branch_1_in] = node_info['reduction_b'].indices
input_mask[mixed_conv] = np.concatenate((np.array(output_mask[branch_0]),
(np.array(output_mask[branch_1]) + get_node_output_channels(branch_0)) ))
output_mask[mixed_conv] = node_info['reduction_b'].indices
#pdb.set_trace()
bias_mask[mixed_bias] = output_mask[mixed_conv]
node_info['inception-resnet-c'] = node_info['reduction_b']
# the last block8
branch_0 = 'InceptionResnetV1/Block8/Branch_0/Conv2d_1x1/weights'.format(idx=i)
branch_1 = 'InceptionResnetV1/Block8/Branch_1/Conv2d_0c_3x1/weights'.format(idx=i)
branch_1_in = 'InceptionResnetV1/Block8/Branch_1/Conv2d_0a_1x1/weights'.format(idx=i)
mixed_conv = 'InceptionResnetV1/Block8/Conv2d_1x1/weights'.format(idx=i)
mixed_bias = re.sub(r'weights', 'biases', mixed_conv)
input_mask[branch_0] = node_info['inception-resnet-c'].indices
input_mask[branch_1_in] = node_info['inception-resnet-c'].indices
input_mask[mixed_conv] = np.concatenate((np.array(output_mask[branch_0]),
(np.array(output_mask[branch_1]) + get_node_output_channels(branch_0)) ))
output_mask[mixed_conv] = node_info['inception-resnet-c'].indices
bias_mask[mixed_bias] = output_mask[mixed_conv]
retrain_mask = {}
for node, indices in output_mask.iteritems():
shape = nodes_map[node].attr['value'].tensor.tensor_shape
shape = np.array([dim.size for dim in shape.dim])
retrain_mask[node] = {
'shape': shape,
'indices': indices
}
with open(args.output_mask, 'wb') as f:
pickle.dump(retrain_mask, f)
# batch norm nodes
bn_mask = {}
new_shape = {}
for n in weights_n:
if not re.search(r'block\d+_\d/Conv2d_1x1/weights$', n):
bn_prefix = re.sub(r'weights$', 'BatchNorm/', n)
beta = bn_prefix + 'beta'
moving_variance = bn_prefix + 'moving_variance'
moving_mean = bn_prefix + 'moving_mean'
moving_mean_variance = bn_prefix + 'moving_mean_variance'
moments_shape = bn_prefix + 'moments/Shape'
bn_mask[beta] = output_mask[n]
bn_mask[moving_mean] = output_mask[n]
bn_mask[moving_variance] = output_mask[n]
bn_mask[moving_mean_variance] = output_mask[n]
new_shape[moments_shape] = output_mask[n].shape[0]
def update_node_shape(in_node, delta):
# pdb.set_trace()
out_node = node_def_pb2.NodeDef()
shape = tensor_util.MakeNdarray(in_node.attr['value'].tensor)
reduced_shape = shape - delta
out_node.op = 'Const'
out_node.name = in_node.name
dtype = in_node.attr["dtype"]
out_node.attr["dtype"].CopyFrom(dtype)
out_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
tensor=tensor_util.make_tensor_proto(reduced_shape,
dtype=dtype.type
)))
return out_node
def prune_bn_node(input_node, mask):
"""TODO: Pruning batch norm nodes according to given mask.
:node_n: input node_def
:mask: given pruning mask
:returns: the pruned_node_def
"""
output_node = node_def_pb2.NodeDef()
ndarray = tensor_util.MakeNdarray(input_node.attr['value'].tensor)
pruned_ndarray = np.delete(ndarray, mask)
output_node.op = 'Const'
output_node.name = input_node.name
dtype = input_node.attr["dtype"]
output_node.attr["dtype"].CopyFrom(dtype)
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
tensor=tensor_util.make_tensor_proto(pruned_ndarray,
dtype=dtype.type,
shape=pruned_ndarray.shape)))
return output_node
def prune_bias_node(input_node, mask):
"""TODO: Pruning biases nodes according to given mask.
:node_n: input node_def
:mask: given pruning mask
:returns: the pruned_node_def
"""
output_node = node_def_pb2.NodeDef()
ndarray = tensor_util.MakeNdarray(input_node.attr['value'].tensor)
pruned_ndarray = np.delete(ndarray, mask)
output_node.op = 'Const'
output_node.name = input_node.name
dtype = input_node.attr["dtype"]
output_node.attr["dtype"].CopyFrom(dtype)
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
tensor=tensor_util.make_tensor_proto(pruned_ndarray,
dtype=dtype.type,
shape=pruned_ndarray.shape)))
return output_node
def prune_node(input_node, out_indices, in_indices):
output_node = node_def_pb2.NodeDef()
ndarray = tensor_util.MakeNdarray(input_node.attr['value'].tensor)
# pdb.set_trace()
pruned_ndarray = ndarray
if out_indices.size:
pruned_ndarray = np.delete(pruned_ndarray, out_indices, axis=3)
if in_indices.size:
pruned_ndarray = np.delete(pruned_ndarray, in_indices, axis=2)
output_node.op = 'Const'
output_node.name = input_node.name
dtype = input_node.attr["dtype"]
output_node.attr["dtype"].CopyFrom(dtype)
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
tensor=tensor_util.make_tensor_proto(pruned_ndarray,
dtype=dtype.type,
shape=pruned_ndarray.shape)))
return output_node
output_graph_def = graph_pb2.GraphDef()
for node in gd.node:
output_node = node_def_pb2.NodeDef()
# pdb.set_trace()
if node.name in weights_n:
name = node.name
print("Pruning node {}".format(name))
output_node = prune_node(node, output_mask[name], input_mask[name])
elif node.name in bn_mask:
name = node.name
print("pruning node {}".format(name))
output_node = prune_bn_node(node, bn_mask[name])
elif node.name in bias_mask:
name = node.name
print ("pruning node {}".format(name))
output_node = prune_bias_node(node, bias_mask[name])
elif node.name in new_shape:
name = node.name
print("pruning node {}".format(name))
output_node = update_node_shape(node, new_shape[name])
elif node.name == 'Bottleneck/weights':
weights = tensor_util.MakeNdarray(node.attr['value'].tensor)
weights = np.delete(weights, output_mask['InceptionResnetV1/Block8/Conv2d_1x1/weights'], 0)
output_node.op = 'Const'
output_node.name = node.name
dtype = node.attr["dtype"]
output_node.attr["dtype"].CopyFrom(dtype)
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
tensor=tensor_util.make_tensor_proto(weights,
dtype=dtype.type,
shape=weights.shape)))
elif node.name =='InceptionResnetV1/Logits/Flatten/Reshape/shape' :
last_shape = nodes_map['InceptionResnetV1/Block8/Conv2d_1x1/weights'].attr['value'].tensor.tensor_shape.dim[-1].size
_shape = last_shape - output_mask['InceptionResnetV1/Block8/Conv2d_1x1/weights'].shape[0]
output_node.op = 'Const'
output_node.name = node.name
dtype = node.attr["dtype"]
output_node.attr["dtype"].CopyFrom(dtype)
output_node.attr['value'].CopyFrom(attr_value_pb2.AttrValue(
tensor=tensor_util.make_tensor_proto(np.array([-1, _shape]),
dtype=dtype.type
)))
else:
output_node.CopyFrom(node)
output_graph_def.node.extend([output_node])
out_map = dict( [ (n.name, n) for n in output_graph_def.node])
pdb.set_trace()
with gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('GRAPH_PATH')
parser.add_argument('MASK_PATH')
parser.add_argument('output_graph')
parser.add_argument('output_mask')
args = parser.parse_args()
main(args)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
import pickle
import argparse
def main(args):
"""TODO: Docstring for main(.
:a: TODO
:returns: TODO
"""
with tf.Session() as sess:
saver = tf.train.import_meta_graph(args.meta)
saver.restore(sess, args.ckpt)
with open(args.zidx, 'rb') as f:
zero_idx = pickle.load(f)
def _node_name(n):
if n.startswith("^"):
return n[1:]
else:
return n.split(":")[0]
for v in tf.global_variables():
node_name = _node_name(v.name)
if node_name in zero_idx:
weights = v.eval()
weights[..., zero_idx[node_name]['indices']] = 0
assign_op = v.assign(weights)
assign_op.eval()
saver.save(sess, args.output)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('meta')
parser.add_argument('ckpt')
parser.add_argument('zidx')
parser.add_argument('output')
args = parser.parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment