Skip to content

Instantly share code, notes, and snippets.

@MokkeMeguru
Created November 17, 2019 07:58
Show Gist options
  • Save MokkeMeguru/69f02d1ee837007a04de93aaf7e2b0ff to your computer and use it in GitHub Desktop.
Save MokkeMeguru/69f02d1ee837007a04de93aaf7e2b0ff to your computer and use it in GitHub Desktop.
import tensorflow_probability as tfp
import tensorflow as tf
tfb, tfd = tfp.bijectors, tfp.distributions
class Blockwise2DwithBatch(tfb.Bijector):
"""
shapes:
input [batch_size, h, w]
output [batch_size, h, w]
inverse_log_det_jacobian [batch_size]
forward_log_det_jacobian [batch_size]
I implement the assertion,
assert len(ildj.shape) == 1, 'Your ildj's shape is unexpected'
if len(ildj.shape) == 1 such as tfb.Identity(), and broadcast_det == True,
ildj = tf.broadcast_to(ildj, [x.shape[0]])
"""
def __init__(self,
bijectors,
block_sizes=None,
broadcast_det=True,
validate_args=False,
name=None):
"""
Args:
bijectors: tfb.Bijector's list
block_sizes: block size list which applys bijectors' list.
If None, it is applied in half to each instance.
broadcast_det: bool
if True and some bijector's log_det_jacobian.shape == [],
broadcast to [batch_size]
ex. batch_size = 3, log_det_jacobian = 1 -> [1., 1., 1.,]
validate_args: see. tfb.Bijector's reference
name: the name of this instance
ex1:
bijectors = [tfb.Exp(), tfb.Identity()]
block_sizes = [8, 8]
broadcast_det = False
validate_args = False
name = 'myblockwise3D'
ex2: (same as ex1)
bijectors = [tfb.Exp(), tfb.Identity()]
block_sizes = None
broadcast_det = False
validate_args = False
name = 'myblockwise3D'
"""
if not name:
name = 'blockwise_of_' + '_and_'.join([b.name for b in bijectors])
name = name.replace('/', '')
if not bijectors:
raise ValueError('`bijectors` must not be empty.')
self._bijectors = bijectors
super(Blockwise2DwithBatch, self).__init__(
forward_min_event_ndims=2,
validate_args=validate_args,
name=name,
)
self._bijectors = bijectors
self._block_sizes = block_sizes
self._broadcast_det = broadcast_det
@property
def broadcast_det(self):
return self._broadcast_det
@property
def bijectors(self):
return self._bijectors
@property
def block_sizes(self):
return self._block_sizes
def _forward(self, x):
split_x = (tf.split(x, len(self.bijectors), axis=-1)
if self.block_sizes is None else tf.split(
x, self.block_sizes, axis=-1))
split_y = [b.forward(x_) for b, x_ in zip(self.bijectors, split_x)]
y = tf.concat(split_y, axis=-1)
return y
def _inverse(self, y):
split_y = (tf.split(y, len(self.bijectors), axis=-1)
if self.block_sizes is None else tf.split(
y, self.block_sizes, axis=-1))
split_x = [b.inverse(y_) for b, y_ in zip(self.bijectors, split_y)]
x = tf.concat(split_x, axis=-1)
return x
def _forward_log_det_jacobian(self, x):
split_x = (tf.split(x, len(self.bijectors), axis=-1)
if self.block_sizes is None else tf.split(
x, self.block_sizes, axis=-1))
if not self.broadcast_det:
fldjs = [
b.forward_log_det_jacobian(x_, event_ndims=2)
for b, x_ in zip(self.bijectors, split_x)
]
return sum(fldjs)
else:
fldjs = tf.zeros([x.shape[0]])
for b, x_ in zip(self.bijectors, split_x):
f_ = b.forward_log_det_jacobian(x_, event_ndims=2)
fldjs += f_
assert len(fldjs.shape) == 1, "Your fldjs' shape is unexpected"
return fldjs
def _inverse_log_det_jacobian(self, y):
split_y = (tf.split(y, len(self.bijectors), axis=-1)
if self.block_sizes is None else tf.split(
y, self.block_sizes, axis=-1))
if not self.broadcast_det:
ildjs = [
b.inverse_log_det_jacobian(y_, event_ndims=2)
for b, y_ in zip(self.bijectors, split_y)
]
return sum(ildjs)
else:
ildjs = tf.zeros([y.shape[0]])
for b, y_ in zip(self.bijectors, split_y):
i_ = b.inverse_log_det_jacobian(y_, event_ndims=2)
ildjs += i_
assert len(ildjs.shape) == 1, "Your ildjs' shape is unexpected"
return ildjs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment