This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
matrix = jnp.arange(1,17) | |
matrix[20] | |
# DeviceArray(16, dtype=int32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
example_trees = [ | |
[1, 'a', object()], | |
(1, (2, 3), ()), | |
[1, {'k1': 2, 'k2': (3, 4)}, 5], | |
{'a': 2, 'b': (2, 3)}, | |
jnp.array([1, 2, 3]), | |
] | |
# Let's see how many leaves they have: | |
for pytree in example_trees: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# set this config at the begining of the program | |
from jax.config import config | |
config.update("jax_enable_x64", True) | |
x = jnp.float64(1.25844) | |
x | |
# DeviceArray(1.25844, dtype=float64) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
x = jnp.float64(1.25844) | |
# /usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py:1806: UserWarning: Explicitly requested dtype float64 requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more. | |
# lax_internal._check_user_dtype_supported(dtype, "array") | |
# DeviceArray(1.25844, dtype=float32) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from jax.config import config | |
config.update("jax_debug_nans", True) | |
jnp.divide(0.0,0.0) | |
# FloatingPointError: invalid value (nan) encountered in div |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
jnp.divide(0.0,0.0) | |
# DeviceArray(nan, dtype=float32, weak_type=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
x = np.arange(5) | |
w = np.array([2., 3., 4.]) | |
def convolve(x, w): | |
output = [] | |
for i in range(1, len(x)-1): | |
output.append(jnp.dot(x[i-1:i+2], w)) | |
return jnp.array(output) | |
convolve(x, w) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
mat = jax.random.normal(key, (150, 100)) | |
batched_x = jax.random.normal(key, (10, 100)) | |
def apply_matrix(v): | |
return jnp.dot(mat, v) | |
@jax.jit | |
def vmap_batched_apply_matrix(v_batched): | |
return jax.vmap(apply_matrix)(v_batched) | |
print('Auto-vectorized with vmap') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@jax.jit | |
def sum_logistic(x): | |
return jnp.sum(1.0 / (1.0 + jnp.exp(-x))),(x + 1) | |
x_small = jnp.arange(6.) | |
derivative_fn = jax.grad(sum_logistic, has_aux=True) | |
print(derivative_fn(x_small)) | |
# (DeviceArray([0.25 , 0.19661194, 0.10499357, 0.04517666, 0.01766271, | |
# 0.00664806], dtype=float32), DeviceArray([1., 2., 3., 4., 5., 6.], dtype=float32)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@jax.jit | |
def sum_logistic(x): | |
return jnp.sum(1.0 / (1.0 + jnp.exp(-x))) | |
x_small = jnp.arange(6.) | |
derivative_fn = jax.grad(sum_logistic) | |
print(derivative_fn(x_small)) |