This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from IPython.display import clear_output, Image, display, HTML | |
import numpy as np | |
import tensorflow as tf | |
def strip_consts(graph_def, max_const_size=32): | |
"""Strip large constant values from graph_def.""" | |
strip_def = tf.GraphDef() | |
for n0 in graph_def.node: | |
n = strip_def.node.add() | |
n.MergeFrom(n0) | |
if n.op == 'Const': |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
os.makedirs("data/VOCdevkit", exist_ok=True) | |
voc2007_dir = os.path.join(project_name, "data/VOC2007") | |
os.system("ln -s {} data/VOCdevkit".format(voc2007_dir)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%cd {mmdetection_dir} | |
from mmcv.runner import load_checkpoint | |
from mmdet.apis import inference_detector, show_result, init_detector | |
checkpoint_file = os.path.join(mmdetection_dir, work_dir, "latest.pth") | |
score_thr = 0.8 | |
# build the model from a config file and a checkpoint file | |
model = init_detector(config_fname, checkpoint_file) | |
# test a single image and show the results | |
img = 'data/VOCdevkit/VOC2007/JPEGImages/15.jpg' | |
result = inference_detector(model, img) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file)) | |
interpreter.allocate_tensors() | |
input_index = interpreter.get_input_details()[0]["index"] | |
output_index = interpreter.get_output_details()[0]["index"] | |
def eval_model(interpreter, x_test, y_test): | |
total_seen = 0 | |
num_correct = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create the .tflite file | |
tflite_model_file = "/tmp/sparse_mnist.tflite" | |
converter = tf.lite.TFLiteConverter.from_keras_model_file(pruned_keras_file) | |
tflite_model = converter.convert() | |
with open(tflite_model_file, "wb") as f: | |
f.write(tflite_model) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tempfile | |
import zipfile | |
_, new_pruned_keras_file = tempfile.mkstemp(".h5") | |
print("Saving pruned model to: ", new_pruned_keras_file) | |
tf.keras.models.save_model(final_model, new_pruned_keras_file, include_optimizer=False) | |
# Zip the .h5 model file | |
_, zip3 = tempfile.mkstemp(".zip") | |
with zipfile.ZipFile(zip3, "w", compression=zipfile.ZIP_DEFLATED) as f: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tensorflow.keras.models import load_model | |
model = load_model(final_model) | |
import numpy as np | |
for i, w in enumerate(model.get_weights()): | |
print( | |
"{} -- Total:{}, Zeros: {:.2f}%".format( | |
model.weights[i].name, w.size, np.sum(w == 0) / w.size * 100 | |
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Add a pruning step callback to peg the pruning step to the optimizer's | |
# step. Also add a callback to add pruning summaries to tensorboard | |
callbacks = [ | |
sparsity.UpdatePruningStep(), | |
sparsity.PruningSummaries(log_dir=logdir, profile_batch=0) | |
] | |
new_pruned_model.fit(x_train, y_train, | |
batch_size=batch_size, | |
epochs=epochs, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
from tensorflow_model_optimization.sparsity import keras as sparsity | |
# Backend agnostic way to save/restore models | |
# _, keras_file = tempfile.mkstemp('.h5') | |
# print('Saving model to: ', keras_file) | |
# tf.keras.models.save_model(model, keras_file, include_optimizer=False) | |
# Load the serialized model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchvision | |
from torch.utils.tensorboard import SummaryWriter | |
from torchvision import datasets, transforms | |
# Writer will output to ./runs/ directory by default | |
writer = SummaryWriter() | |
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) | |
trainset = datasets.MNIST('mnist_train', train=True, download=True, transform=transform) |