Skip to content

Instantly share code, notes, and snippets.

@salrashid123
Last active March 30, 2023 18:55
Show Gist options
  • Save salrashid123/0e6f5a1a11bc12ab21306c1e1ce94fed to your computer and use it in GitHub Desktop.
Save salrashid123/0e6f5a1a11bc12ab21306c1e1ce94fed to your computer and use it in GitHub Desktop.
saving and loading TensorFlow serialized models
# minimal copy of https://www.tensorflow.org/tutorials/keras/text_classification
### which crates a TF Model about movie reviews, then saves the model to a file
#### later on read the model from file and test predictions
import os
import re
import shutil
import string
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import losses
# ====================================
# python3 --version
# Python 3.10.9
# virtualenv env
# source env/bin/activate
# pip3 install -r requirements.txt
# mkdir -p my_model/1
# python3 main.py
## using tensorflow serve https://github.com/tensorflow/serving/blob/master/tensorflow_serving/g3doc/docker.md
# docker run -t --rm -p 8501:8501 \
# -v `pwd`/my_model:/models/my_model \
# -e MODEL_NAME=my_model \
# tensorflow/serving
# $ curl -d '{"inputs": [
# "The movie was great!",
# "The movie was okay.",
# "The movie was terrible...",
# "awesome"
# ]}' -X POST http://localhost:8501/v1/models/my_model:predict
# {
# "outputs": [
# [
# 0.608759
# ],
# [
# 0.43605575
# ],
# [
# 0.348938584
# ],
# [
# 0.58370775
# ]
# ]
# }
print(tf.version.VERSION)
url = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
dataset = tf.keras.utils.get_file("aclImdb_v1", url,
untar=True, cache_dir='.',
cache_subdir='')
dataset_dir = os.path.join(os.path.dirname(dataset), 'aclImdb')
train_dir = os.path.join(dataset_dir, 'train')
os.listdir(train_dir)
sample_file = os.path.join(train_dir, 'pos/1181_9.txt')
with open(sample_file) as f:
print(f.read())
remove_dir = os.path.join(train_dir, 'unsup')
shutil.rmtree(remove_dir)
batch_size = 32
seed = 42
raw_train_ds = tf.keras.utils.text_dataset_from_directory(
'aclImdb/train',
batch_size=batch_size,
validation_split=0.2,
subset='training',
seed=seed)
for text_batch, label_batch in raw_train_ds.take(1):
for i in range(3):
print("Review", text_batch.numpy()[i])
print("Label", label_batch.numpy()[i])
print("Label 0 corresponds to", raw_train_ds.class_names[0])
print("Label 1 corresponds to", raw_train_ds.class_names[1])
raw_val_ds = tf.keras.utils.text_dataset_from_directory(
'aclImdb/train',
batch_size=batch_size,
validation_split=0.2,
subset='validation',
seed=seed)
raw_test_ds = tf.keras.utils.text_dataset_from_directory(
'aclImdb/test',
batch_size=batch_size)
max_features = 10000
sequence_length = 250
vectorize_layer = layers.TextVectorization(
standardize='lower_and_strip_punctuation',
max_tokens=max_features,
output_mode='int',
output_sequence_length=sequence_length)
train_text = raw_train_ds.map(lambda x, y: x)
vectorize_layer.adapt(train_text)
def vectorize_text(text, label):
text = tf.expand_dims(text, -1)
return vectorize_layer(text), label
# retrieve a batch (of 32 reviews and labels) from the dataset
text_batch, label_batch = next(iter(raw_train_ds))
first_review, first_label = text_batch[0], label_batch[0]
print("Review", first_review)
print("Label", raw_train_ds.class_names[first_label])
print("Vectorized review", vectorize_text(first_review, first_label))
print('Vocabulary size: {}'.format(len(vectorize_layer.get_vocabulary())))
train_ds = raw_train_ds.map(vectorize_text)
val_ds = raw_val_ds.map(vectorize_text)
test_ds = raw_test_ds.map(vectorize_text)
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.cache().prefetch(buffer_size=AUTOTUNE)
embedding_dim = 16
model = tf.keras.Sequential([
layers.Embedding(max_features + 1, embedding_dim),
layers.Dropout(0.2),
layers.GlobalAveragePooling1D(),
layers.Dropout(0.2),
layers.Dense(1)])
model.summary()
model.compile(loss=losses.BinaryCrossentropy(from_logits=True),
optimizer='adam',
metrics=tf.metrics.BinaryAccuracy(threshold=0.0))
epochs = 10
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs)
loss, accuracy = model.evaluate(test_ds)
print("Loss: ", loss)
print("Accuracy: ", accuracy)
history_dict = history.history
history_dict.keys()
acc = history_dict['binary_accuracy']
val_acc = history_dict['val_binary_accuracy']
loss = history_dict['loss']
val_loss = history_dict['val_loss']
epochs = range(1, len(acc) + 1)
export_model = tf.keras.Sequential([
vectorize_layer,
model,
layers.Activation('sigmoid')
])
export_model.compile(
loss=losses.BinaryCrossentropy(from_logits=False), optimizer="adam", metrics=['accuracy']
)
# Test it with `raw_test_ds`, which yields raw strings
loss, accuracy = export_model.evaluate(raw_test_ds)
print(accuracy)
model_version = "1"
path = 'my_model/' + model_version
os.makedirs(path)
export_model.save(path)
### Load the model from disk
import_model = tf.keras.models.load_model(path)
examples = [
"The movie was great!",
"The movie was okay.",
"The movie was terrible...",
"awesome"
]
#print(export_model.predict(examples))
print(import_model.predict(examples))
## requirements.txt
# absl-py==1.4.0
# astunparse==1.6.3
# cachetools==5.3.0
# certifi==2022.12.7
# charset-normalizer==3.1.0
# flatbuffers==23.3.3
# gast==0.4.0
# google-auth==2.17.0
# google-auth-oauthlib==0.4.6
# google-pasta==0.2.0
# grpcio==1.53.0
# h5py==3.8.0
# idna==3.4
# jax==0.4.8
# keras==2.12.0
# libclang==16.0.0
# Markdown==3.4.3
# MarkupSafe==2.1.2
# ml-dtypes==0.0.4
# numpy==1.23.5
# oauthlib==3.2.2
# opt-einsum==3.3.0
# packaging==23.0
# protobuf==4.22.1
# pyasn1==0.4.8
# pyasn1-modules==0.2.8
# requests==2.28.2
# requests-oauthlib==1.3.1
# rsa==4.9
# scipy==1.10.1
# six==1.16.0
# tensorboard==2.12.0
# tensorboard-data-server==0.7.0
# tensorboard-plugin-wit==1.8.1
# tensorflow==2.12.0
# tensorflow-estimator==2.12.0
# tensorflow-io-gcs-filesystem==0.32.0
# termcolor==2.2.0
# typing_extensions==4.5.0
# urllib3==1.26.15
# Werkzeug==2.2.3
# wrapt==1.14.1
# setuptools==65.6.3
# wheel==0.38.4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment