Skip to content

Instantly share code, notes, and snippets.

@kmader
Last active February 18, 2020 08:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kmader/0cc8d40216349dadc3e787adca2bbe77 to your computer and use it in GitHub Desktop.
Save kmader/0cc8d40216349dadc3e787adca2bbe77 to your computer and use it in GitHub Desktop.
Convert a preprocessing function into a convolutional layer (in Keras). It takes an input from -127 to 127 and runs it through the preprocess_input function and then returns a convolutional layer with the appropriate weights and biases to reproduce (as closely as possible) the preprocessing step. This allows models to be packaged without worryin…
from keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
from keras import layers, models
from sklearn.linear_model import LinearRegression
import numpy as np
def prep_to_conv(
in_prep_func: Callable[[np.array], np.array],
*,
min_val: float=-127,
max_val: float=127,
channels: int=3,
verbose: bool=False
) -> layers.Layer:
"""Function to turn a preprocessing step into a convolutional layer"""
test_channel = np.linspace(min_val, max_val, 9).reshape((1, 3, 3)).astype('float32')
test_input = np.stack([test_channel]*3, axis=-1)
test_output = in_prep_func(test_input.copy())
W = np.zeros((1, 1, channels, channels), dtype='float32')
b = np.zeros((channels,), dtype='float32')
for i in range(channels):
x = test_input[:, :, :, i].ravel().reshape((-1, 1))
y = test_output[:, :, :, i].ravel().reshape((-1, 1))
lin_reg = LinearRegression()
lin_reg.fit(x, y)
W[0, 0, i, i] = lin_reg.coef_[0]
b[i] = lin_reg.intercept_
conv1 = layers.Conv2D(channels,
kernel_size=(1,1),
activation='linear',
use_bias=True,
weights=[W,b],
input_shape=(None, None, channels))
conv1.trainable=False
if verbose:
import matplotlib.pyplot as plt
# check the inputs and outputs
s = models.Sequential()
s.add(conv1)
pred_output = s.predict(test_input)
fig, m_axs = plt.subplots(1, channels, figsize=(4*channels, 4))
for i, c_ax in enumerate(m_axs):
x = test_input[:, :, :, i].ravel().reshape((-1, 1))
y = test_output[:, :, :, i].ravel().reshape((-1, 1))
z = pred_output[:, :, :, i].ravel().reshape((-1, 1))
c_ax.plot(x, y, 's', label='Given')
c_ax.plot(x, z, '-', label='Predicted')
c_ax.legend()
return conv1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment