Skip to content

Instantly share code, notes, and snippets.

@modanesh
Last active October 17, 2022 14:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save modanesh/2c9c40d183c04c448c1c44a557b17c32 to your computer and use it in GitHub Desktop.
Save modanesh/2c9c40d183c04c448c1c44a557b17c32 to your computer and use it in GitHub Desktop.
Value of `float(img.mean())` differs before and after `jax_disable_jit`
import cv2
import imageio
import jax
import os
import urllib
import numpy as np
import jax.numpy as jnp
def resize_and_center_crop(image):
"""Crops to center of image with padding then scales."""
shape = image.shape
image_height = shape[0]
image_width = shape[1]
padded_center_crop_size = ((224 / (224 + 32)) *
np.minimum(image_height, image_width).astype(np.float32)).astype(np.int32)
offset_height = ((image_height - padded_center_crop_size) + 1) // 2
offset_width = ((image_width - padded_center_crop_size) + 1) // 2
crop_window = [offset_height, offset_width,
padded_center_crop_size, padded_center_crop_size]
# image = tf.image.crop_to_bounding_box(image_bytes, *crop_window)
image = image[crop_window[0]:crop_window[0] + crop_window[2], crop_window[1]:crop_window[1] + crop_window[3]]
return cv2.resize(image, (224, 224), interpolation=cv2.INTER_CUBIC)
def normalize(im):
# taken from imagenet
mean_rgb = (0.485 * 255, 0.456 * 255, 0.406 * 255)
stddev_rgb = (0.229 * 255, 0.224 * 255, 0.225 * 255)
return (im - np.array(mean_rgb)) / np.array(stddev_rgb)
image_path = "dog.jpg"
if not os.path.exists(image_path):
_ = urllib.request.urlretrieve('https://storage.googleapis.com/perceiver_io/dalmation.jpg', image_path)
with open(image_path, 'rb') as f:
img = imageio.imread(f)
img = resize_and_center_crop(img)
img = normalize(img)
img = jnp.array(img)[None]
before_disabling = float(img.mean())
jax.config.update('jax_disable_jit', True)
after_disabling = float(img.mean())
assert before_disabling == after_disabling, f"values are different by: {after_disabling - before_disabling}"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment