Skip to content

Instantly share code, notes, and snippets.

@aseyboldt
Last active September 14, 2017 03:12
Show Gist options
  • Save aseyboldt/1054cf6d6b871041914c601c1efa11ae to your computer and use it in GitHub Desktop.
Save aseyboldt/1054cf6d6b871041914c601c1efa11ae to your computer and use it in GitHub Desktop.
tensorflow-pymc3-experiment
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import collections"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# TODO We can use `tf.distributions` for the most part,\n",
"# but we might want to new ones, and maybe also add methods\n",
"# to them. Using raw `tf.distributions` objects for now.\n",
"class Distribution:\n",
" pass\n",
"\n",
"\n",
"class RandomVariable:\n",
" def __init__(self, name, dist, transform, dtype, shape):\n",
" if shape is None:\n",
" raise ValueError()\n",
" self._name = name\n",
" self._dtype = dtype\n",
" self._shape = shape\n",
" self._transform = transform\n",
" self._dist = dist\n",
" # check domain of dist vs transform.onto\n",
"\n",
" @property\n",
" def name(self):\n",
" return self._name\n",
" \n",
" @property\n",
" def dtype(self):\n",
" return self._dtype\n",
" \n",
" @property\n",
" def shape(self):\n",
" return self._shape\n",
" \n",
" def _logp(self, value, *, jacterms):\n",
" # TODO Pass transform to dists?\n",
" # That would allow specialized implementations.\n",
" transform = self._transform\n",
" if transform is not None:\n",
" value = transform.forward(value)\n",
" \n",
" logp = self._dist.log_prob(value)\n",
" if transform is not None and jacterms:\n",
" logp = logp + transform.jacdet(value)\n",
" return logp\n",
"\n",
" return self._dist.logp(self.var, self.trafo)\n",
" \n",
" def _logp_sum(self, value, *, jacterms):\n",
" # TODO give the distributions a chance to do this\n",
" # on their own.\n",
" return tf.reduce_sum(self._logp(value, jacterms=jacterms))\n",
"\n",
"\n",
"class Free(RandomVariable):\n",
" def __init__(self, name, dist, transform, dtype, shape):\n",
" super().__init__(name, dist, transform, dtype, shape)\n",
" self._var = tf.placeholder(dtype, shape, name)\n",
" \n",
" @property\n",
" def var(self):\n",
" return self._var\n",
" \n",
" def logp(self, *, jacterms=True):\n",
" return self._logp(self.var, jacterms=jacterms)\n",
" \n",
" def logp_sum(self, *, jacterms=True):\n",
" return self._logp_sum(self.var, jacterms=jacterms)\n",
"\n",
" \n",
"class Observed(RandomVariable):\n",
" def __init__(self, name, dist, transform, dtype, shape, observed):\n",
" super().__init__(name, dist, transform, dtype, shape)\n",
" if not isinstance(observed, Data):\n",
" observed = tf.constant(observed)\n",
" self._observed = observed\n",
" \n",
" @property\n",
" def observed(self):\n",
" return self._observed\n",
" \n",
" def logp(self, *, jacterms=True):\n",
" return self._logp(self.observed, jacterms=jacterms)\n",
"\n",
" def logp_sum(self, *, jacterms=True):\n",
" return self._logp_sum(self.observed, jacterms=jacterms)\n",
"\n",
"\n",
"class Data:\n",
" def __init__(self, name, dtype=None, shape=None, default=None):\n",
" # TODO\n",
" self._var = tf.Variable(dtype, shape, name)\n",
"\n",
"\n",
"class TfModel:\n",
" _context_stack = []\n",
"\n",
" def __init__(self):\n",
" self._free_vars = []\n",
" self._observed_vars = []\n",
" self._graph = tf.Graph()\n",
" self._graph_context = None\n",
" \n",
" def add_free_var(self, var):\n",
" self._free_vars.append(var)\n",
" \n",
" def add_observed_var(self, var):\n",
" self._observed_vars.append(var)\n",
" \n",
" def __enter__(self):\n",
" TfModel._context_stack.append(self)\n",
" self._graph_context = self._graph.as_default()\n",
" self._graph_context.__enter__()\n",
" return self\n",
" \n",
" def __exit__(self, *args, **kwargs):\n",
" old = TfModel._context_stack.pop()\n",
" self._graph_context.__exit__(*args, **kwargs)\n",
" assert old is self\n",
" \n",
" def _logp_sum(self, *, jacterms=True, reduce=True):\n",
" with self:\n",
" vars = self._free_vars + self._observed_vars\n",
" logp_free = [var.logp_sum(jacterms=jacterms)\n",
" for var in vars]\n",
" # TODO optional reduce?\n",
" return tf.reduce_sum(logp_free, name='logp__')\n",
" \n",
" def logp_function(self, *, jacterms=True, session=None):\n",
" raise NotImplementedError()\n",
" \n",
" def logp_dlogp_function(self, grad_vars=None, *, target=None, dtype=None,\n",
" jacterms=True, session_config=None, data=None):\n",
" if grad_vars is None:\n",
" # TODO check dtype\n",
" grad_vars = self._free_vars.copy()\n",
" cost = self._logp_sum(jacterms=jacterms)\n",
" return ValueGradFunction(cost, grad_vars, None, target=target,\n",
" dtype=dtype,\n",
" data=data, session_config=session_config)\n",
"\n",
"\n",
"def model_from_context(model):\n",
" if model is not None:\n",
" return model\n",
" if len(TfModel._context_stack) == 0:\n",
" raise ValueError('No model on context stack.')\n",
" return TfModel._context_stack[-1]\n",
"\n",
"\n",
"VarMap = collections.namedtuple(\n",
" 'VarMap', 'name, slice, shape, dtype')\n",
"\n",
"\n",
"class ArrayOrdering:\n",
" def __init__(self, vars, dtype, *, casting='no', data=None, session=None):\n",
" maps = []\n",
" total = 0\n",
" if session is None:\n",
" session = tf.get_default_session()\n",
" \n",
" for var in vars:\n",
" # TODO check casting\n",
" if var.dtype != dtype:\n",
" raise ValueError()\n",
" name = var.name\n",
" shape = var.shape\n",
" if shape is None or any(dim is None for dim in shape):\n",
" # TODO no session\n",
" shape = session.run(shape, data)\n",
" size = int(np.prod(shape))\n",
" slice_ = slice(total, total + size)\n",
" maps.append(VarMap(name, slice_, shape, var.dtype))\n",
" total += size\n",
" self.size = total\n",
" self.vmap = maps\n",
" self.dtype = dtype\n",
" \n",
" def array_to_dict(self, array):\n",
" vars = {}\n",
" for name, slice_, shape, dtype in slef.vmap:\n",
" # TODO check casting\n",
" vars[name] = array[slice_].astype(dtype)\n",
" return vars\n",
"\n",
" def dict_to_array(self, vars):\n",
" array = np.empty(self.size, dtype=self.dtype)\n",
" for var in vars:\n",
" data = vars[var.name]\n",
" data = np.asarray(data, order='C')\n",
" # TODO check casting\n",
" array[var.slice] = data\n",
" return array\n",
"\n",
"\n",
"class ValueGradFunction:\n",
" def __init__(self, cost, grad_vars, extra_vars=None, target=None,\n",
" dtype=None, casting='no', session_config=None, data=None):\n",
" if extra_vars is None:\n",
" extra_vars = []\n",
" \n",
" old_graph = cost.graph\n",
" cost_name = cost.name\n",
" grad_names = [var.name for var in grad_vars]\n",
" # TODO\n",
" assert not extra_vars\n",
" self._ordering = ArrayOrdering(grad_vars, dtype,\n",
" casting=casting, data=data)\n",
" \n",
" graph_def = old_graph.as_graph_def()\n",
" graph = tf.Graph()\n",
" with graph.as_default():\n",
" array = tf.placeholder(dtype, (self._ordering.size,), name='freeRV_array_')\n",
" self._array = array\n",
" var_slices = {}\n",
" for var in self._ordering.vmap:\n",
" var_slices[var.name] = tf.reshape(array[var.slice], var.shape)\n",
" cost_array, = tf.import_graph_def(graph_def, var_slices, [cost_name])\n",
" #cost_array = cost_array.outputs[0]\n",
" cost_grad = tf.gradients(cost_array, array)\n",
" sess = tf.Session(target=target, graph=graph, config=session_config)\n",
" self._session = sess\n",
" self._cost_array = cost_array\n",
" self._cost_grad = cost_grad\n",
" self._graph = graph\n",
" \n",
" def __call__(self, array):\n",
" return self._session.run([self._cost_array, self._cost_grad],\n",
" {self._array: array})"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# TODO find a nice way to do this semi-automatically for\n",
"# different distributions in `tf.contrib.distributions`.\n",
"def Normal(name, mu, sigma, shape=None, dtype=None, transform=None, observed=None, model=None):\n",
" dist = tf.contrib.distributions.Normal(mu, sigma, name=name + '_dist__')\n",
" if shape is None:\n",
" shape = dist.batch_shape\n",
" if dtype is None:\n",
" dtype = tf.float32\n",
" model = model_from_context(model)\n",
"\n",
" if observed is None:\n",
" var = Free(name, dist, transform, dtype, shape)\n",
" model.add_free_var(var)\n",
" return var.var\n",
" else:\n",
" var = Observed(name, dist, transform, dtype, shape, observed)\n",
" model.add_observed_var(var)\n",
" return var.observed"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"model = TfModel()\n",
"\n",
"with model:\n",
" a = Normal('a', 0., 10.)\n",
" c = Normal('c', 0., 5.)\n",
" b = Normal('b', a, 1., observed=[3., 4., 3.5])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"logp_dlogp = model.logp_dlogp_function(dtype='float32')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"833 µs ± 45.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
],
"source": [
"%timeit logp_dlogp(np.zeros(2))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"800 µs ± 33.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
]
}
],
"source": [
"logp = model._logp_sum()\n",
"logp_grad = tf.gradients(logp, [var.var for var in model._free_vars])\n",
"\n",
"sess = tf.Session(graph=model._graph)\n",
"a_ = np.array(3.)\n",
"c_ = np.array(4.)\n",
"%timeit sess.run([logp, logp_grad], {a: a_, c: c_})"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment