Skip to content

Instantly share code, notes, and snippets.

@ajbouh
Last active September 24, 2017 00:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ajbouh/2b1d076f4af569edeeab819a39184abf to your computer and use it in GitHub Desktop.
Save ajbouh/2b1d076f4af569edeeab819a39184abf to your computer and use it in GitHub Desktop.
Advanced TFI + TensorFlow SavedModel Example
# Assumes you have already run:
# git clone https://github.com/tensorflow/models
# export PYTHONPATH=$PWD/models/research/slim
from datasets import dataset_factory
from nets import nets_factory
import os.path
import tensorflow as tf
import tfi
from urllib.request import urlretrieve
CHECKPOINT_URL = "http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz"
CHECKPOINT_FILE = "inception_v1.ckpt"
CHECKPOINT_SHA256 = "7a620c430fcaba8f8f716241f5148c4c47c035cce4e49ef02cfbe6cd1adf96a6"
class InceptionV1(tfi.saved_model.Base):
def __init__(self):
dataset = dataset_factory.get_dataset('imagenet', 'train', '')
category_items = list(dataset.labels_to_names.items())
category_items.sort() # sort by index
categories = [label for _, label in category_items]
self._labels = tf.constant(categories)
network_fn = nets_factory.get_network_fn(
'inception_v1',
num_classes=len(categories),
is_training=False)
image_size = network_fn.default_image_size
self._placeholder = tf.placeholder(
name='input',
dtype=tf.float32,
shape=[None, image_size, image_size, 3])
logits, _ = network_fn(self._placeholder)
self._scores = tf.nn.softmax(logits)
tfi.checkpoint.restore(CHECKPOINT_FILE)
def predict(self, *, images: self._placeholder) -> {
'scores': self._scores,
'categories': self._labels,
}:
pass
# Lazily download checkpoint file and verify its digest.
if not os.path.exists(CHECKPOINT_FILE):
import hashlib
import tarfile
downloaded = urlretrieve(CHECKPOINT_URL)[0]
def sha256(filename, blocksize=65536):
hash = hashlib.sha256()
with open(filename, "rb") as f:
for block in iter(lambda: f.read(blocksize), b""):
hash.update(block)
return hash.hexdigest()
s = sha256(downloaded)
if s != CHECKPOINT_SHA256:
print("invalid fetch of", CHECKPOINT_URL, s, "!=", CHECKPOINT_SHA256)
exit(1)
with tarfile.open(downloaded, 'r|gz') as tar:
tar.extractall()
# Do the actual export!
tfi.saved_model.export("./inception_v1.saved_model", InceptionV1)
import tfi
InceptionV1 = tfi.saved_model.as_class("./inception_v1.saved_model")
model = InceptionV1()
image = tfi.data.file("./dog-medium-landing-hero.jpg")
result = model.predict(images=[image])
categories, scores = result.categories, result.scores[0]
print([(scores[i], categories[i].decode()) for i in scores.argsort()[:-5:-1]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment