Skip to content

Instantly share code, notes, and snippets.

@matthieubulte
Created May 17, 2024 15:00
Show Gist options
  • Save matthieubulte/c67a1171147e309bbe54b8cfcd8a345a to your computer and use it in GitHub Desktop.
Save matthieubulte/c67a1171147e309bbe54b8cfcd8a345a to your computer and use it in GitHub Desktop.
import time
import jax
from jax import numpy as jnp
import numpy as np
N=4096
flops_ish = N*N*2*N
for i in range(5):
A = np.random.rand(N,N).astype(np.float64)
B = np.random.rand(N,N).astype(np.float64)
t1 = time.monotonic()
C = np. matmul (A,B)
dt = time.monotonic() - t1
print(f'NP {dt*1000:.2f}us \t {flops_ish/dt * 1e-12:.2f} TFLOP/S')
A = jax.device_put(A)
B = jax.device_put(B)
t1 = time.monotonic()
C = jnp.matmul(A,B).block_until_ready()
dt = time.monotonic () - t1
print(f'JAX {dt*1000:.2f}us \t {flops_ish/dt * 1e-12:.2f} TFLOP/S')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment