def img_scaler(image, max_dim = 512):
# Casts a tensor to a new type.
original_shape = tf.cast(tf.shape(image)[:-1], tf.float32)
# Creates a scale constant for the image
scale_ratio = max_dim / max(original_shape)
# Casts a tensor to a new type.
new_shape = tf.cast(original_shape * scale_ratio, tf.int32)
# Resizes the image based on the scaling constant generated above
return tf.image.resize(image, new_shape)
