Skip to content

Instantly share code, notes, and snippets.

@andersbll
Created September 5, 2016 07:12
Show Gist options
  • Save andersbll/dca876a68d278812aeaa46746919406e to your computer and use it in GitHub Desktop.
Save andersbll/dca876a68d278812aeaa46746919406e to your computer and use it in GitHub Desktop.
from copy import copy, deepcopy
import numpy as np
import cudarray as ca
import deeppy as dp
import deeppy.expr as expr
from util import ScaleGradient, WeightedParameter
from ae import GaussianNegLogLikelihood
class AppendSpatially(expr.base.Binary):
def __call__(self, imgs, feats):
self.imgs = imgs
self.feats = feats
self.inputs = [imgs, feats]
return self
def setup(self):
b, c, h, w = self.imgs.out_shape
b_, f = self.feats.out_shape
if b != b_:
raise ValueError('batch size mismatch')
self.out_shape = (b, c+f, h, w)
self.out = ca.empty(self.out_shape)
self.out_grad = ca.empty(self.out_shape)
self.tmp = ca.zeros((b, f, h, w))
def fprop(self):
self.tmp.fill(0.0)
feats = ca.reshape(self.feats.out, self.feats.out.shape + (1, 1))
ca.add(feats, self.tmp, out=self.tmp)
ca.extra.concatenate(self.imgs.out, self.tmp, axis=1, out=self.out)
def bprop(self):
ca.extra.split(self.out_grad, a_size=self.imgs.out_shape[1], axis=1,
out_a=self.imgs.out_grad, out_b=self.tmp)
class ConditionalSequential(expr.Sequential):
def __call__(self, x, y=None):
for op in self.collection:
if isinstance(op, (expr.Concatenate, AppendSpatially)):
if y is None:
raise ValueError('No y given to concatenate with')
x = op(x, y)
else:
x = op(x)
return x
class AEGAN(dp.base.Model, dp.base.CollectionMixin):
def __init__(self, encoder, latent_encoder, decoder, discriminator,
recon_vs_gan_weight=1.0, recon_depth=0, sample_z=True):
self.encoder = encoder
self.latent_encoder = latent_encoder
self.discriminator = discriminator
self.recon_vs_gan_weight = recon_vs_gan_weight
self.recon_depth = recon_depth
self.sample_z = sample_z
self.eps = 1e-4
self.latent_encode = True
self.recon_error = GaussianNegLogLikelihood()
self.decoder = decoder
self.collection = [self.encoder, self.latent_encoder, self.decoder,
self.discriminator]
decoder.params = [WeightedParameter(p, self.recon_vs_gan_weight, -(1.0-self.recon_vs_gan_weight))
for p in decoder.params]
self.decoder_neggrad = deepcopy(decoder)
self.decoder_neggrad.params = [p.share() for p in decoder.params]
self.collection += [self.decoder_neggrad]
if recon_depth > 0:
recon_layers = discriminator.collection[:recon_depth]
print('Reconstruction error at layer #%i: %s'
% (recon_depth, recon_layers[-1].__class__.__name__))
dis_layers = discriminator.collection[recon_depth:]
discriminator.collection = recon_layers
discriminator.params = [WeightedParameter(p, 1.0, 0.0)
for p in discriminator.params]
self.discriminator_recon = deepcopy(discriminator)
self.discriminator_recon.params = [p.share() for p in
discriminator.params]
discriminator.collection += dis_layers
self.collection += [self.discriminator_recon]
def _encode_expr(self, x, batch_size, y):
enc = self.encoder(x, y=y)
z, encoder_loss = self.latent_encoder.encode(enc, batch_size)
return z
def _decode_expr(self, z, batch_size, y):
return self.decoder(z, y=y)
def setup(self, x_shape, y_shape):
batch_size = x_shape[0]
self.x_src = expr.Source(x_shape)
self.y_src = expr.Source(y_shape)
loss = 0
# Encode
enc = self.encoder(self.x_src, self.y_src)
z, self.encoder_loss = self.latent_encoder.encode(enc, batch_size)
loss += self.encoder_loss
# Decode
x_tilde = self.decoder(z, self.y_src)
y = self.y_src
if self.recon_depth > 0:
# Reconstruction error in discriminator
x = expr.Concatenate(axis=0)(x_tilde, self.x_src)
y = expr.Concatenate(axis=0)(self.y_src, self.y_src)
d = self.discriminator_recon(x, y=y)
d = expr.Reshape((batch_size*2, -1))(d)
d_x_tilde, d_x = expr.Slices([batch_size])(d)
loss += self.recon_error(d_x_tilde, d_x)
else:
loss += self.recon_error(x_tilde, self.x_src)
# Discriminate
z = ScaleGradient(0.0)(z)
gen_size = batch_size
if self.sample_z:
gen_size += batch_size
if self.recon_depth == 0:
y = expr.Concatenate(axis=0)(y, self.y_src)
z_samples = self.latent_encoder.samples(batch_size)
z = expr.Concatenate(axis=0)(z, z_samples)
x_tilde = self.decoder_neggrad(z, y=y)
x = expr.Concatenate(axis=0)(self.x_src, x_tilde)
y = expr.Concatenate(axis=0)(y, self.y_src)
if self.dis_weights is not None:
weights = np.ones((batch_size + gen_size, 1))
real_weight = self.dis_weights * (float(batch_size + gen_size)/batch_size) / (float(batch_size + gen_size)/batch_size)
gen_weight = (1-self.dis_weights) * (float(batch_size + gen_size)/gen_size) / (float(batch_size + gen_size)/batch_size)
weights[:batch_size] = real_weight
weights[batch_size:] = gen_weight
self._grad_weights = ca.array(weights)
shape = [(batch_size+gen_size) if i == 0 else 1 for i in range(len(x_shape))]
weights = 1.0 / np.reshape(weights, shape)
self._grad_weights_inv = ca.array(weights)
x = ScaleGradient(self._grad_weights_inv)(x)
d = self.discriminator(x, y=y)
if self.dis_weights is not None:
d = ScaleGradient(self._grad_weights)(d)
d = expr.clip(d, self.eps, 1.0-self.eps)
sign = np.ones((gen_size + batch_size, 1), dtype=ca.float_)
sign[batch_size:] = -1.0
offset = np.zeros_like(sign)
offset[batch_size:] = 1.0
self.gan_loss = expr.log(d*sign + offset)
self._graph = expr.ExprGraph(expr.sum(loss) + expr.sum(-self.gan_loss))
self._graph.out_grad = ca.array(1.0)
self._graph.setup()
@property
def params(self):
enc_params = self.encoder.params + self.latent_encoder.params
dec_params = self.decoder.params
dis_params = self.discriminator.params
return enc_params, dec_params, dis_params
def update(self, x, y):
self.x_src.out = x
self.y_src.out = y
self._graph.fprop()
self._graph.bprop()
encoder_loss = 0
d_x_loss = 0
d_z_loss = 0
encoder_loss = np.array(self.encoder_loss.out)
gan_loss = -np.array(self.gan_loss.out)
batch_size = x.shape[0]
d_x_loss = float(np.mean(gan_loss[:batch_size]))
d_z_loss = float(np.mean(gan_loss[batch_size:]))
return d_x_loss, d_z_loss, encoder_loss
def _batchwise(self, input, expr_fun):
input = dp.input.Input.from_any(input)
src = expr.Source(input.x_shape)
y_src = expr.Source(input.y_shape)
graph = expr.ExprGraph(expr_fun(src, input.batch_size, y=y_src))
graph.setup()
z = []
for batch in input.batches():
src.out = batch['x']
y_src.out = batch['y']
graph.fprop()
z.append(np.array(graph.out))
z = np.concatenate(z)[:input.n_samples]
return z
def encode(self, input):
""" Input to hidden. """
return self._batchwise(input, self._encode_expr)
def decode(self, input):
""" Hidden to input. """
return self._batchwise(input, self._decode_expr)
def likelihood(self, input):
""" Input to hidden. """
return self._batchwise(input, self._likelihood_expr)
class GradientDescent(dp.GradientDescent):
def __init__(self, model, input, learn_rule, margin=0.4, equilibrium=0.68):
super(GradientDescent, self).__init__(model, input, learn_rule)
self.margin = margin
self.equilibrium = equilibrium
def reset(self):
self.input.reset()
self.model.setup(**self.input.shapes)
self.params_enc, self.params_dec, self.params_dis = self.model.params
def states(params):
return [self.learn_rule.init_state(p) for p in params]
self.lstates_enc = states(self.params_enc)
self.lstates_dec = states(self.params_dec)
self.lstates_dis = states(self.params_dis)
def train_epoch(self):
batch_costs = []
for batch in self.input.batches():
real_cost, fake_cost, encoder = self.model.update(**batch)
batch_costs.append((real_cost, fake_cost, encoder))
dec_update = True
dis_update = True
if self.margin is not None:
if real_cost < self.equilibrium - self.margin or \
fake_cost < self.equilibrium - self.margin:
dis_update = False
if real_cost > self.equilibrium + self.margin or \
fake_cost > self.equilibrium + self.margin:
dec_update = False
if not (dec_update or dis_update):
dec_update = True
dis_update = True
for param, state in zip(self.params_enc, self.lstates_enc):
self.learn_rule.step(param, state)
if dec_update:
for param, state in zip(self.params_dec, self.lstates_dec):
self.learn_rule.step(param, state)
if dis_update:
for param, state in zip(self.params_dis, self.lstates_dis):
self.learn_rule.step(param, state)
real_cost = np.mean([cost[0] for cost in batch_costs])
fake_cost = np.mean([cost[1] for cost in batch_costs])
encoder = np.mean([c[2] for c in batch_costs])
return real_cost + fake_cost + encoder
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment