Skip to content

Instantly share code, notes, and snippets.

@adriangb
Last active February 21, 2021 19:46
Show Gist options
  • Save adriangb/996999e708b0ec38343268e5c4fef908 to your computer and use it in GitHub Desktop.
Save adriangb/996999e708b0ec38343268e5c4fef908 to your computer and use it in GitHub Desktop.
SaveModel vs. Pickle (via SaveModel)
import tempfile
from timeit import default_timer, timeit
from typing import List, Tuple
from matplotlib import pyplot as plt
import numpy as np
from tensorflow import keras
from tensorflow.keras.models import load_model
from scikeras._saving_utils import pack_keras_model
def get_model(n_hidden: int) -> keras.Sequential:
inp = x = keras.Input((64,)) # arbitrary dim choice
for _ in range(n_hidden):
x = keras.layers.Dense(64)(x) # arbitrary dim choice
model = keras.Model(inp, x)
model.compile(loss="mse") # arbitrary loss choice
return model
def roundtrip_pickle(model: keras.Model):
packed = pack_keras_model(model)
model = packed[0](*packed[1]) # see pickle protocol for details
def roundtrip_savemodel(model: keras.Model):
tmpdir = f"ram://{tempfile.mkdtemp()}"
model.save(tmpdir)
model = load_model(tmpdir)
def bench_layers(
n_hidden: int,
) -> Tuple[float, float, float, float]:
model = get_model(n_hidden)
# get any type of caching/optimization out of the way
roundtrip_savemodel(model)
pickle_times = []
for repeat in range(10):
start = default_timer()
roundtrip_pickle(model)
end = default_timer()
pickle_times.append(end - start)
savemodel_times = []
for repeat in range(10):
start = default_timer()
roundtrip_savemodel(model)
end = default_timer()
savemodel_times.append(end - start)
return (
np.mean(pickle_times),
np.std(pickle_times),
np.mean(savemodel_times),
np.std(savemodel_times),
)
def get_mobilenet() -> keras.Model:
model = keras.applications.MobileNetV3Small(minimalistic=True)
model.compile(loss="sparse_categorical_crossentropy") # arbitrary loss choice
return model
def bench_mobilenet(repeats: int) -> Tuple[List[float], List[float]]:
model = get_mobilenet()
# get any type of caching/optimization out of the way
roundtrip_savemodel(model)
pickle_times = []
for _ in range(repeats):
start = default_timer()
roundtrip_pickle(model)
end = default_timer()
pickle_times.append(end - start)
savemodel_times = []
for _ in range(repeats):
start = default_timer()
roundtrip_savemodel(model)
end = default_timer()
savemodel_times.append(end - start)
return (pickle_times, savemodel_times)
n_hiddens = [1, 5, 10, 25, 50, 100]
pickle_times = []
pickle_std = []
savemodel_times = []
savemodel_std = []
for n_hidden in n_hiddens:
pt, ps, st, ss = bench_layers(n_hidden)
pickle_times.append(pt)
pickle_std.append(ps)
savemodel_times.append(st)
savemodel_std.append(ss)
time_pickle, time_savemodel = bench_mobilenet(30)
fig, (ax1, ax2) = plt.subplots(ncols=2)
ax1.errorbar(
np.array(n_hiddens) - n_hiddens[-1]/50,
pickle_times,
yerr=pickle_std,
label="Pickle",
fmt="o",
color="blue",
ecolor="cornflowerblue",
elinewidth=3,
capsize=0,
)
ax1.errorbar(
np.array(n_hiddens) + n_hiddens[-1]/50,
savemodel_times,
yerr=savemodel_std,
label="SaveModel",
fmt="o",
color="red",
ecolor="lightcoral",
elinewidth=3,
capsize=0,
)
ax1.legend()
ax1.set_xlabel("Number of hidden layers")
ax1.set_ylabel("Roundterip time (s)")
ax2.boxplot([time_pickle, time_savemodel], labels=["Pickle", "SaveModel"])
ax2.set_ylabel("Roundtrip time (s)")
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment