Skip to content

Instantly share code, notes, and snippets.

@wangjiezhe
Last active March 20, 2024 04:40
Show Gist options
  • Save wangjiezhe/050854e2a12a4f05eab66da43b579dd9 to your computer and use it in GitHub Desktop.
Save wangjiezhe/050854e2a12a4f05eab66da43b579dd9 to your computer and use it in GitHub Desktop.
Accelerating Inference in TensorFlow with TensorRT User Guide https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html
import tensorflow as tf
from tensorflow import keras
# Define a simple sequential model
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = tf.cast(x_train, dtype=tf.float32)
y_train = tf.cast(y_train, dtype=tf.float32)
x_test = tf.cast(x_test, dtype=tf.float32)
y_test = tf.cast(y_test, dtype=tf.float32)
# Train the model
model.fit(x_train, y_train, epochs=5)
# Evaluate your model accuracy
model.evaluate(x_test, y_test, verbose=2)
# Save model in the saved_model format
SAVED_MODEL_DIR="./models/native_saved_model"
tf.saved__model.save(model, SAVED_MODEL_DIR)
from tensorflow.python.compiler.tensorrt import trt_convert as trt
# Instantiate the TF-TRT converter
converter = trt.TrtGraphConverterV2(
input_saved_model_dir=SAVED_MODEL_DIR,
precision_mode=trt.TrtPrecisionMode.FP32
)
# Convert the model into TRT compatible segments
trt_func = converter.convert()
converter.summary()
MAX_BATCH_SIZE=128
def input_fn():
batch_size = MAX_BATCH_SIZE
x = x_test[0:batch_size, :]
yield [x]
converter.build(input_fn=input_fn)
OUTPUT_SAVED_MODEL_DIR="./models/tftrt_saved_model"
converter.save(output_saved_model_dir=OUTPUT_SAVED_MODEL_DIR)
# Get batches of test data and run inference through them
infer_batch_size = MAX_BATCH_SIZE // 2
for i in range(10):
print(f"Step: {i}")
start_idx = i * infer_batch_size
end_idx = (i + 1) * infer_batch_size
x = x_test[start_idx:end_idx, :]
trt_func(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment