Created
November 15, 2019 17:01
-
-
Save nkreeger/29a6498348d201306baa127774b13a60 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
# Sample 2x FC (dense) model | |
model = tf.keras.models.Sequential() | |
model.add(tf.keras.layers.Dense(48, input_shape=(16,))) | |
model.add(tf.keras.layers.Dense(48)) | |
model.add(tf.keras.layers.Softmax()) | |
model.compile( | |
optimizer='adam', | |
loss='sparse_categorical_crossentropy', | |
metrics=['accuracy']) | |
model.summary() | |
# Generate random data | |
data_x = np.random.rand(16, 16) | |
data_y = np.random.randint(2, size=(16, 1)) | |
# Train | |
model.fit(data_x, data_y, epochs=16) | |
# Representative dataset for full quantization: | |
def representative_dataset_gen(): | |
for _ in range(16): | |
yield [np.random.rand(1, 16).astype(np.float32)] | |
# Convert to fixed-point | |
converter = tf.lite.TFLiteConverter.from_keras_model(model) | |
converter.optimizations = [tf.lite.Optimize.DEFAULT] | |
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] | |
converter.inference_input_type = tf.int8 | |
converter.inference_output_type = tf.int8 | |
converter.representative_dataset = representative_dataset_gen | |
tflite_model = converter.convert() | |
# Save | |
open("/tmp/fc_fp_quant_model.tflite", "wb").write(tflite_model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment