Skip to content

Instantly share code, notes, and snippets.

@ihincks
Last active July 10, 2018 20:42
Show Gist options
  • Save ihincks/0b636988959f1226c2107e7a78516aff to your computer and use it in GitHub Desktop.
Save ihincks/0b636988959f1226c2107e7a78516aff to your computer and use it in GitHub Desktop.
Implements a bijection between density matrices (unit trace positive operators) and real space R^(d^2-1)
from __future__ import division
import numpy as np
from scipy.special import expit, logit
def complex_contract(x, y):
"""
Shrinks z=x+iy to the unit disk using half-expit.
"""
r = np.sqrt(x**2 + y**2)
phi = np.arctan2(y, x)
return (2 * expit(r) - 1) * np.exp(1j * phi)
def complex_expand(z):
"""
Inverse function to `complex_contract`.
"""
r = logit(np.abs(z) / 2 + 0.5)
phi = np.angle(z)
return r * np.cos(phi), r * np.sin(phi)
def vec_to_state(x, normalized=True):
"""
The last axis of the array `x` is mapped to square density
matrices.
The last axis of `x` is storing elements of a lower cholesky factorization
of the density matrices, and has length `d^2` or `d^2-1`.
The first `d-1` are the diagonal elements (for `normalized=True`, `d` otherwise),
the next `d*(d-1)/2` are the real parts of the off diagonals, and
the rest are the imaginary parts of the off diagonals. The stacking convention
is decided by `np.tril_indices`.
The unit trace condition is met by using a stick-breaking like trick, where
a given value of `x` doesn't directly store a cholessky factor matrix element,
but rather, it stores (after being shrunk) the fraction of the remaining
budget of the trace until `1`. Note that the trace of a positive operator
is the square Hilbert-Schmidt norm of its cholesky factor, so this is a
2-norm version of stick breaking.
:param np.ndarray x: Array of shape `(...,d^2-1)` for `normalized=True` or
an array of shape `(...,d^2)` for `normalized=False`
:param bool normalized: Whether density matrices are trace 1.
"""
if normalized:
# define a few integers
n = x.shape[:-1]
dim = int(np.sqrt(x.shape[-1] + 1))
a = dim - 1
b = int(a + dim * (dim - 1) / 2)
c = int(b + dim * (dim - 1) / 2)
# extract components of the vector and contract them all to the unit disk
# note that cholesky factors are only in bijection with positive operators
# if we enforce that the diagonal is positive.
diag = expit(x[...,:a])
z = np.concatenate([diag, complex_contract(x[...,a:b], x[...,b:c])], axis=-1)
# compute the last remaining diagonal element by using a 2-norm stick breaking
# procudure, which still works for complex numbers
w = np.ones(n + (int(dim*(dim+1)/2),), dtype=np.complex128)
w[...,1:] = np.cumprod(np.sqrt(1 - np.abs(z)**2), axis=-1)
w[...,:-1] *= z
# the above puts the new element at the end; move it to the start
w = np.roll(w, 1, axis=-1)
# reshape w into the bottom triangle of matrices, putting zeroes elsewhere
chol = np.zeros(n + (dim, dim),dtype=np.complex128)
chol[(Ellipsis,) + np.diag_indices(dim)] = w[...,:dim]
chol[(Ellipsis,) + np.tril_indices(dim,k=-1)] = w[...,dim:]
# turn the cholesky factors into positive operators
return np.matmul(chol,chol.transpose(range(x.ndim-1)+[-1,-2]).conj())
else:
# this is easy to do relative to the other case
raise NotImplemented()
def state_to_vec(p, normalized=True):
"""
Inverse function of `vec_to_state`.
"""
if normalized:
# just do the opposite of everything that is done in vec_to_state
n = p.shape[:-2]
dim = p.shape[-1]
a = dim - 1
b = int(a + dim * (dim - 1) / 2)
c = int(b + dim * (dim - 1) / 2)
# compute cholesky factors and populate w with the non-zero
# side of the triangle.
chol = np.linalg.cholesky(p)
w = np.empty(n + (int(dim*(dim+1)/2-1),), dtype=np.complex128)
w[...,:a] = chol[(Ellipsis,) + np.diag_indices(dim)][...,1:dim]
w[...,a:] = chol[(Ellipsis,) + np.tril_indices(dim,k=-1)]
# the following simple formula turns out to be the inverse of the stick-breaking
# used in vec_to_state
w[...,1:] /= np.sqrt(1 - np.cumsum(np.abs(w[...,:-1])**2, axis=-1))
# now we just have to unconstrain from the unit disk.
x = np.empty(n + (dim**2-1,))
x[...,:a] = logit(np.real(w[...,:a]))
x[...,a:b], x[...,b:c] = complex_expand(w[...,a:])
return x
else:
# this is easy to do relative to the other case
raise NotImplemented()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment