Skip to content

Instantly share code, notes, and snippets.

@ryanholbrook
Created January 9, 2021 14:15
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ryanholbrook/85583a7d847bb1639c3cf8a3769db68e to your computer and use it in GitHub Desktop.
Save ryanholbrook/85583a7d847bb1639c3cf8a3769db68e to your computer and use it in GitHub Desktop.
Optimization Visualization
import math
from itertools import product
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from matplotlib import gridspec
# # Activation Model #
def make_activation_model(model, layer_name, filter):
layer = model.get_layer(layer_name) # Grab the layer
feature_map = layer.output[:, :, :, filter]
activation_model = keras.Model(
inputs=model.inputs, # New inputs are original inputs (images)
outputs=feature_map, # New outputs are the layer's outputs (feature maps)
)
return activation_model
def show_feature_map(image, model, layer_name, filter, ax=None):
act = make_activation_model(model, layer_name, filter)
feature_map = tf.squeeze(act(tf.expand_dims(image, axis=0)))
if ax is None:
fig, ax = plt.subplots()
ax.imshow(
feature_map, cmap="magma", vmin=0.0, vmax=1.0,
)
ax.axis("off")
return ax
def show_feature_maps(
image,
model,
layer_name,
offset=0,
rows=None,
cols=3,
width=12,
cmap="magma",
):
if rows is None:
num_filters = model.get_layer(layer_name).output.shape[-1]
rows = math.floor(num_filters / cols)
fig, axs = plt.subplots(
rows,
cols,
figsize=(width, (width * rows) / cols),
gridspec_kw=dict(wspace=0.01, hspace=0.01),
)
for f, (r, c) in enumerate(product(range(rows), range(cols))):
axs[r, c] = show_feature_map(
image, model, layer_name, f + offset, ax=axs[r, c]
)
return fig
# # Optimization Model #
class OptVis(object):
def __init__(
self,
model,
layer,
filter,
neuron=False,
size=[128, 128],
fft=True,
scale=0.01,
):
# Create activation model
activations = model.get_layer(layer).output
if len(activations.shape) == 4:
activations = activations[:, :, :, filter]
else:
raise ValueError("Activation shapes other than 4 not implemented.")
if neuron:
_, y, x = activations.shape
# find center
# TODO: need to compute this from selected size, not activations
yc = int(round(y / 2))
xc = int(round(x / 2))
activations = activations[:, yc, xc]
self.activation_model = keras.Model(
inputs=model.inputs, outputs=activations
)
# Create random initialization buffer
self.shape = [1, *size, 3]
self.fft = fft
self.image = init_buffer(
height=size[0], width=size[1], fft=fft, scale=scale
)
self.fft_scale = fft_scale(size[0], size[1], decay_power=1.0)
def __call__(self):
image = self.activation_model(self.image)
return image
def compile(self, optimizer):
self.optimizer = optimizer
@tf.function
def train_step(self):
# Compute loss
with tf.GradientTape() as tape:
image = self.image
if self.fft:
image = fft_to_rgb(
shape=self.shape, buffer=image, fft_scale=self.fft_scale
)
image = to_valid_rgb(image)
image = random_transform(
tf.squeeze(image),
jitter=8,
scale=1.1,
rotate=1.0,
fill_method="reflect",
)
image = tf.expand_dims(image, 0)
loss = clip_gradients(score(self.activation_model(image)))
# Apply gradient
grads = tape.gradient(loss, self.image)
self.optimizer.apply_gradients([(-grads, self.image)])
return {"loss": loss}
@tf.function
def fit(self, epochs=1, log=False):
for epoch in tf.range(epochs):
loss = self.train_step()
if log:
print("Score: {}".format(loss["loss"]))
image = self.image
if self.fft:
image = fft_to_rgb(
shape=self.shape, buffer=image, fft_scale=self.fft_scale
)
return to_valid_rgb(image)
# # Loss and Gradients #
def score(x):
s = tf.math.reduce_mean(x)
return s
@tf.custom_gradient
def clip_gradients(y):
def backward(dy):
return tf.clip_by_norm(dy, 1.0)
return y, backward
# unused
def normalize_gradients(grads, method="l2"):
if method == "l2":
grads = tf.math.l2_normalize(grads)
elif method == "std":
grads /= tf.math.reduce_std(grads) + 1e-8
elif method == "clip":
grads = tf.clip_by_norm(grads, 1.0)
return grads
# # Color Transforms #
# ImageNet statistics
color_correlation_svd_sqrt = np.asarray(
[[0.26, 0.09, 0.02], [0.27, 0.00, -0.05], [0.27, -0.09, 0.03]]
).astype("float32")
max_norm_svd_sqrt = np.max(np.linalg.norm(color_correlation_svd_sqrt, axis=0))
color_correlation_normalized = color_correlation_svd_sqrt / max_norm_svd_sqrt
color_mean = np.asarray([0.485, 0.456, 0.406])
color_std = np.asarray([0.229, 0.224, 0.225])
def correlate_color(image):
image_flat = tf.reshape(image, [-1, 3])
image_flat = tf.matmul(image_flat, color_correlation_normalized.T)
image = tf.reshape(image_flat, tf.shape(image))
return image
def normalize(image):
return (image - color_mean) / color_std
def to_valid_rgb(image, crop=False):
if crop:
image = image[:, 25:-25, 25:-25, :]
image = correlate_color(image)
image = tf.nn.sigmoid(image)
return image
# # Spatial Transforms #
# Adapted from https://github.com/tensorflow/lucid/blob/master/lucid/optvis/param/spatial.py
# and https://github.com/elichen/Feature-visualization/blob/master/optvis.py
def rfft2d_freqs(h, w):
"""Computes 2D spectrum frequencies."""
fy = np.fft.fftfreq(h)[:, np.newaxis]
# when we have an odd input dimension we need to keep one additional
# frequency and later cut off 1 pixel
if w % 2 == 1:
fx = np.fft.fftfreq(w)[: w // 2 + 2]
else:
fx = np.fft.fftfreq(w)[: w // 2 + 1]
return np.sqrt(fx * fx + fy * fy)
def fft_scale(h, w, decay_power=1.0):
freqs = rfft2d_freqs(h, w)
scale = 1.0 / np.maximum(freqs, 1.0 / max(w, h)) ** decay_power
scale *= np.sqrt(w * h)
return tf.convert_to_tensor(scale, dtype=tf.complex64)
def fft_to_rgb(shape, buffer, fft_scale):
"""Convert FFT spectrum buffer to RGB image buffer."""
batch, h, w, ch = shape
spectrum = tf.complex(buffer[0], buffer[1]) * fft_scale
image = tf.signal.irfft2d(spectrum)
image = tf.transpose(image, (0, 2, 3, 1))
# in case of odd spatial input dimensions we need to crop
image = image[:batch, :h, :w, :ch]
image = image / 4.0
return image
# # Affine Transforms #
@tf.function
def random_transform(image, jitter=0, rotate=0, scale=1, **kwargs):
jx = tf.random.uniform([], -jitter, jitter)
jy = tf.random.uniform([], -jitter, jitter)
r = tf.random.uniform([], -rotate, rotate)
s = tf.random.uniform([], 1.0, scale)
image = apply_affine_transform(
image, theta=r, tx=jx, ty=jy, zx=s, zy=s, **kwargs,
)
return image
@tf.function
def apply_affine_transform(
x,
theta=0,
tx=0,
ty=0,
shear=0,
zx=1,
zy=1,
row_axis=0,
col_axis=1,
channel_axis=2,
fill_method="reflect",
cval=0.0,
interpolation_method="nearest",
):
""" Apply an affine transformation to an image x. """
theta = tf.convert_to_tensor(theta, dtype=tf.float32)
tx = tf.convert_to_tensor(tx, dtype=tf.float32)
ty = tf.convert_to_tensor(ty, dtype=tf.float32)
shear = tf.convert_to_tensor(shear, dtype=tf.float32)
zx = tf.convert_to_tensor(zx, dtype=tf.float32)
zy = tf.convert_to_tensor(zy, dtype=tf.float32)
transform_matrix = _get_inverse_affine_transform(
theta, tx, ty, shear, zx, zy,
)
x = _apply_inverse_affine_transform(
x,
transform_matrix,
fill_method=fill_method,
interpolation_method=interpolation_method,
)
return x
@tf.function
def _get_inverse_affine_transform(theta, tx, ty, shear, zx, zy):
""" Construct the inverse of the affine transformation matrix with the given transformations.
The transformation is taken with respect to the usual right-handed coordinate system."""
transform_matrix = tf.eye(3, dtype=tf.float32)
if theta != 0:
theta = theta * math.pi / 180 # convert degrees to radians
# this is
rotation_matrix = tf.convert_to_tensor(
[
[tf.math.cos(theta), tf.math.sin(theta), 0],
[-tf.math.sin(theta), tf.math.cos(theta), 0],
[0, 0, 1],
],
dtype=tf.float32,
)
transform_matrix = rotation_matrix
if tx != 0 or ty != 0:
shift_matrix = tf.convert_to_tensor(
[[1, 0, -tx], [0, 1, -ty], [0, 0, 1]], dtype=tf.float32
)
if transform_matrix is None:
transform_matrix = shift_matrix
else:
transform_matrix = tf.matmul(transform_matrix, shift_matrix)
if shear != 0:
shear = shear * math.pi / 180 # convert degrees to radians
shear_matrix = tf.convert_to_tensor(
[
[1, tf.math.sin(shear), 0],
[0, tf.math.cos(shear), 0],
[0, 0, 1],
],
dtype=tf.float32,
)
if transform_matrix is None:
transform_matrix = shear_matrix
else:
transform_matrix = tf.matmul(transform_matrix, shear_matrix)
if zx != 1 or zy != 1:
# need to assert !=0
zoom_matrix = tf.convert_to_tensor(
[[1 / zx, 0, 0], [0, 1 / zy, 0], [0, 0, 1]], dtype=tf.float32
)
if transform_matrix is None:
transform_matrix = zoom_matrix
else:
transform_matrix = tf.matmul(transform_matrix, zoom_matrix)
return transform_matrix
@tf.function
def _apply_inverse_affine_transform(A, Ti, fill_method, interpolation_method):
"""Perform an affine transformation of the image A defined by a
transform whose inverse is Ti. The matrix Ti is assumed to be in
homogeneous coordinate form.
Available fill methods are "replicate" and "reflect" (default).
Available interpolation method is "nearest".
"""
nrows, ncols, _ = A.shape
# Create centered coordinate grid
x = tf.range(ncols * nrows) % ncols
x = tf.cast(x, dtype=tf.float32) - ((ncols - 1) / 2) # center
y = tf.range(ncols * nrows) // ncols
y = tf.cast(y, dtype=tf.float32) - ((nrows - 1) / 2) # center
y = -y # left-handed to right-handed coordinates
z = tf.ones([ncols * nrows], dtype=tf.float32)
grid = tf.stack([x, y, z])
# apply transformation
# x, y, _ = tf.matmul(Ti, grid)
xy = tf.matmul(Ti, grid)
x = xy[0, :]
y = xy[1, :]
# convert coordinates to (approximate) indices
i = -y + ((nrows - 1) / 2)
j = x + ((ncols - 1) / 2)
# replicate: 111|1234|444
if fill_method == "replicate":
i = tf.clip_by_value(i, 0.0, nrows - 1)
j = tf.clip_by_value(j, 0.0, ncols - 1)
# reflect: 432|1234|321
elif fill_method == "reflect":
i = _reflect_index(i, nrows - 1)
j = _reflect_index(j, ncols - 1)
# nearest neighbor interpolation
grid = tf.stack([i, j])
grid = tf.round(grid)
grid = tf.cast(grid, dtype=tf.int32)
B = tf.gather_nd(A, tf.transpose(grid))
B = tf.reshape(B, A.shape)
return B
@tf.function
def _reflect_index(i, n):
"""Reflect the index i across dimensions [0, n]."""
i = tf.math.floormod(i - n, 2 * n)
i = tf.math.abs(i - n)
return tf.math.floor(i)
# # Buffer Initializers #
def init_buffer(
height, width=None, batches=1, channels=3, scale=0.01, fft=True
):
"""Initialize an image buffer."""
width = width or height
shape = [batches, height, width, channels]
fn = init_fft if fft else init_pixel
buffer = fn(shape, scale)
return tf.Variable(buffer, trainable=True)
def init_pixel(shape, scale=None):
batches, h, w, ch = shape
# initializer = tf.initializers.VarianceScaling(scale=scale)
initializer = tf.random.uniform
buffer = initializer(shape=[batches, h, w, ch], dtype=tf.float32)
return buffer
def init_fft(shape, scale=0.1):
"""Initialize FFT image buffer."""
batch, h, w, ch = shape
freqs = rfft2d_freqs(h, w)
init_val_size = (2, batch, ch) + freqs.shape
buffer = np.random.normal(size=init_val_size, scale=scale).astype(
np.float32
)
return buffer
# # Plotting #
def read_image(filename, size=[256, 256]):
image = tf.io.read_file(filename)
image = tf.io.decode_jpeg(image)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.resize(image, size=size)
return image
def visualize(
model,
layer,
filter,
neuron=False,
size=[150, 150],
fft=True,
lr=0.05,
epochs=500,
log=False,
ax=None,
):
optvis = OptVis(model, layer, filter, neuron=neuron, size=size, fft=fft)
optvis.compile(optimizer=tf.optimizers.Adam(lr))
image = optvis.fit(epochs=epochs, log=log)
if ax is None:
fig, ax = plt.subplots()
ax.imshow(tf.squeeze(image).numpy())
ax.axis("off")
return ax
def visualize_layer(
model,
layer,
init_filter=0,
neuron=False,
size=[150, 150],
fft=True,
lr=0.05,
epochs=500,
log=False,
rows=2,
cols=4,
width=16,
):
gs = gridspec.GridSpec(rows, cols, wspace=0.01, hspace=0.01)
plt.figure(figsize=(width, (width * rows) / cols))
for f, (r, c) in enumerate(product(range(rows), range(cols))):
optvis = OptVis(
model, layer, f + init_filter, neuron=neuron, size=size, fft=fft
)
optvis.compile(optimizer=tf.optimizers.Adam(lr))
image = optvis.fit(epochs=epochs)
plt.subplot(gs[r, c])
plt.imshow(tf.squeeze(image))
plt.axis("off")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment