Skip to content

Instantly share code, notes, and snippets.

@taku-y
Last active March 5, 2016 10:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save taku-y/0c7082a2ab7eadacb2d6 to your computer and use it in GitHub Desktop.
Save taku-y/0c7082a2ab7eadacb2d6 to your computer and use it in GitHub Desktop.
import numpy as np
from pymc3 import Model, MvNormal
import theano
import theano.tensor as T
def run_check_logdet():
print('alpha0=1e4')
check_logdet(alpha0=1e4)
print('alpha0=1e3')
check_logdet(alpha0=1e3)
print('alpha0=1e2')
check_logdet(alpha0=1e2)
def check_logdet(alpha0=1e4, n_groups=100):
n_sensors = 200
n_vertices = 100
n_timepoints = 2
n_groups = n_groups
cov = np.eye(n_sensors)
bs = np.random.randn(n_timepoints, n_sensors)
gs = [np.random.randn(n_sensors, n_vertices) for _ in range(n_groups)]
gags = [(1.0 / alpha0) * g.dot(g.T) for g in gs]
gag = np.sum(np.stack(gags), axis=0)
bcov = cov + gag
prec = np.linalg.inv(bcov)
eigs, _ = np.linalg.eig(prec)
with Model() as model1:
# Likelihood for observations
MvNormal('l', mu=0.0, tau=T.as_tensor(prec), observed=bs)
with Model() as model2:
# Likelihood for observations
MyMvNormal('l', mu=0.0, tau=T.as_tensor(prec), observed=bs)
print('logp of with log(det()) = {}'.format(model1.logp()))
print('logp of with logabsdet() = {}'.format(model2.logp()))
from theano.gof import Op, Apply
# The code is adopted from https://github.com/Theano/Theano/pull/3959
class LogAbsDet(Op):
"""Computes the logarithm of absolute determinant of a square
matrix M, log(abs(det(M))), on CPU. Avoids det(M) overflow/
underflow.
TODO: add GPU code!
"""
def make_node(self, x):
x = theano.tensor.as_tensor_variable(x)
o = theano.tensor.scalar(dtype=x.dtype)
return Apply(self, [x], [o])
def perform(self, node, inputs, outputs):
try:
(x,) = inputs
(z,) = outputs
s = np.linalg.svd(x, compute_uv=False)
log_abs_det = np.sum(np.log(np.abs(s)))
z[0] = np.asarray(log_abs_det, dtype=x.dtype)
except Exception:
print('Failed to compute logabsdet of {}.'.format(x))
raise
def grad(self, inputs, g_outputs):
gz, = g_outputs
x, = inputs
return [gz * T.nlinalg.matrix_inverse(x).T]
def __str__(self):
return "LogAbsDet"
logabsdet = LogAbsDet()
from scipy import stats
from pymc3.distributions.distribution import Continuous, draw_values, generate_samples
class MyMvNormal(Continuous):
r"""
Multivariate normal log-likelihood.
.. math::
f(x \mid \pi, T) =
\frac{|T|^{1/2}}{(2\pi)^{1/2}}
\exp\left\{ -\frac{1}{2} (x-\mu)^{\prime} T (x-\mu) \right\}
======== ==========================
Support :math:`x \in \mathbb{R}^k`
Mean :math:`\mu`
Variance :math:`T^{-1}`
======== ==========================
Parameters
----------
mu : array
Vector of means.
tau : array
Precision matrix.
"""
def __init__(self, mu, tau, *args, **kwargs):
super(MyMvNormal, self).__init__(*args, **kwargs)
self.mean = self.median = self.mode = self.mu = mu
self.tau = tau
def random(self, point=None, size=None):
mu, tau = draw_values([self.mu, self.tau], point=point)
def _random(mean, cov, size=None):
# FIXME: cov here is actually precision?
return stats.multivariate_normal.rvs(
mean, cov, None if size == mean.shape else size)
samples = generate_samples(_random,
mean=mu, cov=tau,
dist_shape=self.shape,
broadcast_shape=mu.shape,
size=size)
return samples
def logp(self, value):
mu = self.mu
tau = self.tau
delta = value - mu
k = tau.shape[0]
# result = k * T.log(2 * np.pi) + T.log(1./det(tau))
result = k * T.log(2 * np.pi) - logabsdet(tau)
result += (delta.dot(tau) * delta).sum(axis=delta.ndim - 1)
return -1/2. * result
#In [1]: import sys; sys.path.insert(0, '/Users/taku-y/git/github/pymc3')
#
#In [2]: import test
#
#In [3]: test.run_check_logdet()
#alpha0=1e4
#logp of with log(det()) = -613.857479372
#logp of with logabsdet() = -613.857479372
#alpha0=1e3
#logp of with log(det()) = -862.731539233
#logp of with logabsdet() = -862.731539233
#alpha0=1e2
#logp of with log(det()) = -inf
#logp of with logabsdet() = -1290.73899232
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment