Skip to content

Instantly share code, notes, and snippets.

@aewhite
Created January 8, 2022 14:03
Show Gist options
  • Save aewhite/14db960f9e832bce4041bf185cdc9615 to your computer and use it in GitHub Desktop.
Save aewhite/14db960f9e832bce4041bf185cdc9615 to your computer and use it in GitHub Desktop.
Inverse of extract Image Patches for TF 2.x
# Solution based on https://stackoverflow.com/a/51785735/278836
import tensorflow as tf
def extract_patches(images):
return tf.image.extract_patches(
images,
(1, 3, 3, 1),
(1, 1, 1, 1),
(1, 1, 1, 1),
padding="VALID")
@tf.function
def extract_patches_inverse(shape, patches):
_x = tf.zeros(shape)
_y = extract_patches(_x)
grad = tf.gradients(_y, _x)[0]
return tf.gradients(_y, _x, grad_ys=patches)[0] / grad
def main():
shape = (10, 28, 28, 3)
images = tf.random.uniform(shape, 0.0, 1.0)
patches = extract_patches(images)
images_reconstructed = extract_patches_inverse(shape, patches)
error = tf.reduce_mean(tf.math.squared_difference(images_reconstructed, images))
print(error)
if __name__ == '__main__':
main()
@ajwl27
Copy link

ajwl27 commented Apr 11, 2022

Thanks for sharing this, but it just won't work for me for some reason. The example you give works fine, but actual images don't. Any ideas what might be causing this behaviour?

The code I ran:

import tensorflow as tf
import cv2
def extract_patches(images):
    return tf.image.extract_patches(
        images,
        (1, 3, 3, 1),
        (1, 1, 1, 1),
        (1, 1, 1, 1),
        padding="VALID")

@tf.function
def extract_patches_inverse(shape, patches):
    _x = tf.zeros(shape)
    _y = extract_patches(_x)
    grad = tf.gradients(_y, _x)[0]
    return tf.gradients(_y, _x, grad_ys=patches)[0] / grad


image_name = "/mnt/dump2/test.png"
img = cv2.imread(image_name)
img = tf.expand_dims(img,0) # To create the batch information
shape = img.shape

patches = extract_patches(img)
images_reconstructed = extract_patches_inverse(shape, patches)
error = tf.reduce_mean(tf.math.squared_difference(images_reconstructed, images))
print(error)

The error generated:


---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-3-b937d1cb8f44> in <module>
     23 
     24 patches = extract_patches(img)
---> 25 images_reconstructed = extract_patches_inverse(shape, patches)
     26 error = tf.reduce_mean(tf.math.squared_difference(images_reconstructed, images))
     27 print(error)

/usr/local/lib/python3.8/dist-packages/tensorflow/python/util/traceback_utils.py in error_handler(*args, **kwargs)
    151     except Exception as e:
    152       filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153       raise e.with_traceback(filtered_tb) from None
    154     finally:
    155       del filtered_tb

/usr/local/lib/python3.8/dist-packages/tensorflow/python/eager/function.py in _make_input_signature_hashable(elem)
    127     # TFE_Py_EncodeArg weakrefs arguments it does not recognize, and we expect
    128     # all recognized types to be hashable.
--> 129     assert isinstance(elem, weakref.ReferenceType)
    130     v = elem()
    131 

AssertionError: 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment