Skip to content

Instantly share code, notes, and snippets.

from jax import device_put
import numpy as np
size = 5000
x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
from functools import partial
@partial(jax.jit, static_argnums=(0,))
def f(boolean, x):
return -x if boolean else x
f(True, 1)
@jax.jit
def f(boolean, x):
return -x if boolean else x
f(True, 1)
# ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
def sum_logistic(x):
print("printed x:", x)
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(6.)
print(jax.make_jaxpr(sum_logistic)(x_small))
test_fn_jit = jax.jit(test_fn)
%timeit test_fn_jit().block_until_ready()
# best of 5: 4.54 µs per loop
def test_fn(sample_rate=3000,frequency=3):
x = jnp.arange(sample_rate)
y = np.sin(2*jnp.pi*frequency * (frequency/sample_rate))
return jnp.dot(x,y)
%timeit test_fn()
# best of 5: 76.1 µs per loop
new_scores_array = scores_array.at[0:3].set([20,40,90])
new_scores_array
# DeviceArray([20, 40, 90, 30, 25], dtype=int32)
scores = [50,60,70,30,25]
scores_array = jnp.array(scores)
scores_array[0:3] = [20,40,90]
# TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment.
# JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[]
# method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
try:
jnp.sum([1, 2, 3])
except TypeError as e:
print(f"TypeError: {e}")
# TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.
matrix = matrix.reshape(4,4)
jnp.max(matrix)
jnp.argmax(matrix)
jnp.min(matrix)
jnp.argmin(matrix)
jnp.sum(matrix)
jnp.sqrt(matrix)
matrix.transpose()