Skip to content

Instantly share code, notes, and snippets.

@justheuristic
Last active August 1, 2021 08:52
Show Gist options
  • Save justheuristic/1118a14a798b2b6d47789f7e6f511abd to your computer and use it in GitHub Desktop.
Save justheuristic/1118a14a798b2b6d47789f7e6f511abd to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jheuristic/.local/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
" from ._conv import register_converters as _register_converters\n"
]
}
],
"source": [
"import tensorflow as tf\n",
"import tfnn\n",
"from tfnn.layers import Dense\n",
"from concrete_gate import ConcreteGate\n",
"\n",
"class MyAutoencoder:\n",
" def __init__(self, name, inp_size, hid_size, **kwargs):\n",
" with tf.variable_scope(name):\n",
" self.first = Dense('first', inp_size, hid_size, activ=tf.tanh)\n",
" self.gate = ConcreteGate('gate', shape=[1, hid_size], **kwargs)\n",
" self.second = Dense('second', hid_size, inp_size, activ=lambda x:x)\n",
" \n",
" def __call__(self, x):\n",
" return self.second(self.gate(self.first(x)))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"env: CUDA_VISIBLE_DEVICES=\n"
]
}
],
"source": [
"%env CUDA_VISIBLE_DEVICES=\n",
"tf.reset_default_graph()\n",
"sess = tf.InteractiveSession()\n",
"\n",
"my_net = MyAutoencoder('network', 64, 100, l0_penalty=1e-3, hard=False)\n",
"x = tf.random_normal([5, 64])\n",
"x_rec = my_net(x)\n",
"loss = tf.reduce_mean(tf.squared_difference(x, x_rec))\n",
"\n",
"reg = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))\n",
"# or manually: reg = my_net.gate.get_penalty()\n",
"\n",
"step = tf.train.AdamOptimizer(learning_rate=1e-2).minimize(loss + reg)\n",
"\n",
"sess.run(tf.global_variables_initializer())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"rate_of_nonzero_activations = 0.66\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7fc3bc7d7a20>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 10000/10000 [00:10<00:00, 913.38it/s]\n"
]
}
],
"source": [
"from tqdm import trange\n",
"from IPython.display import clear_output\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"loss_history, reg_history = [], []\n",
"\n",
"for t in trange(10000):\n",
" loss_t, reg_t, _ = sess.run([loss, reg, step])\n",
" loss_history.append(loss_t)\n",
" reg_history.append(reg_t)\n",
" if t % 1000 == 0:\n",
" clear_output(True)\n",
" num_nonzero = sess.run(my_net.gate.get_sparsity_rate())\n",
" print(\"rate_of_nonzero_activations =\", num_nonzero)\n",
" \n",
" plt.subplot(1,2,1)\n",
" plt.plot(loss_history, label='loss')\n",
" plt.legend()\n",
" plt.subplot(1,2,2)\n",
" plt.plot(reg_history, label='reg')\n",
" plt.legend()\n",
" plt.show()\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[1., 0., 0., 1., 1., 1., 0., 1., 0., 1., 0., 1., 1., 1., 0., 1.,\n",
" 1., 0., 1., 1., 0., 1., 1., 0., 1., 0., 1., 1., 0., 0., 1., 0.,\n",
" 1., 1., 1., 1., 1., 1., 0., 1., 0., 1., 1., 0., 1., 1., 0., 0.,\n",
" 1., 1., 0., 1., 1., 1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 1.,\n",
" 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
" 0., 0., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 0., 0., 0.,\n",
" 0., 1., 1., 0.]], dtype=float32)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# gate values\n",
"sess.run(my_net.gate.get_gates(is_train=False))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0.02172022, -0. , 0. , 0.05970415, -0.03229865,\n",
" -0.01130804, 0. , -0.02124784, -0. , 0.02389894,\n",
" -0. , 0.01233097, 0.00665306, -0.00404538, -0. ,\n",
" 0.00455578, 0.00603381, 0. , 0.02199194, -0.03769897,\n",
" -0. , -0.00714258, 0.05062252, -0. , -0.02904983,\n",
" 0. , 0.049015 , -0.00445654, 0. , -0. ,\n",
" -0.01899791, -0. , 0.01498988, -0.00361649, -0.01428505,\n",
" 0.0200627 , 0.00454252, -0.02961647, 0. , -0.04975929,\n",
" -0. , 0.00315028, -0.03230491, 0. , -0.03524412,\n",
" 0.01989217, 0. , 0. , 0.00102865, -0.01898451,\n",
" -0. , -0.04948218, -0.03765366, 0.00472582, -0. ,\n",
" -0. , -0. , 0. , 0.01454964, -0.03137352,\n",
" -0.00582995, -0.04133054, -0. , -0.01280909, 0.01734644,\n",
" -0.02355555, 0.04760959, 0.02267313, 0.00704623, -0.05425708,\n",
" 0. , 0.01986646, -0.02759497, -0.05245164, -0.03126373,\n",
" -0.00284979, 0.02161182, 0.02729553, -0.04157891, -0.04360487,\n",
" -0. , -0. , 0.0135207 , -0.00187147, 0.0609222 ,\n",
" -0.07456608, 0.02430784, 0.00149097, 0.02840535, 0. ,\n",
" 0.07865206, -0.04211108, -0. , 0. , 0. ,\n",
" -0. , -0. , -0.05287053, 0.0263629 , -0. ],\n",
" [ 0.02172022, -0. , 0. , 0.05970415, -0.03229865,\n",
" -0.01130804, 0. , -0.02124784, -0. , 0.02389894,\n",
" -0. , 0.01233097, 0.00665306, -0.00404538, -0. ,\n",
" 0.00455578, 0.00603381, 0. , 0.02199194, -0.03769897,\n",
" -0. , -0.00714258, 0.05062252, -0. , -0.02904983,\n",
" 0. , 0.049015 , -0.00445654, 0. , -0. ,\n",
" -0.01899791, -0. , 0.01498988, -0.00361649, -0.01428505,\n",
" 0.0200627 , 0.00454252, -0.02961647, 0. , -0.04975929,\n",
" -0. , 0.00315028, -0.03230491, 0. , -0.03524412,\n",
" 0.01989217, 0. , 0. , 0.00102865, -0.01898451,\n",
" -0. , -0.04948218, -0.03765366, 0.00472582, -0. ,\n",
" -0. , -0. , 0. , 0.01454964, -0.03137352,\n",
" -0.00582995, -0.04133054, -0. , -0.01280909, 0.01734644,\n",
" -0.02355555, 0.04760959, 0.02267313, 0.00704623, -0.05425708,\n",
" 0. , 0.01986646, -0.02759497, -0.05245164, -0.03126373,\n",
" -0.00284979, 0.02161182, 0.02729553, -0.04157891, -0.04360487,\n",
" -0. , -0. , 0.0135207 , -0.00187147, 0.0609222 ,\n",
" -0.07456608, 0.02430784, 0.00149097, 0.02840535, 0. ,\n",
" 0.07865206, -0.04211108, -0. , 0. , 0. ,\n",
" -0. , -0. , -0.05287053, 0.0263629 , -0. ],\n",
" [ 0.02172022, -0. , 0. , 0.05970415, -0.03229865,\n",
" -0.01130804, 0. , -0.02124784, -0. , 0.02389894,\n",
" -0. , 0.01233097, 0.00665306, -0.00404538, -0. ,\n",
" 0.00455578, 0.00603381, 0. , 0.02199194, -0.03769897,\n",
" -0. , -0.00714258, 0.05062252, -0. , -0.02904983,\n",
" 0. , 0.049015 , -0.00445654, 0. , -0. ,\n",
" -0.01899791, -0. , 0.01498988, -0.00361649, -0.01428505,\n",
" 0.0200627 , 0.00454252, -0.02961647, 0. , -0.04975929,\n",
" -0. , 0.00315028, -0.03230491, 0. , -0.03524412,\n",
" 0.01989217, 0. , 0. , 0.00102865, -0.01898451,\n",
" -0. , -0.04948218, -0.03765366, 0.00472582, -0. ,\n",
" -0. , -0. , 0. , 0.01454964, -0.03137352,\n",
" -0.00582995, -0.04133054, -0. , -0.01280909, 0.01734644,\n",
" -0.02355555, 0.04760959, 0.02267313, 0.00704623, -0.05425708,\n",
" 0. , 0.01986646, -0.02759497, -0.05245164, -0.03126373,\n",
" -0.00284979, 0.02161182, 0.02729553, -0.04157891, -0.04360487,\n",
" -0. , -0. , 0.0135207 , -0.00187147, 0.0609222 ,\n",
" -0.07456608, 0.02430784, 0.00149097, 0.02840535, 0. ,\n",
" 0.07865206, -0.04211108, -0. , 0. , 0. ,\n",
" -0. , -0. , -0.05287053, 0.0263629 , -0. ],\n",
" [ 0.02172022, -0. , 0. , 0.05970415, -0.03229865,\n",
" -0.01130804, 0. , -0.02124784, -0. , 0.02389894,\n",
" -0. , 0.01233097, 0.00665306, -0.00404538, -0. ,\n",
" 0.00455578, 0.00603381, 0. , 0.02199194, -0.03769897,\n",
" -0. , -0.00714258, 0.05062252, -0. , -0.02904983,\n",
" 0. , 0.049015 , -0.00445654, 0. , -0. ,\n",
" -0.01899791, -0. , 0.01498988, -0.00361649, -0.01428505,\n",
" 0.0200627 , 0.00454252, -0.02961647, 0. , -0.04975929,\n",
" -0. , 0.00315028, -0.03230491, 0. , -0.03524412,\n",
" 0.01989217, 0. , 0. , 0.00102865, -0.01898451,\n",
" -0. , -0.04948218, -0.03765366, 0.00472582, -0. ,\n",
" -0. , -0. , 0. , 0.01454964, -0.03137352,\n",
" -0.00582995, -0.04133054, -0. , -0.01280909, 0.01734644,\n",
" -0.02355555, 0.04760959, 0.02267313, 0.00704623, -0.05425708,\n",
" 0. , 0.01986646, -0.02759497, -0.05245164, -0.03126373,\n",
" -0.00284979, 0.02161182, 0.02729553, -0.04157891, -0.04360487,\n",
" -0. , -0. , 0.0135207 , -0.00187147, 0.0609222 ,\n",
" -0.07456608, 0.02430784, 0.00149097, 0.02840535, 0. ,\n",
" 0.07865206, -0.04211108, -0. , 0. , 0. ,\n",
" -0. , -0. , -0.05287053, 0.0263629 , -0. ],\n",
" [ 0.02172022, -0. , 0. , 0.05970415, -0.03229865,\n",
" -0.01130804, 0. , -0.02124784, -0. , 0.02389894,\n",
" -0. , 0.01233097, 0.00665306, -0.00404538, -0. ,\n",
" 0.00455578, 0.00603381, 0. , 0.02199194, -0.03769897,\n",
" -0. , -0.00714258, 0.05062252, -0. , -0.02904983,\n",
" 0. , 0.049015 , -0.00445654, 0. , -0. ,\n",
" -0.01899791, -0. , 0.01498988, -0.00361649, -0.01428505,\n",
" 0.0200627 , 0.00454252, -0.02961647, 0. , -0.04975929,\n",
" -0. , 0.00315028, -0.03230491, 0. , -0.03524412,\n",
" 0.01989217, 0. , 0. , 0.00102865, -0.01898451,\n",
" -0. , -0.04948218, -0.03765366, 0.00472582, -0. ,\n",
" -0. , -0. , 0. , 0.01454964, -0.03137352,\n",
" -0.00582995, -0.04133054, -0. , -0.01280909, 0.01734644,\n",
" -0.02355555, 0.04760959, 0.02267313, 0.00704623, -0.05425708,\n",
" 0. , 0.01986646, -0.02759497, -0.05245164, -0.03126373,\n",
" -0.00284979, 0.02161182, 0.02729553, -0.04157891, -0.04360487,\n",
" -0. , -0. , 0.0135207 , -0.00187147, 0.0609222 ,\n",
" -0.07456608, 0.02430784, 0.00149097, 0.02840535, 0. ,\n",
" 0.07865206, -0.04211108, -0. , 0. , 0. ,\n",
" -0. , -0. , -0.05287053, 0.0263629 , -0. ]],\n",
" dtype=float32)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# gated activations\n",
"sess.run(my_net.gate(my_net.first(x), is_train=False), {x: tf.zeros(x.shape).eval()})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
import tensorflow as tf
from warnings import warn
import tfnn
class ConcreteGate:
"""
A gate made of stretched concrete distribution (using experimental Stretchable Concrete™)
Can be applied to sparsify neural network activations or weights.
Usage example: https://gist.github.com/justheuristic/1118a14a798b2b6d47789f7e6f511abd
:param shape: shape of gate variable. can be broadcasted.
e.g. if you want to apply gate to tensor [batch, length, units] over units axis,
your shape should be [1, 1, units]
:param temperature: concrete sigmoid temperature, should be in (0, 1] range
lower values yield better approximation to actual discrete gate but train longer
:param stretch_limits: min and max value of gate before it is clipped to [0, 1]
min value should be negative in order to compute l0 penalty as in https://arxiv.org/pdf/1712.01312.pdf
however, you can also use tf.nn.sigmoid(log_a) as regularizer if min, max = 0, 1
:param l0_penalty: coefficient on the regularizer that minimizes l0 norm of gated value
:param l2_penalty: coefficient on the regularizer that minimizes l2 norm of gated value
:param eps: a small additive value used to avoid NaNs
:param hard: if True, gates are binarized to {0, 1} but backprop is still performed as if they were concrete
:param local_rep: if True, samples a different gumbel noise tensor for each sample in batch,
by default, noise is sampled using shape param as size.
"""
def __init__(self, name, shape, temperature=0.33, stretch_limits=(-0.1, 1.1),
l0_penalty=0.0, l2_penalty=0.0, eps=1e-6, hard=False, local_rep=False):
self.name = name
self.temperature, self.stretch_limits, self.eps = temperature, stretch_limits, eps
self.l0_penalty, self.l2_penalty = l0_penalty, l2_penalty
self.hard, self.local_rep = hard, local_rep
with tf.variable_scope(name):
self.log_a = tfnn.ops.get_model_variable("log_a", shape=shape)
def __call__(self, values, is_train=None, axis=None, reg_collection=tf.GraphKeys.REGULARIZATION_LOSSES):
""" applies gate to values, if is_train, adds regularizer to reg_collection """
is_train = tfnn.ops.is_dropout_enabled() if is_train is None else is_train
gates = self.get_gates(is_train, shape=tf.shape(values) if self.local_rep else None)
if self.l0_penalty != 0 or self.l2_penalty != 0:
reg = self.get_penalty(values=values, axis=axis)
if is_train:
tf.add_to_collection(reg_collection, tf.identity(reg, name='concrete_gate_reg'))
return values * gates
def get_gates(self, is_train, shape=None):
""" samples gate activations in [0, 1] interval """
low, high = self.stretch_limits
with tf.name_scope(self.name):
if is_train:
shape = tf.shape(self.log_a) if shape is None else shape
noise = tf.random_uniform(shape, self.eps, 1.0 - self.eps)
concrete = tf.nn.sigmoid((tf.log(noise) - tf.log(1 - noise) + self.log_a) / self.temperature)
else:
concrete = tf.nn.sigmoid(self.log_a)
stretched_concrete = concrete * (high - low) + low
clipped_concrete = tf.clip_by_value(stretched_concrete, 0, 1)
if self.hard:
hard_concrete = tf.to_float(tf.greater(clipped_concrete, 0.5))
clipped_concrete = clipped_concrete + tf.stop_gradient(hard_concrete - clipped_concrete)
return clipped_concrete
def get_penalty(self, values=None, axis=None):
"""
Computes l0 and l2 penalties. For l2 penalty one must also provide the sparsified values
(usually activations or weights) before they are multiplied by the gate
Returns the regularizer value that should to be MINIMIZED (negative logprior)
"""
if self.l0_penalty == self.l2_penalty == 0:
warn("get_penalty() is called with both penalties set to 0")
low, high = self.stretch_limits
assert low < 0.0, "p_gate_closed can be computed only if lower stretch limit is negative"
with tf.name_scope(self.name):
# compute p(gate_is_closed) = cdf(stretched_sigmoid < 0)
p_open = tf.nn.sigmoid(self.log_a - self.temperature * tf.log(-low / high))
p_open = tf.clip_by_value(p_open, self.eps, 1.0 - self.eps)
total_reg = 0.0
if self.l0_penalty != 0:
if values != None:
p_open += tf.zeros_like(values) # broadcast shape to account for values
l0_reg = self.l0_penalty * tf.reduce_sum(p_open, axis=axis)
total_reg += tf.reduce_mean(l0_reg)
if self.l2_penalty != 0:
assert values is not None
l2_reg = 0.5 * self.l2_penalty * p_open * tf.reduce_sum(values ** 2, axis=axis)
total_reg += tf.reduce_mean(l2_reg)
return total_reg
def get_sparsity_rate(self, is_train=False):
""" Computes the fraction of gates which are now active (non-zero) """
is_nonzero = tf.not_equal(self.get_gates(is_train), 0.0)
return tf.reduce_mean(tf.to_float(is_nonzero))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment