Skip to content

Instantly share code, notes, and snippets.

@lukmanr
Last active October 29, 2018 16:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lukmanr/4bb2d9c53df0133930f325fe7a1845ad to your computer and use it in GitHub Desktop.
Save lukmanr/4bb2d9c53df0133930f325fe7a1845ad to your computer and use it in GitHub Desktop.
TF Model Optimization code 1
from __future__ import print_function
import os
import numpy as np
from datetime import datetime
import sys
import tensorflow as tf
from tensorflow import data
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import freeze_graph
from tensorflow.python import ops
from tensorflow.tools.graph_transforms import TransformGraph
NUM_CLASSES = 10
MODELS_LOCATION = 'models/mnist'
MODEL_NAME = 'keras_classifier'
def load_mnist_keras():
(train_data, train_labels), (eval_data, eval_labels) = tf.keras.datasets.mnist.load_data()
return train_data, train_labels, eval_data, eval_labels
def keras_model_fn(params):
inputs = tf.keras.layers.Input(shape=(28, 28), name='input_image')
input_layer = tf.keras.layers.Reshape(target_shape=(28, 28, 1), name='reshape')(inputs)
# convolutional layers
conv_inputs = input_layer
for i in range(params.num_conv_layers):
filters = params.init_filters * (2**i)
conv = tf.keras.layers.Conv2D(kernel_size=3, filters=filters, strides=1, padding='SAME', activation='relu')(conv_inputs)
max_pool = tf.keras.layers.MaxPool2D(pool_size=2, strides=2, padding='SAME')(conv)
batch_norm = tf.keras.layers.BatchNormalization()(max_pool)
conv_inputs = batch_norm
flatten = tf.keras.layers.Flatten(name='flatten')(conv_inputs)
# fully-connected layers
dense_inputs = flatten
for i in range(len(params.hidden_units)):
dense = tf.keras.layers.Dense(units=params.hidden_units[i], activation='relu')(dense_inputs)
dropout = tf.keras.layers.Dropout(params.dropout)(dense)
dense_inputs = dropout
# softmax classifier
logits = tf.keras.layers.Dense(units=NUM_CLASSES, name='logits')(dense_inputs)
softmax = tf.keras.layers.Activation('softmax', name='softmax')(logits)
# keras model
model = tf.keras.models.Model(inputs, softmax)
return model
def create_estimator_keras(params, run_config):
keras_model = keras_model_fn(params)
print(keras_model.summary())
optimizer = tf.keras.optimizers.Adam(lr=params.learning_rate)
keras_model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
mnist_classifier = tf.keras.estimator.model_to_estimator(
keras_model=keras_model,
config=run_config
)
return mnist_classifier
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment