Skip to content

Instantly share code, notes, and snippets.

@alexlee-gk
Last active December 7, 2018 21:25
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alexlee-gk/cbc9bfa6e5be51b53c622684cec0a3f3 to your computer and use it in GitHub Desktop.
Save alexlee-gk/cbc9bfa6e5be51b53c622684cec0a3f3 to your computer and use it in GitHub Desktop.
SSIM TensorFlow implementation that matches scikit-image's compare_ssim
import tensorflow as tf
from tensorflow.python.util import nest
def _with_flat_batch(flat_batch_fn):
def fn(x, *args, **kwargs):
shape = tf.shape(x)
flat_batch_x = tf.reshape(x, tf.concat([[-1], shape[-3:]], axis=0))
flat_batch_r = flat_batch_fn(flat_batch_x, *args, **kwargs)
r = nest.map_structure(lambda x: tf.reshape(x, tf.concat([shape[:-3], x.shape[1:]], axis=0)),
flat_batch_r)
return r
return fn
def structural_similarity(X, Y, K1=0.01, K2=0.03, win_size=7,
data_range=1.0, use_sample_covariance=True):
"""
Structural SIMilarity (SSIM) index between two images
Args:
X: A tensor of shape `[..., in_height, in_width, in_channels]`.
Y: A tensor of shape `[..., in_height, in_width, in_channels]`.
Returns:
The SSIM between images X and Y.
Reference:
https://github.com/scikit-image/scikit-image/blob/master/skimage/measure/_structural_similarity.py
Broadcasting is supported.
"""
X = tf.convert_to_tensor(X)
Y = tf.convert_to_tensor(Y)
ndim = 2 # number of spatial dimensions
nch = tf.shape(X)[-1]
filter_func = _with_flat_batch(tf.nn.depthwise_conv2d)
kernel = tf.cast(tf.fill([win_size, win_size, nch, 1], 1 / win_size ** 2), X.dtype)
filter_args = {'filter': kernel, 'strides': [1] * 4, 'padding': 'VALID'}
NP = win_size ** ndim
# filter has already normalized by NP
if use_sample_covariance:
cov_norm = NP / (NP - 1) # sample covariance
else:
cov_norm = 1.0 # population covariance to match Wang et. al. 2004
# compute means
ux = filter_func(X, **filter_args)
uy = filter_func(Y, **filter_args)
# compute variances and covariances
uxx = filter_func(X * X, **filter_args)
uyy = filter_func(Y * Y, **filter_args)
uxy = filter_func(X * Y, **filter_args)
vx = cov_norm * (uxx - ux * ux)
vy = cov_norm * (uyy - uy * uy)
vxy = cov_norm * (uxy - ux * uy)
R = data_range
C1 = (K1 * R) ** 2
C2 = (K2 * R) ** 2
A1, A2, B1, B2 = ((2 * ux * uy + C1,
2 * vxy + C2,
ux ** 2 + uy ** 2 + C1,
vx + vy + C2))
D = B1 * B2
S = (A1 * A2) / D
ssim = tf.reduce_mean(S, axis=[-3, -2, -1])
return ssim
def main():
import numpy as np
from skimage.measure import compare_ssim
batch_size = 4
image_shape = (64, 64, 3)
images0 = np.random.random((batch_size,) + image_shape)
images1 = np.random.random((batch_size,) + image_shape)
sess = tf.Session()
ssim_tf = tf.reduce_mean(structural_similarity(images0, images1))
ssim_tf = sess.run(ssim_tf)
ssim_skimage = np.mean([compare_ssim(image0, image1, data_range=1.0, multichannel=True)
for image0, image1 in zip(images0, images1)])
print(ssim_tf, ssim_skimage)
if __name__ == '__main__':
main()
@brunopop
Copy link

Hello, I tried running your script as-is and the numbers don't match: ssim_tf is always 1.0. Are you sure you are flattening the batch correctly in _with_flat_batch?

@rakeshmahadasa
Copy link

I agree with brunopop. ssim_tf is always 1. could you please explain the reason?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment