Skip to content

Instantly share code, notes, and snippets.

matrix = jnp.arange(1,17)
matrix[20]
# DeviceArray(16, dtype=int32)
example_trees = [
[1, 'a', object()],
(1, (2, 3), ()),
[1, {'k1': 2, 'k2': (3, 4)}, 5],
{'a': 2, 'b': (2, 3)},
jnp.array([1, 2, 3]),
]
# Let's see how many leaves they have:
for pytree in example_trees:
# set this config at the begining of the program
from jax.config import config
config.update("jax_enable_x64", True)
x = jnp.float64(1.25844)
x
# DeviceArray(1.25844, dtype=float64)
x = jnp.float64(1.25844)
# /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:1806: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
# lax_internal._check_user_dtype_supported(dtype, "array")
# DeviceArray(1.25844, dtype=float32)
from jax.config import config
config.update("jax_debug_nans", True)
jnp.divide(0.0,0.0)
# FloatingPointError: invalid value (nan) encountered in div
jnp.divide(0.0,0.0)
# DeviceArray(nan, dtype=float32, weak_type=True)
x = np.arange(5)
w = np.array([2., 3., 4.])
def convolve(x, w):
output = []
for i in range(1, len(x)-1):
output.append(jnp.dot(x[i-1:i+2], w))
return jnp.array(output)
convolve(x, w)
mat = jax.random.normal(key, (150, 100))
batched_x = jax.random.normal(key, (10, 100))
def apply_matrix(v):
return jnp.dot(mat, v)
@jax.jit
def vmap_batched_apply_matrix(v_batched):
return jax.vmap(apply_matrix)(v_batched)
print('Auto-vectorized with vmap')
@jax.jit
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x))),(x + 1)
x_small = jnp.arange(6.)
derivative_fn = jax.grad(sum_logistic, has_aux=True)
print(derivative_fn(x_small))
# (DeviceArray([0.25 , 0.19661194, 0.10499357, 0.04517666, 0.01766271,
# 0.00664806], dtype=float32), DeviceArray([1., 2., 3., 4., 5., 6.], dtype=float32))
@jax.jit
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(6.)
derivative_fn = jax.grad(sum_logistic)
print(derivative_fn(x_small))