Last active
October 17, 2022 14:10
-
-
Save modanesh/2c9c40d183c04c448c1c44a557b17c32 to your computer and use it in GitHub Desktop.
Value of `float(img.mean())` differs before and after `jax_disable_jit`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import cv2 | |
import imageio | |
import jax | |
import os | |
import urllib | |
import numpy as np | |
import jax.numpy as jnp | |
def resize_and_center_crop(image): | |
"""Crops to center of image with padding then scales.""" | |
shape = image.shape | |
image_height = shape[0] | |
image_width = shape[1] | |
padded_center_crop_size = ((224 / (224 + 32)) * | |
np.minimum(image_height, image_width).astype(np.float32)).astype(np.int32) | |
offset_height = ((image_height - padded_center_crop_size) + 1) // 2 | |
offset_width = ((image_width - padded_center_crop_size) + 1) // 2 | |
crop_window = [offset_height, offset_width, | |
padded_center_crop_size, padded_center_crop_size] | |
# image = tf.image.crop_to_bounding_box(image_bytes, *crop_window) | |
image = image[crop_window[0]:crop_window[0] + crop_window[2], crop_window[1]:crop_window[1] + crop_window[3]] | |
return cv2.resize(image, (224, 224), interpolation=cv2.INTER_CUBIC) | |
def normalize(im): | |
# taken from imagenet | |
mean_rgb = (0.485 * 255, 0.456 * 255, 0.406 * 255) | |
stddev_rgb = (0.229 * 255, 0.224 * 255, 0.225 * 255) | |
return (im - np.array(mean_rgb)) / np.array(stddev_rgb) | |
image_path = "dog.jpg" | |
if not os.path.exists(image_path): | |
_ = urllib.request.urlretrieve('https://storage.googleapis.com/perceiver_io/dalmation.jpg', image_path) | |
with open(image_path, 'rb') as f: | |
img = imageio.imread(f) | |
img = resize_and_center_crop(img) | |
img = normalize(img) | |
img = jnp.array(img)[None] | |
before_disabling = float(img.mean()) | |
jax.config.update('jax_disable_jit', True) | |
after_disabling = float(img.mean()) | |
assert before_disabling == after_disabling, f"values are different by: {after_disabling - before_disabling}" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment