This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import numpy as np | |
from numpy.lib.stride_tricks import as_strided as ast | |
def chunk_data(data,window_size,overlap_size=0,flatten_inside_window=True): | |
assert data.ndim == 1 or data.ndim == 2 | |
if data.ndim == 1: | |
data = data.reshape((-1,1)) | |
# get the number of overlapping windows that fit into the data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# distutils: language = c++ | |
from libcpp.vector cimport vector | |
def test_memviews_same(a,b): | |
cdef vector[double[:]] aa = a | |
cdef vector[double[:]] bb = b | |
aa[0][0] = 10. | |
bb[0][0] = 20. | |
def test_memviews_diff(a,b): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
from jax import core | |
from jax.util import safe_map, safe_zip | |
import jax.linear_util as lu | |
map = safe_map | |
zip = safe_zip | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from jax import core | |
# A primitive is just a name to which we associate rules. | |
sincos_p = core.Primitive('sincos') | |
# A primitive's "bind" is how it gets applied, in a way that interacts with the | |
# trace/transform machinery. As a convention we wrap them in Python functions | |
# like this: | |
def sincos(x): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import jax | |
import jax.numpy as np | |
from jax.scipy.special import logsumexp | |
from jax import lax, random | |
from jax import jit, grad | |
def log_normalizer(params, seq): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from jax.interpreters import ad | |
from jax.interpreters import partial_eval as pe | |
from jax import custom_transforms | |
from jax import core | |
from jax import grad | |
@custom_transforms | |
def f(x, y): | |
return x**2 + 3 * y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
### numpy version | |
import numpy as onp | |
x = onp.zeros((10, 2)) | |
x[3:5] = 5. | |
print x | |
# [[0. 0.] | |
# [0. 0.] | |
# [0. 0.] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import partial | |
import numpy.random as npr | |
import jax.numpy as np | |
from jax import lax | |
from jax import grad, pjit, papply | |
### set up some synthetic data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
set PATH "$HOME"/bin /opt/local/Library/Frameworks/Python.framework/Versions/Current/bin/\ | |
/opt/local/bin /opt/local/sbin $PATH | |
set CDPATH . "$HOME" $CDPATH | |
set fish_greeting "" | |
function fish_user_key_bindings | |
bind \e\[1\;9A 'history-token-search-backward' | |
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import types | |
def closure_conversion(f): | |
code, globs, freevars = f.func_code, f.func_globals, f.func_code.co_freevars | |
env = dict(zip(freevars, (c.cell_contents for c in f.func_closure))) | |
make_cell = lambda val: (lambda: val).func_closure[0] # different in PY3 | |
def f_maker(env): | |
closure = tuple(make_cell(env[name]) for name in freevars) | |
return types.FunctionType(code, globs, closure=closure) | |
return f_maker, env |