Last active September 24, 2017 00:46
Advanced TFI + TensorFlow SavedModel Example
# Assumes you have already run:
# git clone
# export PYTHONPATH=$PWD/models/research/slim
from datasets import dataset_factory
from nets import nets_factory
import os.path
import tensorflow as tf
import tfi
from urllib.request import urlretrieve
CHECKPOINT_FILE = "inception_v1.ckpt"
CHECKPOINT_SHA256 = "7a620c430fcaba8f8f716241f5148c4c47c035cce4e49ef02cfbe6cd1adf96a6"
class InceptionV1(tfi.saved_model.Base):
def __init__(self):
dataset = dataset_factory.get_dataset('imagenet', 'train', '')
category_items = list(dataset.labels_to_names.items())
category_items.sort() # sort by index
categories = [label for _, label in category_items]
self._labels = tf.constant(categories)
network_fn = nets_factory.get_network_fn(
image_size = network_fn.default_image_size
self._placeholder = tf.placeholder(
shape=[None, image_size, image_size, 3])
logits, _ = network_fn(self._placeholder)
self._scores = tf.nn.softmax(logits)
def predict(self, *, images: self._placeholder) -> {
'scores': self._scores,
'categories': self._labels,
# Lazily download checkpoint file and verify its digest.
if not os.path.exists(CHECKPOINT_FILE):
import hashlib
import tarfile
downloaded = urlretrieve(CHECKPOINT_URL)[0]
def sha256(filename, blocksize=65536):
hash = hashlib.sha256()
with open(filename, "rb") as f:
for block in iter(lambda:, b""):
return hash.hexdigest()
s = sha256(downloaded)
if s != CHECKPOINT_SHA256:
print("invalid fetch of", CHECKPOINT_URL, s, "!=", CHECKPOINT_SHA256)
with, 'r|gz') as tar:
# Do the actual export!
tfi.saved_model.export("./inception_v1.saved_model", InceptionV1)
import tfi
InceptionV1 = tfi.saved_model.as_class("./inception_v1.saved_model")
model = InceptionV1()
image ="./dog-medium-landing-hero.jpg")
result = model.predict(images=[image])
categories, scores = result.categories, result.scores[0]
print([(scores[i], categories[i].decode()) for i in scores.argsort()[:-5:-1]])
