Skip to content

Instantly share code, notes, and snippets.

@jaymody
Created October 5, 2022 03:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jaymody/a14f04813243e84a9360c3095d9da474 to your computer and use it in GitHub Desktop.
Save jaymody/a14f04813243e84a9360c3095d9da474 to your computer and use it in GitHub Desktop.
Comparing jax code runtimes with jax.array vs np.array
import jax
import jax.numpy as jnp
def forward_fn(params, X):
for W, b in params[:-1]:
X = jax.nn.relu(X @ W + b)
final_W, final_b = params[-1]
return X @ final_W + final_b
def initialize_params(key, input_dim, hidden_dims, output_dim):
sizes = [input_dim] + hidden_dims + [output_dim]
keys = jax.random.split(key, len(sizes) - 1)
return [
(jax.random.normal(k, (n_in, n_out)), jnp.zeros((n_out,)))
for k, n_in, n_out in zip(keys, sizes[:-1], sizes[1:])
]
def loss_fn(params, X, y):
# forward pass
unnormalized_probs = forward_fn(params, X)
# cross entropy loss
batch_size = unnormalized_probs.shape[0]
num_classes = unnormalized_probs.shape[-1]
log_probs = jax.nn.log_softmax(unnormalized_probs, axis=-1)
labels = jax.nn.one_hot(y, num_classes)
loss = jnp.sum(labels * -log_probs) / batch_size
return loss
@jax.jit
def update(params, X, y, lr):
# compute loss and gradient
loss, grad = jax.value_and_grad(loss_fn)(params, X, y)
# good ole vanilla stochastic gradient descent
params = jax.tree_map(lambda w, g: w - lr * g, params, grad)
return loss, params
def train(params, X, y, batch_size, lr):
for i in range(0, len(X), batch_size):
loss, params = update(
params,
X[i : i + batch_size],
y[i : i + batch_size],
lr,
)
# print(f"loss at step {i} = {loss}")
def main(conversion):
import random
import time
import numpy as np
# create dummy data to simulate mnist
dummy_X = [[random.random() for _ in range(784)] for _ in range(60000)]
dummy_y = [random.randint(0, 9) for _ in range(60000)]
# test convert and train times
conversions = {
"np.array": lambda x: np.array(x),
"np.asarray": lambda x: np.asarray(x),
"jnp.array": lambda x: jnp.array(x),
"jnp.array + np.array": lambda x: jnp.array(np.array(x)),
"jnp.array + np.asarray": lambda x: jnp.array(np.asarray(x)),
"jnp.asarray": lambda x: jnp.asarray(x),
"jnp.asarray + np.array": lambda x: jnp.asarray(np.array(x)),
"jnp.asarray + np.asarray": lambda x: jnp.asarray(np.asarray(x)),
}
conversion_func = conversions[conversion]
# initialize params
params = initialize_params(jax.random.PRNGKey(123), 784, [128, 64], 10)
# convert to np or jax array
convert_start_time = time.time()
X, y = conversion_func(dummy_X), conversion_func(dummy_y)
convert_time = time.time() - convert_start_time
# run update at least once so it jit compiles
jit_start_time = time.time()
update(params, X[:64], y[:64], 1e-3)
jit_time = time.time() - jit_start_time
# train
train_start_time = time.time()
train(params, X, y, 64, 1e-3)
train_time = time.time() - train_start_time
print("conversion =", conversion)
print("convert_times =", convert_time)
print("jit_time =", jit_time)
print("train_times =", train_time)
if __name__ == "__main__":
import sys
main(sys.argv[1])
## no jit
# conversion = np.array
# convert_times = 1.390981674194336
# train_times = 12.141575813293457
# conversion = np.asarray
# convert_times = 1.2834157943725586
# train_times = 11.19111704826355
# conversion = jnp.array
# convert_times = 95.75779509544373
# train_times = 21.965157985687256
# conversion = jnp.array + np.array
# convert_times = 1.3955588340759277
# train_times = 22.44120502471924
# conversion = jnp.array + np.asarray
# convert_times = 1.3871350288391113
# train_times = 22.535207986831665
# conversion = jnp.asarray
# convert_times = 89.70011687278748
# train_times = 21.396696090698242
# conversion = jnp.asarray + np.array
# convert_times = 1.3765108585357666
# train_times = 21.16973900794983
# conversion = jnp.asarray + np.asarray
# convert_times = 1.3343868255615234
# train_times = 21.99706506729126
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment