Created
November 17, 2019 07:58
-
-
Save MokkeMeguru/69f02d1ee837007a04de93aaf7e2b0ff to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow_probability as tfp | |
import tensorflow as tf | |
tfb, tfd = tfp.bijectors, tfp.distributions | |
class Blockwise2DwithBatch(tfb.Bijector): | |
""" | |
shapes: | |
input [batch_size, h, w] | |
output [batch_size, h, w] | |
inverse_log_det_jacobian [batch_size] | |
forward_log_det_jacobian [batch_size] | |
I implement the assertion, | |
assert len(ildj.shape) == 1, 'Your ildj's shape is unexpected' | |
if len(ildj.shape) == 1 such as tfb.Identity(), and broadcast_det == True, | |
ildj = tf.broadcast_to(ildj, [x.shape[0]]) | |
""" | |
def __init__(self, | |
bijectors, | |
block_sizes=None, | |
broadcast_det=True, | |
validate_args=False, | |
name=None): | |
""" | |
Args: | |
bijectors: tfb.Bijector's list | |
block_sizes: block size list which applys bijectors' list. | |
If None, it is applied in half to each instance. | |
broadcast_det: bool | |
if True and some bijector's log_det_jacobian.shape == [], | |
broadcast to [batch_size] | |
ex. batch_size = 3, log_det_jacobian = 1 -> [1., 1., 1.,] | |
validate_args: see. tfb.Bijector's reference | |
name: the name of this instance | |
ex1: | |
bijectors = [tfb.Exp(), tfb.Identity()] | |
block_sizes = [8, 8] | |
broadcast_det = False | |
validate_args = False | |
name = 'myblockwise3D' | |
ex2: (same as ex1) | |
bijectors = [tfb.Exp(), tfb.Identity()] | |
block_sizes = None | |
broadcast_det = False | |
validate_args = False | |
name = 'myblockwise3D' | |
""" | |
if not name: | |
name = 'blockwise_of_' + '_and_'.join([b.name for b in bijectors]) | |
name = name.replace('/', '') | |
if not bijectors: | |
raise ValueError('`bijectors` must not be empty.') | |
self._bijectors = bijectors | |
super(Blockwise2DwithBatch, self).__init__( | |
forward_min_event_ndims=2, | |
validate_args=validate_args, | |
name=name, | |
) | |
self._bijectors = bijectors | |
self._block_sizes = block_sizes | |
self._broadcast_det = broadcast_det | |
@property | |
def broadcast_det(self): | |
return self._broadcast_det | |
@property | |
def bijectors(self): | |
return self._bijectors | |
@property | |
def block_sizes(self): | |
return self._block_sizes | |
def _forward(self, x): | |
split_x = (tf.split(x, len(self.bijectors), axis=-1) | |
if self.block_sizes is None else tf.split( | |
x, self.block_sizes, axis=-1)) | |
split_y = [b.forward(x_) for b, x_ in zip(self.bijectors, split_x)] | |
y = tf.concat(split_y, axis=-1) | |
return y | |
def _inverse(self, y): | |
split_y = (tf.split(y, len(self.bijectors), axis=-1) | |
if self.block_sizes is None else tf.split( | |
y, self.block_sizes, axis=-1)) | |
split_x = [b.inverse(y_) for b, y_ in zip(self.bijectors, split_y)] | |
x = tf.concat(split_x, axis=-1) | |
return x | |
def _forward_log_det_jacobian(self, x): | |
split_x = (tf.split(x, len(self.bijectors), axis=-1) | |
if self.block_sizes is None else tf.split( | |
x, self.block_sizes, axis=-1)) | |
if not self.broadcast_det: | |
fldjs = [ | |
b.forward_log_det_jacobian(x_, event_ndims=2) | |
for b, x_ in zip(self.bijectors, split_x) | |
] | |
return sum(fldjs) | |
else: | |
fldjs = tf.zeros([x.shape[0]]) | |
for b, x_ in zip(self.bijectors, split_x): | |
f_ = b.forward_log_det_jacobian(x_, event_ndims=2) | |
fldjs += f_ | |
assert len(fldjs.shape) == 1, "Your fldjs' shape is unexpected" | |
return fldjs | |
def _inverse_log_det_jacobian(self, y): | |
split_y = (tf.split(y, len(self.bijectors), axis=-1) | |
if self.block_sizes is None else tf.split( | |
y, self.block_sizes, axis=-1)) | |
if not self.broadcast_det: | |
ildjs = [ | |
b.inverse_log_det_jacobian(y_, event_ndims=2) | |
for b, y_ in zip(self.bijectors, split_y) | |
] | |
return sum(ildjs) | |
else: | |
ildjs = tf.zeros([y.shape[0]]) | |
for b, y_ in zip(self.bijectors, split_y): | |
i_ = b.inverse_log_det_jacobian(y_, event_ndims=2) | |
ildjs += i_ | |
assert len(ildjs.shape) == 1, "Your ildjs' shape is unexpected" | |
return ildjs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment