Skip to content

Instantly share code, notes, and snippets.

In [2]: import flax
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-2-5854b141ca7d> in <module>
----> 1 import flax
~/.virtualenvs/flax2/lib/python3.7/site-packages/flax/__init__.py in <module>
32
33 # Allow `import flax`; `flax.nn.[...]`, and the same for `flax.optim.[...]`
---> 34 from . import nn
import numpy as onp
import jax.numpy as jnp
class ArrayContainer:
def __init__(self, value):
self.value = value
def __array__(self):
return self.value
@bayerj
bayerj / jax.py
Created February 12, 2020 14:03
Learning step with non-diff'able parameters.
import jax
import jax.experimental.optimizers
from jax.api import _check_inexact_input_vjp
from jax import tree_util as tu
import numpy as onp
def make_resilient_step(loss, sample_params, split, join, optimizer):
sample_learn_params, non_learn_params = split(sample_params)
@bayerj
bayerj / gist:6064edd404e65189105dbd9f7945b3d3
Created February 12, 2020 14:03
jax learning step wrt non-diff'able parameterds like integers.
import jax
import jax.experimental.optimizers
from jax.api import _check_inexact_input_vjp
from jax import tree_util as tu
import numpy as onp
def make_resilient_step(loss, sample_params, split, join, optimizer):
sample_learn_params, non_learn_params = split(sample_params)
@bayerj
bayerj / __main__.py
Created March 22, 2018 08:28
Skeleton for command line scripts using docopt.
"""
Usage:
run something somewhere
Options:
-h | --help Show this screen.
"""
import sys
import tensorflow as tf
import edward as ed
from collections import OrderedDict
def edges(rvs):
rvs = [*rvs]
edges = []
visited = set()
@bayerj
bayerj / guidelines.MD
Created September 28, 2017 07:19
Some coding recommendation guidelines for scientifc programming with Python

"Rules are to make you think before you break them." – Terry Pratchett

Some guidelines to follow for coding to increase collaboration efficiency.

  • Name variables meaningfully.
    • It pays of to think hard about it, because coming up with names is O(1) while trying to decipher names is O(“amount of people that try to understand the code”).
    • Don’t be afraid of long variable names; it is much better to have a long obvious name than a short name that does not “flow” while reading code.
    • Avoid greek letter names such as epsilon, eta or gamma. The right variable names are often offset, step_rate or decay_factor. Even if the paper you are implementing is using certain greek names, a follow up won’t. Also, mathematical formulas in papers and source code are different things with different objectives.
class _SequentialAutoregressive(tf.contrib.distributions.Distribution):
def __init__(self, f_process, base_dist_cls, initial_dist,
n_time_steps=None,
dtype=tf.float32,
name='sequential_auto_regressive'):
self.f_process = f_process
self.base_dist_cls = base_dist_cls
self.initial_dist = initial_dist
self.n_time_steps = n_time_steps
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-2-913092c38985> in <module>()
1 x = tf.placeholder('float32', [None, 784])
----> 2 noised = x + tf.random_normal(x.get_shape())
/Users/bayerj/anaconda/lib/python2.7/site-packages/tensorflow/python/ops/random_ops.pyc in random_normal(shape, mean, stddev, dtype, seed, name)
44 """
45 with ops.op_scope([shape, mean, stddev], name, "random_normal") as name:
---> 46 shape_tensor = _ShapeTensor(shape)
<script src="http://cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS_HTML"></script>
$$test$$