Created
January 8, 2022 14:03
-
-
Save aewhite/14db960f9e832bce4041bf185cdc9615 to your computer and use it in GitHub Desktop.
Inverse of extract Image Patches for TF 2.x
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Solution based on https://stackoverflow.com/a/51785735/278836 | |
import tensorflow as tf | |
def extract_patches(images): | |
return tf.image.extract_patches( | |
images, | |
(1, 3, 3, 1), | |
(1, 1, 1, 1), | |
(1, 1, 1, 1), | |
padding="VALID") | |
@tf.function | |
def extract_patches_inverse(shape, patches): | |
_x = tf.zeros(shape) | |
_y = extract_patches(_x) | |
grad = tf.gradients(_y, _x)[0] | |
return tf.gradients(_y, _x, grad_ys=patches)[0] / grad | |
def main(): | |
shape = (10, 28, 28, 3) | |
images = tf.random.uniform(shape, 0.0, 1.0) | |
patches = extract_patches(images) | |
images_reconstructed = extract_patches_inverse(shape, patches) | |
error = tf.reduce_mean(tf.math.squared_difference(images_reconstructed, images)) | |
print(error) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for sharing this, but it just won't work for me for some reason. The example you give works fine, but actual images don't. Any ideas what might be causing this behaviour?
The code I ran:
The error generated: