Skip to content

Instantly share code, notes, and snippets.

View smooth_labels.py
labels = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.smooth_labels(labels,alpha=0.4)
# DeviceArray([0.2 , 0.26, 0.14, 0.2 , 0.2 ], dtype=float32)
View log_cosh.py
predictions = jnp.array([0.50,0.60,0.70,0.30,0.25])
targets = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.log_cosh(predictions,targets)
# DeviceArray([0.04434085, 0.04434085, 0.17013526, 0.00499171, 0.00124949], dtype=float32)
View l2_loss.py
predictions = jnp.array([0.50,0.60,0.70,0.30,0.25])
targets = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.l2_loss(predictions,targets)
# DeviceArray([0.045 , 0.045 , 0.17999998, 0.005 , 0.00125 ], dtype=float32)
View huber_loss.py
logits = jnp.array([0.50,0.60,0.70,0.30,0.25])
labels = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.huber_loss(logits,labels)
# DeviceArray([0.045 , 0.045 , 0.17999998, 0.005 , 0.00125 ], dtype=float32)
View cosine_similarity.py
predictions = jnp.array([0.50,0.60,0.70,0.30,0.25])
targets = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.cosine_similarity(predictions,targets,epsilon=0.5)
# DeviceArray(0.8220514, dtype=float32)
View cosine_distance.py
predictions = jnp.array([0.50,0.60,0.70,0.30,0.25])
targets = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.cosine_distance(predictions,targets,epsilon=0.7)
# DeviceArray(0.4128204, dtype=float32)
View softmax_cross_entropy.py
logits = jnp.array([0.50,0.60,0.70,0.30,0.25])
labels = jnp.array([0.20,0.30,0.10,0.20,0.2])
optax.softmax_cross_entropy(logits,labels)
# DeviceArray(1.6341426, dtype=float32)
View sigmoid_binary_cross_entropy.py
optax.sigmoid_binary_cross_entropy(0.5,0.0)
# DeviceArray(0.974077, dtype=float32, weak_type=True)
View custom_sigmoid_binary_cross_entropy.py
import jax
def custom_sigmoid_binary_cross_entropy(logits, labels):
log_p = jax.nn.log_sigmoid(logits)
log_not_p = jax.nn.log_sigmoid(-logits)
return -labels * log_p - (1. - labels) * log_not_p
custom_sigmoid_binary_cross_entropy(0.5,0.0)
# DeviceArray(0.974077, dtype=float32, weak_type=True)
View load_data.py
import tensorflow as tf
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')
import tensorflow_datasets as tfds
data_dir = '/tmp/tfds'
# Fetch full datasets for evaluation
# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)