Skip to content

Instantly share code, notes, and snippets.

@sussillo
Last active February 18, 2019 19:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sussillo/76e191ba07bf2e75a4d4945f749549fe to your computer and use it in GitHub Desktop.
Save sussillo/76e191ba07bf2e75a4d4945f749549fe to your computer and use it in GitHub Desktop.
Are loops in JAX faster than those in LAX (excluding JIT times)?
from __future__ import print_function, division
import matplotlib.pyplot as plt
import jax.numpy as np
from jax import grad, jit, vmap
from jax import random
import numpy as onp
import time
def keygen(key, nkeys):
"""Generate randomness that JAX can use by splitting the JAX keys.
Args:
key : the random.PRNGKey for JAX
nkeys : how many keys in key generator
Returns:
2-tuple (new key for further generators, key generator)
"""
keys = random.split(key, nkeys+1)
return keys[0], (k for k in keys[1:])
def random_esn_params(key, u, n, m, tau=1.0, dt=0.1, g=1.0):
"""Generate random RNN parameters
Arguments:
u: dim of the input
n: dim of the hidden state
m: dim of the output
tau: "neuronal" time constant
dt: time between Euler integration updates
g: scaling of the recurrent matrix in the reservoir
Returns:
A dictionary of parameters for the ESN.
"""
key, skeys = keygen(key, 5)
hscale = 0.25
ifactor = 1.0 / np.sqrt(u)
hfactor = g / np.sqrt(n)
pfactor = 1.0 / np.sqrt(n)
ffactor = 1.0 # Feedback factor, keep at 1 for now.
return {'a0' : random.normal(next(skeys), (n,)) * hscale,
'wI' : random.normal(next(skeys), (n,u)) * ifactor,
'wR' : random.normal(next(skeys), (n,n)) * hfactor,
'wO' : random.normal(next(skeys), (m,n)) * pfactor,
'wF' : random.normal(next(skeys), (n,m)) * ffactor,
'dt_over_tau' : dt / tau}
def esn(x, a, h, z, wI, wR, wF, wO, dtdivtau):
"""Run the continuous-time Echostate network one step.
da/dt = -a + wI x + wR h + wF z
Arguments:
x: ndarray of input to ESN
a: ndarray of activations (pre nonlinearity) from prev time step
h: ndarray of hidden states from prev time step
z: ndarray of output from prev time step
wI: ndarray, input matrix, shape (n, u)
wR: ndarray, recurrent matrix, shape (n, n)
wF: ndarray, feedback matrix, shape (n, m)
wO: ndarray, output matrix, shape (m, n)
dtdivtau: dt / tau
Returns:
The update to the ESN at this time step.
"""
dadt = -a + np.dot(wI, x) + np.dot(wR, h) + np.dot(wF, z)
a = a + dtdivtau * dadt
h = np.tanh(a)
z = np.dot(wO, h)
return a, h, z
def esn_run_jax(params, x_t):
"""Run the Echostate network forward a number of steps the length of x_t.
This implementation uses JAX to build the outer time loop from basic
Python for loop.
Arguments:
params: dict of ESN params
x_t: ndarray of input time series, shape (t, u)
Returns:
2-tuple ofh_t, z_t, after running ESN
"""
# per-example predictions
a = params['a0']
h = np.tanh(a)
wO = params['wO']
z = np.dot(wO, h)
h_t = []
z_t = []
wI = params['wI']
wR = params['wR']
wF = params['wF']
dtdivtau = params['dt_over_tau']
for tidx, x in enumerate(x_t):
a, h, z = esn(x, a, h, z, wI, wR, wF, wO, dtdivtau)
h_t.append(h)
z_t.append(z)
h_t = np.array(h_t)
z_t = np.array(z_t)
return h_t, z_t
esn_run_jax_jit = jit(esn_run_jax)
from jax import lax
def esn_run_lax(params, x_t):
"""Run the Echostate network ntime steps, where ntime is shape[0] of x_t.
This implementation uses LAX to build the outer time loop the LAX
fori_loop and dynamic_update_slice functions.
Arguments:
params: dict of ESN params
x_t: ndarray of input time series, shape (t, u)
Returns:
2-tuple h_t, z_t, after running ESN
"""
ntime = x_t.shape[0]
a0 = params['a0']
h0 = np.tanh(a0)
a_t = np.zeros((ntime+1, n))
h_t = np.zeros((ntime+1, n))
z_t = np.zeros((ntime, m))
a_t = lax.dynamic_update_slice(a_t, np.expand_dims(a0, 0), [0, 0])
h_t = lax.dynamic_update_slice(h_t, np.expand_dims(h0, 0), [0, 0])
wI = params['wI']
wR = params['wR']
wF = params['wF']
wO = params['wO']
dtdivtau = params['dt_over_tau']
def esn_body(tidx, inputs):
a_t, h_t, z_t = inputs
x = x_t[tidx]
a = a_t[tidx]
h = h_t[tidx]
z = z_t[tidx]
a, h, z = esn(x, a, h, z, wI, wR, wF, wO, dtdivtau)
a_t = lax.dynamic_update_slice(a_t, np.expand_dims(a, 0), [tidx+1, 0])
h_t = lax.dynamic_update_slice(h_t, np.expand_dims(h, 0), [tidx+1, 0])
z_t = lax.dynamic_update_slice(z_t, np.expand_dims(z, 0), [tidx, 0])
return a_t, h_t, z_t
a_t, h_t, z_t = lax.fori_loop(0, ntime, esn_body, (a_t, h_t, z_t))
h_t = h_t[1:] # ditch initial condition
return h_t, z_t
esn_run_lax_jit = jit(esn_run_lax)
# Basic parameters of the Echostate networks
key = random.PRNGKey(0)
T = 30 # total time
u = 1 # number of inputs (didn't bother to set up zero, just put in zeros)
m = 20
n = 1000 # size of the reservoir in the ESN
tau = 1.0 # neuron time constant
dt = tau / 10.0 # Euler integration step
time = np.arange(0, T, dt) # all time
ntime = time.shape[0] # the number of time steps
x_t = np.zeros((ntime,u)) # Just a stand-in in folks want a real input later
### JAX ###
print("")
print("JAX speed")
# Create the ESN that will trained with FORCE learning.
g = 1.5 # Lower g value was shown to be good in the paper for training.
alpha = 1e0 # Initial learning rate for RLS
params_seed = 100001
print("Params seed %d" %(params_seed))
key = random.PRNGKey(params_seed)
params = random_esn_params(key, u, n, m, g=g)
# Run the untrained ESN to jit it.
h_t, z_t = esn_run_jax_jit(params, x_t)
# Let's do a little timing. Here are the speeds of the routines *without*
# just-in-time compilation.
import time
ntorun = 100
start_time = time.time()
for i in range(ntorun):
esn_run_jax_jit(params, x_t)
end_time = (time.time() - start_time)/ntorun
print("JAX run {:0.4f} sec".format(end_time))
### LAX ###
print("")
print("LAX speed")
# Run the untrained ESN to jit it.
h_t, z_t = esn_run_lax_jit(params, x_t)
import time
ntorun = 100
start_time = time.time()
for i in range(ntorun):
esn_run_lax_jit(params, x_t)
end_time = (time.time() - start_time)/ntorun
print("LAX run {:0.4f} sec".format(end_time))
@sussillo
Copy link
Author

sussillo commented Feb 18, 2019

I'm running with a GPU. I jax versions below. I ran also on Colab with a later jaxlib, with the same results. The output is

JAX speed
Params seed 100001
JAX run 0.0160 sec

LAX speed
LAX run 0.0418 sec
> pip show jax
Name: jax
Version: 0.1.19
Summary: Differentiate, compile, and transform Numpy code.
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/sussillo/jax
Requires: numpy, six, protobuf, absl-py, opt-einsum

>pip show jaxlib
Name: jaxlib
Version: 0.1.7
Summary: XLA library for JAX
Home-page: https://github.com/google/jax
Author: JAX team
Author-email: jax-dev@google.com
License: Apache-2.0
Location: /home/sussillo/jax/build
Requires: scipy, numpy, six, protobuf, absl-py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment