Skip to content

Instantly share code, notes, and snippets.

@randcode-generator
Created December 10, 2019 01:12
Show Gist options
  • Save randcode-generator/e21e9614b024acc3a422e5cd1835803b to your computer and use it in GitHub Desktop.
Save randcode-generator/e21e9614b024acc3a422e5cd1835803b to your computer and use it in GitHub Desktop.
import tensorflow as tf
import os
from tensorflow.contrib import slim as contrib_slim
import urllib2
slim = contrib_slim
image_size = 224
_R_MEAN = 123.68
_G_MEAN = 116.78
_B_MEAN = 103.94
url = "http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz"
checkpoints_dir = '~/vgg'
def download_and_uncompress_tarball(tarball_url, dataset_dir):
"""Downloads the `tarball_url` and uncompresses it locally.
Args:
tarball_url: The URL of a tarball file.
dataset_dir: The directory where the temporary files are stored.
"""
filename = tarball_url.split('/')[-1]
filepath = os.path.join(dataset_dir, filename)
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib2.request.urlretrieve(tarball_url, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
def _mean_image_subtraction(image, means):
if image.get_shape().ndims != 3:
raise ValueError('Input must be of size [height, width, C>0]')
num_channels = image.get_shape().as_list()[-1]
if len(means) != num_channels:
raise ValueError('len(means) must match the number of channels')
channels = tf.split(axis=2, num_or_size_splits=num_channels, value=image)
for i in range(num_channels):
channels[i] -= means[i]
return tf.concat(axis=2, values=channels)
def create_readable_names_for_imagenet_labels():
"""Create a dict mapping label id to human readable string.
Returns:
labels_to_names: dictionary where keys are integers from to 1000
and values are human-readable names.
We retrieve a synset file, which contains a list of valid synset labels used
by ILSVRC competition. There is one synset one per line, eg.
# n01440764
# n01443537
We also retrieve a synset_to_human_file, which contains a mapping from synsets
to human-readable names for every synset in Imagenet. These are stored in a
tsv format, as follows:
# n02119247 black fox
# n02119359 silver fox
We assign each synset (in alphabetical order) an integer, starting from 1
(since 0 is reserved for the background class).
Code is based on
https://github.com/tensorflow/models/blob/master/research/inception/inception/data/build_imagenet_data.py#L463
"""
# pylint: disable=g-line-too-long
base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/research/inception/inception/data/'
synset_url = '{}/imagenet_lsvrc_2015_synsets.txt'.format(base_url)
synset_to_human_url = '{}/imagenet_metadata.txt'.format(base_url)
filename, _ = urllib.urlretrieve(synset_url)
synset_list = [s.strip() for s in open(filename).readlines()]
num_synsets_in_ilsvrc = len(synset_list)
assert num_synsets_in_ilsvrc == 1000
filename, _ = urllib.urlretrieve(synset_to_human_url)
synset_to_human_list = open(filename).readlines()
num_synsets_in_all_imagenet = len(synset_to_human_list)
assert num_synsets_in_all_imagenet == 21842
synset_to_human = {}
for s in synset_to_human_list:
parts = s.strip().split('\t')
assert len(parts) == 2
synset = parts[0]
human = parts[1]
synset_to_human[synset] = human
label_index = 1
labels_to_names = {0: 'background'}
for synset in synset_list:
name = synset_to_human[synset]
labels_to_names[label_index] = name
label_index += 1
return labels_to_names
def vgg_arg_scope(weight_decay=0.0005):
with slim.arg_scope([slim.conv2d, slim.fully_connected],
activation_fn=tf.nn.relu,
weights_regularizer=slim.l2_regularizer(weight_decay),
biases_initializer=tf.zeros_initializer()):
with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc:
return arg_sc
def vgg_16(inputs,
num_classes=1000,
is_training=True,
dropout_keep_prob=0.5,
spatial_squeeze=True,
reuse=None,
scope='vgg_16',
fc_conv_padding='VALID',
global_pool=False):
with tf.variable_scope(scope, 'vgg_16', [inputs], reuse=reuse) as sc:
end_points_collection = sc.original_name_scope + '_end_points'
#print(end_points_collection)
# Collect outputs for conv2d, fully_connected and max_pool2d.
with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
outputs_collections=end_points_collection):
net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool1')
net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
net = slim.max_pool2d(net, [2, 2], scope='pool2')
net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
net = slim.max_pool2d(net, [2, 2], scope='pool3')
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
net = slim.max_pool2d(net, [2, 2], scope='pool4')
net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
net = slim.max_pool2d(net, [2, 2], scope='pool5')
# Use conv2d instead of fully_connected layers.
net = slim.conv2d(net, 4096, [7, 7], padding=fc_conv_padding, scope='fc6')
net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
scope='dropout6')
net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
# Convert end_points_collection into a end_point dict.
end_points = slim.utils.convert_collection_to_dict(end_points_collection)
if global_pool:
net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool')
end_points['global_pool'] = net
if num_classes:
net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
scope='dropout7')
net = slim.conv2d(net, num_classes, [1, 1],
activation_fn=None,
normalizer_fn=None,
scope='fc8')
if spatial_squeeze:
net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
end_points[sc.name + '/fc8'] = net
return net, end_points
if not tf.gfile.Exists(checkpoints_dir):
tf.gfile.MakeDirs(checkpoints_dir)
download_and_uncompress_tarball(url, checkpoints_dir)
with tf.Graph().as_default():
# url = 'https://upload.wikimedia.org/wikipedia/commons/d/d9/First_Student_IC_school_bus_202076.jpg'
url = 'https://firebasestorage.googleapis.com/v0/b/machine-learning-site.appspot.com/o/bus.jpg?alt=media&token=78144953-72f7-42b0-a3b6-1143cd8dde76'
image_string = urllib2.urlopen(url).read()
image = tf.image.decode_jpeg(image_string, channels=3)
image = tf.dtypes.cast(image, tf.float32)
image = _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN])
batch_size = 1
height = 224
width = 224
image = tf.expand_dims(image, 0)
with slim.arg_scope(vgg_arg_scope()):
logits, k123 = vgg_16(image, num_classes=1000, is_training=False)
for k, v in k123.items():
print(k)
print(v)
probabilities = tf.nn.softmax(logits)
init_fn = slim.assign_from_checkpoint_fn(
os.path.join(checkpoints_dir, 'vgg_16.ckpt'),
slim.get_model_variables('vgg_16'))
with tf.Session() as sess:
init_fn(sess)
probabilities = sess.run(probabilities)
probabilities = probabilities[0, 0:]
sorted_inds = [i[0] for i in sorted(enumerate(-probabilities), key=lambda x:x[1])]
names = create_readable_names_for_imagenet_labels()
for i in range(5):
index = sorted_inds[i]
# Shift the index of a class name by one.
print('Probability %0.2f%% => [%s]' % (probabilities[index] * 100, names[index+1]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment