Skip to content

Instantly share code, notes, and snippets.

@PhilipVinc
Last active January 21, 2021 14:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save PhilipVinc/24d7e46abe6a6edd72293f63803b4f09 to your computer and use it in GitHub Desktop.
Save PhilipVinc/24d7e46abe6a6edd72293f63803b4f09 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
# import
import jax
from jax import numpy as jnp
import flax
from flax import linen as nn
import numpy as np
from functools import partial
from typing import Any
from jax.experimental.stax import Dense
from jax.experimental import stax
# stax-like batch-Sum reduction layer
def SumLayer():
def init_fun(rng, input_shape):
output_shape = (-1, 1)
return output_shape, ()
def apply_fun(params, inputs, **kwargs):
return inputs.sum(axis=-1)
return init_fun, apply_fun
SumLayer = SumLayer()
# expose jax-stax as a flax module
class JaxModel():
def __init__(self, ifun, afun):
self.ifun = ifun
self.afun = afun
def init(self, keys, inp):
o, pars = self.ifun(keys['params'], inp.shape)
return flax.core.FrozenDict({'params':pars})
def apply(self, w, x):
return self.afun(w['params'], x)
# wrap jax-stax as a flax module
class JaxWrapModule(nn.Module):
"""
Wrapper for Jax bare modules made by a init_fun and apply_fun
"""
init_fun: Any
apply_fun: Any
@nn.compact
def __call__(self, x):
if jnp.ndim(x) == 1:
x = jnp.atleast_1d(x)
pars = self.param(
"jax", lambda rng, shape: self.init_fun(rng, shape)[1], x.shape
)
return self.apply_fun(pars, x)
# a simple rbm
class FlaxRBM(nn.Module):
dtype: Any = np.float32
activation: Any = nn.activation.sigmoid
alpha: int = 1
use_bias: bool = True
@nn.compact
def __call__(self, x):
x = nn.Dense(
name="Dense",
features=self.alpha * x.shape[-1],
dtype=self.dtype,
use_bias=self.use_bias,
)(x)
x = self.activation(x)
return jnp.sum(x, axis=-1)
def build_JaxModel(L, alpha):
ifun, afun = stax.serial(stax.Dense(alpha * L), stax.Sigmoid, SumLayer)
ma = JaxModel(ifun, afun)
x = jnp.zeros((1, L))
w = ma.init({'params':jax.random.PRNGKey(0)}, x)
return ma, w
def build_WrapJax(L, alpha):
ifun, afun = stax.serial(stax.Dense(alpha * L), stax.Sigmoid, SumLayer)
ma = JaxWrapModule(ifun, afun)
x = jnp.zeros((1, L))
w = ma.init({'params':jax.random.PRNGKey(0)}, x)
return ma, w
def build_Flax(L, alpha):
ma = FlaxRBM(alpha=alpha)
x = jnp.zeros((1, L))
w = ma.init({'params':jax.random.PRNGKey(0)}, x)
return ma, w
def build_Flax(L, alpha):
ma = FlaxRBM(alpha=alpha)
x = jnp.zeros((1, L))
w = ma.init({'params':jax.random.PRNGKey(0)}, x)
return ma, w
@partial(jax.jit, static_argnums=0)
def apply_fun(ma, w, x):
return ma.apply(w, x)
# benchmark dispatch overhead
L = 1
alpha = 1
batches = 1
j_ma, j_w = build_JaxModel(L, alpha)
w_ma, w_w = build_WrapJax(L, alpha)
f_ma, f_w = build_Flax(L, alpha)
x = jax.random.uniform(jax.random.PRNGKey(12), (batches,L))
for i in range(100000):
_ = apply_fun(f_ma, f_w, x).block_until_ready()
# import
import jax
from jax import numpy as jnp
import flax
from flax import linen as nn
import numpy as np
from functools import partial
from typing import Any
from jax.experimental.stax import Dense
from jax.experimental import stax
# stax-like batch-Sum reduction layer
def SumLayer():
def init_fun(rng, input_shape):
output_shape = (-1, 1)
return output_shape, ()
def apply_fun(params, inputs, **kwargs):
return inputs.sum(axis=-1)
return init_fun, apply_fun
SumLayer = SumLayer()
# expose jax-stax as a flax module
class JaxModel():
def __init__(self, ifun, afun):
self.ifun = ifun
self.afun = afun
def init(self, keys, inp):
o, pars = self.ifun(keys['params'], inp.shape)
return flax.core.FrozenDict({'params':pars})
def apply(self, w, x):
return self.afun(w['params'], x)
# wrap jax-stax as a flax module
class JaxWrapModule(nn.Module):
"""
Wrapper for Jax bare modules made by a init_fun and apply_fun
"""
init_fun: Any
apply_fun: Any
@nn.compact
def __call__(self, x):
if jnp.ndim(x) == 1:
x = jnp.atleast_1d(x)
pars = self.param(
"jax", lambda rng, shape: self.init_fun(rng, shape)[1], x.shape
)
return self.apply_fun(pars, x)
# a simple rbm
class FlaxRBM(nn.Module):
dtype: Any = np.float32
activation: Any = nn.activation.sigmoid
alpha: int = 1
use_bias: bool = True
@nn.compact
def __call__(self, x):
x = nn.Dense(
name="Dense",
features=self.alpha * x.shape[-1],
dtype=self.dtype,
use_bias=self.use_bias,
)(x)
x = self.activation(x)
return jnp.sum(x, axis=-1)
def build_JaxModel(L, alpha):
ifun, afun = stax.serial(stax.Dense(alpha * L), stax.Sigmoid, SumLayer)
ma = JaxModel(ifun, afun)
x = jnp.zeros((1, L))
w = ma.init({'params':jax.random.PRNGKey(0)}, x)
return ma, w
def build_WrapJax(L, alpha):
ifun, afun = stax.serial(stax.Dense(alpha * L), stax.Sigmoid, SumLayer)
ma = JaxWrapModule(ifun, afun)
x = jnp.zeros((1, L))
w = ma.init({'params':jax.random.PRNGKey(0)}, x)
return ma, w
def build_Flax(L, alpha):
ma = FlaxRBM(alpha=alpha)
x = jnp.zeros((1, L))
w = ma.init({'params':jax.random.PRNGKey(0)}, x)
return ma, w
def build_Flax(L, alpha):
ma = FlaxRBM(alpha=alpha)
x = jnp.zeros((1, L))
w = ma.init({'params':jax.random.PRNGKey(0)}, x)
return ma, w
@partial(jax.jit, static_argnums=0)
def apply_fun(ma, w, x):
return ma.apply(w, x)
# benchmark dispatch overhead
L = 1
alpha = 1
batches = 1
j_ma, j_w = build_JaxModel(L, alpha)
w_ma, w_w = build_WrapJax(L, alpha)
f_ma, f_w = build_Flax(L, alpha)
x = jax.random.uniform(jax.random.PRNGKey(12), (batches,L))
for i in range(100000):
_ = apply_fun(j_ma, j_w, x).block_until_ready()
# import
import jax
from jax import numpy as jnp
import flax
from flax import linen as nn
import numpy as np
from functools import partial
from typing import Any
from jax.experimental.stax import Dense
from jax.experimental import stax
# stax-like batch-Sum reduction layer
def SumLayer():
def init_fun(rng, input_shape):
output_shape = (-1, 1)
return output_shape, ()
def apply_fun(params, inputs, **kwargs):
return inputs.sum(axis=-1)
return init_fun, apply_fun
SumLayer = SumLayer()
# expose jax-stax as a flax module
class JaxModel():
def __init__(self, ifun, afun):
self.ifun = ifun
self.afun = afun
def init(self, keys, inp):
o, pars = self.ifun(keys['params'], inp.shape)
return flax.core.FrozenDict({'params':pars})
def apply(self, w, x):
return self.afun(w['params'], x)
# wrap jax-stax as a flax module
class JaxWrapModule(nn.Module):
"""
Wrapper for Jax bare modules made by a init_fun and apply_fun
"""
init_fun: Any
apply_fun: Any
@nn.compact
def __call__(self, x):
if jnp.ndim(x) == 1:
x = jnp.atleast_1d(x)
pars = self.param(
"jax", lambda rng, shape: self.init_fun(rng, shape)[1], x.shape
)
return self.apply_fun(pars, x)
# a simple rbm
class FlaxRBM(nn.Module):
dtype: Any = np.float32
activation: Any = nn.activation.sigmoid
alpha: int = 1
use_bias: bool = True
@nn.compact
def __call__(self, x):
x = nn.Dense(
name="Dense",
features=self.alpha * x.shape[-1],
dtype=self.dtype,
use_bias=self.use_bias,
)(x)
x = self.activation(x)
return jnp.sum(x, axis=-1)
def build_JaxModel(L, alpha):
ifun, afun = stax.serial(stax.Dense(alpha * L), stax.Sigmoid, SumLayer)
ma = JaxModel(ifun, afun)
x = jnp.zeros((1, L))
w = ma.init({'params':jax.random.PRNGKey(0)}, x)
return ma, w
def build_WrapJax(L, alpha):
ifun, afun = stax.serial(stax.Dense(alpha * L), stax.Sigmoid, SumLayer)
ma = JaxWrapModule(ifun, afun)
x = jnp.zeros((1, L))
w = ma.init({'params':jax.random.PRNGKey(0)}, x)
return ma, w
def build_Flax(L, alpha):
ma = FlaxRBM(alpha=alpha)
x = jnp.zeros((1, L))
w = ma.init({'params':jax.random.PRNGKey(0)}, x)
return ma, w
def build_Flax(L, alpha):
ma = FlaxRBM(alpha=alpha)
x = jnp.zeros((1, L))
w = ma.init({'params':jax.random.PRNGKey(0)}, x)
return ma, w
@partial(jax.jit, static_argnums=0)
def apply_fun(ma, w, x):
return ma.apply(w, x)
# benchmark dispatch overhead
L = 1
alpha = 1
batches = 1
j_ma, j_w = build_JaxModel(L, alpha)
w_ma, w_w = build_WrapJax(L, alpha)
f_ma, f_w = build_Flax(L, alpha)
x = jax.random.uniform(jax.random.PRNGKey(12), (batches,L))
for i in range(100000):
_ = apply_fun(w_ma, w_w, x).block_until_ready()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment