Last active
January 21, 2021 14:44
-
-
Save PhilipVinc/24d7e46abe6a6edd72293f63803b4f09 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# import | |
import jax | |
from jax import numpy as jnp | |
import flax | |
from flax import linen as nn | |
import numpy as np | |
from functools import partial | |
from typing import Any | |
from jax.experimental.stax import Dense | |
from jax.experimental import stax | |
# stax-like batch-Sum reduction layer | |
def SumLayer(): | |
def init_fun(rng, input_shape): | |
output_shape = (-1, 1) | |
return output_shape, () | |
def apply_fun(params, inputs, **kwargs): | |
return inputs.sum(axis=-1) | |
return init_fun, apply_fun | |
SumLayer = SumLayer() | |
# expose jax-stax as a flax module | |
class JaxModel(): | |
def __init__(self, ifun, afun): | |
self.ifun = ifun | |
self.afun = afun | |
def init(self, keys, inp): | |
o, pars = self.ifun(keys['params'], inp.shape) | |
return flax.core.FrozenDict({'params':pars}) | |
def apply(self, w, x): | |
return self.afun(w['params'], x) | |
# wrap jax-stax as a flax module | |
class JaxWrapModule(nn.Module): | |
""" | |
Wrapper for Jax bare modules made by a init_fun and apply_fun | |
""" | |
init_fun: Any | |
apply_fun: Any | |
@nn.compact | |
def __call__(self, x): | |
if jnp.ndim(x) == 1: | |
x = jnp.atleast_1d(x) | |
pars = self.param( | |
"jax", lambda rng, shape: self.init_fun(rng, shape)[1], x.shape | |
) | |
return self.apply_fun(pars, x) | |
# a simple rbm | |
class FlaxRBM(nn.Module): | |
dtype: Any = np.float32 | |
activation: Any = nn.activation.sigmoid | |
alpha: int = 1 | |
use_bias: bool = True | |
@nn.compact | |
def __call__(self, x): | |
x = nn.Dense( | |
name="Dense", | |
features=self.alpha * x.shape[-1], | |
dtype=self.dtype, | |
use_bias=self.use_bias, | |
)(x) | |
x = self.activation(x) | |
return jnp.sum(x, axis=-1) | |
def build_JaxModel(L, alpha): | |
ifun, afun = stax.serial(stax.Dense(alpha * L), stax.Sigmoid, SumLayer) | |
ma = JaxModel(ifun, afun) | |
x = jnp.zeros((1, L)) | |
w = ma.init({'params':jax.random.PRNGKey(0)}, x) | |
return ma, w | |
def build_WrapJax(L, alpha): | |
ifun, afun = stax.serial(stax.Dense(alpha * L), stax.Sigmoid, SumLayer) | |
ma = JaxWrapModule(ifun, afun) | |
x = jnp.zeros((1, L)) | |
w = ma.init({'params':jax.random.PRNGKey(0)}, x) | |
return ma, w | |
def build_Flax(L, alpha): | |
ma = FlaxRBM(alpha=alpha) | |
x = jnp.zeros((1, L)) | |
w = ma.init({'params':jax.random.PRNGKey(0)}, x) | |
return ma, w | |
def build_Flax(L, alpha): | |
ma = FlaxRBM(alpha=alpha) | |
x = jnp.zeros((1, L)) | |
w = ma.init({'params':jax.random.PRNGKey(0)}, x) | |
return ma, w | |
@partial(jax.jit, static_argnums=0) | |
def apply_fun(ma, w, x): | |
return ma.apply(w, x) | |
# benchmark dispatch overhead | |
L = 1 | |
alpha = 1 | |
batches = 1 | |
j_ma, j_w = build_JaxModel(L, alpha) | |
w_ma, w_w = build_WrapJax(L, alpha) | |
f_ma, f_w = build_Flax(L, alpha) | |
x = jax.random.uniform(jax.random.PRNGKey(12), (batches,L)) | |
for i in range(100000): | |
_ = apply_fun(f_ma, f_w, x).block_until_ready() | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# import | |
import jax | |
from jax import numpy as jnp | |
import flax | |
from flax import linen as nn | |
import numpy as np | |
from functools import partial | |
from typing import Any | |
from jax.experimental.stax import Dense | |
from jax.experimental import stax | |
# stax-like batch-Sum reduction layer | |
def SumLayer(): | |
def init_fun(rng, input_shape): | |
output_shape = (-1, 1) | |
return output_shape, () | |
def apply_fun(params, inputs, **kwargs): | |
return inputs.sum(axis=-1) | |
return init_fun, apply_fun | |
SumLayer = SumLayer() | |
# expose jax-stax as a flax module | |
class JaxModel(): | |
def __init__(self, ifun, afun): | |
self.ifun = ifun | |
self.afun = afun | |
def init(self, keys, inp): | |
o, pars = self.ifun(keys['params'], inp.shape) | |
return flax.core.FrozenDict({'params':pars}) | |
def apply(self, w, x): | |
return self.afun(w['params'], x) | |
# wrap jax-stax as a flax module | |
class JaxWrapModule(nn.Module): | |
""" | |
Wrapper for Jax bare modules made by a init_fun and apply_fun | |
""" | |
init_fun: Any | |
apply_fun: Any | |
@nn.compact | |
def __call__(self, x): | |
if jnp.ndim(x) == 1: | |
x = jnp.atleast_1d(x) | |
pars = self.param( | |
"jax", lambda rng, shape: self.init_fun(rng, shape)[1], x.shape | |
) | |
return self.apply_fun(pars, x) | |
# a simple rbm | |
class FlaxRBM(nn.Module): | |
dtype: Any = np.float32 | |
activation: Any = nn.activation.sigmoid | |
alpha: int = 1 | |
use_bias: bool = True | |
@nn.compact | |
def __call__(self, x): | |
x = nn.Dense( | |
name="Dense", | |
features=self.alpha * x.shape[-1], | |
dtype=self.dtype, | |
use_bias=self.use_bias, | |
)(x) | |
x = self.activation(x) | |
return jnp.sum(x, axis=-1) | |
def build_JaxModel(L, alpha): | |
ifun, afun = stax.serial(stax.Dense(alpha * L), stax.Sigmoid, SumLayer) | |
ma = JaxModel(ifun, afun) | |
x = jnp.zeros((1, L)) | |
w = ma.init({'params':jax.random.PRNGKey(0)}, x) | |
return ma, w | |
def build_WrapJax(L, alpha): | |
ifun, afun = stax.serial(stax.Dense(alpha * L), stax.Sigmoid, SumLayer) | |
ma = JaxWrapModule(ifun, afun) | |
x = jnp.zeros((1, L)) | |
w = ma.init({'params':jax.random.PRNGKey(0)}, x) | |
return ma, w | |
def build_Flax(L, alpha): | |
ma = FlaxRBM(alpha=alpha) | |
x = jnp.zeros((1, L)) | |
w = ma.init({'params':jax.random.PRNGKey(0)}, x) | |
return ma, w | |
def build_Flax(L, alpha): | |
ma = FlaxRBM(alpha=alpha) | |
x = jnp.zeros((1, L)) | |
w = ma.init({'params':jax.random.PRNGKey(0)}, x) | |
return ma, w | |
@partial(jax.jit, static_argnums=0) | |
def apply_fun(ma, w, x): | |
return ma.apply(w, x) | |
# benchmark dispatch overhead | |
L = 1 | |
alpha = 1 | |
batches = 1 | |
j_ma, j_w = build_JaxModel(L, alpha) | |
w_ma, w_w = build_WrapJax(L, alpha) | |
f_ma, f_w = build_Flax(L, alpha) | |
x = jax.random.uniform(jax.random.PRNGKey(12), (batches,L)) | |
for i in range(100000): | |
_ = apply_fun(j_ma, j_w, x).block_until_ready() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# import | |
import jax | |
from jax import numpy as jnp | |
import flax | |
from flax import linen as nn | |
import numpy as np | |
from functools import partial | |
from typing import Any | |
from jax.experimental.stax import Dense | |
from jax.experimental import stax | |
# stax-like batch-Sum reduction layer | |
def SumLayer(): | |
def init_fun(rng, input_shape): | |
output_shape = (-1, 1) | |
return output_shape, () | |
def apply_fun(params, inputs, **kwargs): | |
return inputs.sum(axis=-1) | |
return init_fun, apply_fun | |
SumLayer = SumLayer() | |
# expose jax-stax as a flax module | |
class JaxModel(): | |
def __init__(self, ifun, afun): | |
self.ifun = ifun | |
self.afun = afun | |
def init(self, keys, inp): | |
o, pars = self.ifun(keys['params'], inp.shape) | |
return flax.core.FrozenDict({'params':pars}) | |
def apply(self, w, x): | |
return self.afun(w['params'], x) | |
# wrap jax-stax as a flax module | |
class JaxWrapModule(nn.Module): | |
""" | |
Wrapper for Jax bare modules made by a init_fun and apply_fun | |
""" | |
init_fun: Any | |
apply_fun: Any | |
@nn.compact | |
def __call__(self, x): | |
if jnp.ndim(x) == 1: | |
x = jnp.atleast_1d(x) | |
pars = self.param( | |
"jax", lambda rng, shape: self.init_fun(rng, shape)[1], x.shape | |
) | |
return self.apply_fun(pars, x) | |
# a simple rbm | |
class FlaxRBM(nn.Module): | |
dtype: Any = np.float32 | |
activation: Any = nn.activation.sigmoid | |
alpha: int = 1 | |
use_bias: bool = True | |
@nn.compact | |
def __call__(self, x): | |
x = nn.Dense( | |
name="Dense", | |
features=self.alpha * x.shape[-1], | |
dtype=self.dtype, | |
use_bias=self.use_bias, | |
)(x) | |
x = self.activation(x) | |
return jnp.sum(x, axis=-1) | |
def build_JaxModel(L, alpha): | |
ifun, afun = stax.serial(stax.Dense(alpha * L), stax.Sigmoid, SumLayer) | |
ma = JaxModel(ifun, afun) | |
x = jnp.zeros((1, L)) | |
w = ma.init({'params':jax.random.PRNGKey(0)}, x) | |
return ma, w | |
def build_WrapJax(L, alpha): | |
ifun, afun = stax.serial(stax.Dense(alpha * L), stax.Sigmoid, SumLayer) | |
ma = JaxWrapModule(ifun, afun) | |
x = jnp.zeros((1, L)) | |
w = ma.init({'params':jax.random.PRNGKey(0)}, x) | |
return ma, w | |
def build_Flax(L, alpha): | |
ma = FlaxRBM(alpha=alpha) | |
x = jnp.zeros((1, L)) | |
w = ma.init({'params':jax.random.PRNGKey(0)}, x) | |
return ma, w | |
def build_Flax(L, alpha): | |
ma = FlaxRBM(alpha=alpha) | |
x = jnp.zeros((1, L)) | |
w = ma.init({'params':jax.random.PRNGKey(0)}, x) | |
return ma, w | |
@partial(jax.jit, static_argnums=0) | |
def apply_fun(ma, w, x): | |
return ma.apply(w, x) | |
# benchmark dispatch overhead | |
L = 1 | |
alpha = 1 | |
batches = 1 | |
j_ma, j_w = build_JaxModel(L, alpha) | |
w_ma, w_w = build_WrapJax(L, alpha) | |
f_ma, f_w = build_Flax(L, alpha) | |
x = jax.random.uniform(jax.random.PRNGKey(12), (batches,L)) | |
for i in range(100000): | |
_ = apply_fun(w_ma, w_w, x).block_until_ready() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment