|
# Assumes you have already run: |
|
# git clone https://github.com/tensorflow/models |
|
# export PYTHONPATH=$PWD/models/research/slim |
|
|
|
from datasets import dataset_factory |
|
from nets import nets_factory |
|
import os.path |
|
import tensorflow as tf |
|
import tfi |
|
from urllib.request import urlretrieve |
|
|
|
CHECKPOINT_URL = "http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz" |
|
CHECKPOINT_FILE = "inception_v1.ckpt" |
|
CHECKPOINT_SHA256 = "7a620c430fcaba8f8f716241f5148c4c47c035cce4e49ef02cfbe6cd1adf96a6" |
|
|
|
class InceptionV1(tfi.saved_model.Base): |
|
def __init__(self): |
|
dataset = dataset_factory.get_dataset('imagenet', 'train', '') |
|
category_items = list(dataset.labels_to_names.items()) |
|
category_items.sort() # sort by index |
|
categories = [label for _, label in category_items] |
|
self._labels = tf.constant(categories) |
|
|
|
network_fn = nets_factory.get_network_fn( |
|
'inception_v1', |
|
num_classes=len(categories), |
|
is_training=False) |
|
|
|
image_size = network_fn.default_image_size |
|
self._placeholder = tf.placeholder( |
|
name='input', |
|
dtype=tf.float32, |
|
shape=[None, image_size, image_size, 3]) |
|
|
|
logits, _ = network_fn(self._placeholder) |
|
self._scores = tf.nn.softmax(logits) |
|
tfi.checkpoint.restore(CHECKPOINT_FILE) |
|
|
|
def predict(self, *, images: self._placeholder) -> { |
|
'scores': self._scores, |
|
'categories': self._labels, |
|
}: |
|
pass |
|
|
|
# Lazily download checkpoint file and verify its digest. |
|
if not os.path.exists(CHECKPOINT_FILE): |
|
import hashlib |
|
import tarfile |
|
|
|
downloaded = urlretrieve(CHECKPOINT_URL)[0] |
|
def sha256(filename, blocksize=65536): |
|
hash = hashlib.sha256() |
|
with open(filename, "rb") as f: |
|
for block in iter(lambda: f.read(blocksize), b""): |
|
hash.update(block) |
|
return hash.hexdigest() |
|
s = sha256(downloaded) |
|
if s != CHECKPOINT_SHA256: |
|
print("invalid fetch of", CHECKPOINT_URL, s, "!=", CHECKPOINT_SHA256) |
|
exit(1) |
|
with tarfile.open(downloaded, 'r|gz') as tar: |
|
tar.extractall() |
|
|
|
# Do the actual export! |
|
tfi.saved_model.export("./inception_v1.saved_model", InceptionV1) |