Last active
February 18, 2019 19:44
-
-
Save sussillo/76e191ba07bf2e75a4d4945f749549fe to your computer and use it in GitHub Desktop.
Are loops in JAX faster than those in LAX (excluding JIT times)?
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function, division | |
import matplotlib.pyplot as plt | |
import jax.numpy as np | |
from jax import grad, jit, vmap | |
from jax import random | |
import numpy as onp | |
import time | |
def keygen(key, nkeys): | |
"""Generate randomness that JAX can use by splitting the JAX keys. | |
Args: | |
key : the random.PRNGKey for JAX | |
nkeys : how many keys in key generator | |
Returns: | |
2-tuple (new key for further generators, key generator) | |
""" | |
keys = random.split(key, nkeys+1) | |
return keys[0], (k for k in keys[1:]) | |
def random_esn_params(key, u, n, m, tau=1.0, dt=0.1, g=1.0): | |
"""Generate random RNN parameters | |
Arguments: | |
u: dim of the input | |
n: dim of the hidden state | |
m: dim of the output | |
tau: "neuronal" time constant | |
dt: time between Euler integration updates | |
g: scaling of the recurrent matrix in the reservoir | |
Returns: | |
A dictionary of parameters for the ESN. | |
""" | |
key, skeys = keygen(key, 5) | |
hscale = 0.25 | |
ifactor = 1.0 / np.sqrt(u) | |
hfactor = g / np.sqrt(n) | |
pfactor = 1.0 / np.sqrt(n) | |
ffactor = 1.0 # Feedback factor, keep at 1 for now. | |
return {'a0' : random.normal(next(skeys), (n,)) * hscale, | |
'wI' : random.normal(next(skeys), (n,u)) * ifactor, | |
'wR' : random.normal(next(skeys), (n,n)) * hfactor, | |
'wO' : random.normal(next(skeys), (m,n)) * pfactor, | |
'wF' : random.normal(next(skeys), (n,m)) * ffactor, | |
'dt_over_tau' : dt / tau} | |
def esn(x, a, h, z, wI, wR, wF, wO, dtdivtau): | |
"""Run the continuous-time Echostate network one step. | |
da/dt = -a + wI x + wR h + wF z | |
Arguments: | |
x: ndarray of input to ESN | |
a: ndarray of activations (pre nonlinearity) from prev time step | |
h: ndarray of hidden states from prev time step | |
z: ndarray of output from prev time step | |
wI: ndarray, input matrix, shape (n, u) | |
wR: ndarray, recurrent matrix, shape (n, n) | |
wF: ndarray, feedback matrix, shape (n, m) | |
wO: ndarray, output matrix, shape (m, n) | |
dtdivtau: dt / tau | |
Returns: | |
The update to the ESN at this time step. | |
""" | |
dadt = -a + np.dot(wI, x) + np.dot(wR, h) + np.dot(wF, z) | |
a = a + dtdivtau * dadt | |
h = np.tanh(a) | |
z = np.dot(wO, h) | |
return a, h, z | |
def esn_run_jax(params, x_t): | |
"""Run the Echostate network forward a number of steps the length of x_t. | |
This implementation uses JAX to build the outer time loop from basic | |
Python for loop. | |
Arguments: | |
params: dict of ESN params | |
x_t: ndarray of input time series, shape (t, u) | |
Returns: | |
2-tuple ofh_t, z_t, after running ESN | |
""" | |
# per-example predictions | |
a = params['a0'] | |
h = np.tanh(a) | |
wO = params['wO'] | |
z = np.dot(wO, h) | |
h_t = [] | |
z_t = [] | |
wI = params['wI'] | |
wR = params['wR'] | |
wF = params['wF'] | |
dtdivtau = params['dt_over_tau'] | |
for tidx, x in enumerate(x_t): | |
a, h, z = esn(x, a, h, z, wI, wR, wF, wO, dtdivtau) | |
h_t.append(h) | |
z_t.append(z) | |
h_t = np.array(h_t) | |
z_t = np.array(z_t) | |
return h_t, z_t | |
esn_run_jax_jit = jit(esn_run_jax) | |
from jax import lax | |
def esn_run_lax(params, x_t): | |
"""Run the Echostate network ntime steps, where ntime is shape[0] of x_t. | |
This implementation uses LAX to build the outer time loop the LAX | |
fori_loop and dynamic_update_slice functions. | |
Arguments: | |
params: dict of ESN params | |
x_t: ndarray of input time series, shape (t, u) | |
Returns: | |
2-tuple h_t, z_t, after running ESN | |
""" | |
ntime = x_t.shape[0] | |
a0 = params['a0'] | |
h0 = np.tanh(a0) | |
a_t = np.zeros((ntime+1, n)) | |
h_t = np.zeros((ntime+1, n)) | |
z_t = np.zeros((ntime, m)) | |
a_t = lax.dynamic_update_slice(a_t, np.expand_dims(a0, 0), [0, 0]) | |
h_t = lax.dynamic_update_slice(h_t, np.expand_dims(h0, 0), [0, 0]) | |
wI = params['wI'] | |
wR = params['wR'] | |
wF = params['wF'] | |
wO = params['wO'] | |
dtdivtau = params['dt_over_tau'] | |
def esn_body(tidx, inputs): | |
a_t, h_t, z_t = inputs | |
x = x_t[tidx] | |
a = a_t[tidx] | |
h = h_t[tidx] | |
z = z_t[tidx] | |
a, h, z = esn(x, a, h, z, wI, wR, wF, wO, dtdivtau) | |
a_t = lax.dynamic_update_slice(a_t, np.expand_dims(a, 0), [tidx+1, 0]) | |
h_t = lax.dynamic_update_slice(h_t, np.expand_dims(h, 0), [tidx+1, 0]) | |
z_t = lax.dynamic_update_slice(z_t, np.expand_dims(z, 0), [tidx, 0]) | |
return a_t, h_t, z_t | |
a_t, h_t, z_t = lax.fori_loop(0, ntime, esn_body, (a_t, h_t, z_t)) | |
h_t = h_t[1:] # ditch initial condition | |
return h_t, z_t | |
esn_run_lax_jit = jit(esn_run_lax) | |
# Basic parameters of the Echostate networks | |
key = random.PRNGKey(0) | |
T = 30 # total time | |
u = 1 # number of inputs (didn't bother to set up zero, just put in zeros) | |
m = 20 | |
n = 1000 # size of the reservoir in the ESN | |
tau = 1.0 # neuron time constant | |
dt = tau / 10.0 # Euler integration step | |
time = np.arange(0, T, dt) # all time | |
ntime = time.shape[0] # the number of time steps | |
x_t = np.zeros((ntime,u)) # Just a stand-in in folks want a real input later | |
### JAX ### | |
print("") | |
print("JAX speed") | |
# Create the ESN that will trained with FORCE learning. | |
g = 1.5 # Lower g value was shown to be good in the paper for training. | |
alpha = 1e0 # Initial learning rate for RLS | |
params_seed = 100001 | |
print("Params seed %d" %(params_seed)) | |
key = random.PRNGKey(params_seed) | |
params = random_esn_params(key, u, n, m, g=g) | |
# Run the untrained ESN to jit it. | |
h_t, z_t = esn_run_jax_jit(params, x_t) | |
# Let's do a little timing. Here are the speeds of the routines *without* | |
# just-in-time compilation. | |
import time | |
ntorun = 100 | |
start_time = time.time() | |
for i in range(ntorun): | |
esn_run_jax_jit(params, x_t) | |
end_time = (time.time() - start_time)/ntorun | |
print("JAX run {:0.4f} sec".format(end_time)) | |
### LAX ### | |
print("") | |
print("LAX speed") | |
# Run the untrained ESN to jit it. | |
h_t, z_t = esn_run_lax_jit(params, x_t) | |
import time | |
ntorun = 100 | |
start_time = time.time() | |
for i in range(ntorun): | |
esn_run_lax_jit(params, x_t) | |
end_time = (time.time() - start_time)/ntorun | |
print("LAX run {:0.4f} sec".format(end_time)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I'm running with a GPU. I jax versions below. I ran also on Colab with a later jaxlib, with the same results. The output is