Skip to content

Instantly share code, notes, and snippets.

def vae_loss(input_img, output):
# compute the average MSE error, then scale it up i.e. simply sum on all axes
reconstruction_loss = K.sum(K.square(output-input_img))
# compute the KL loss
kl_loss = -0.5 * K.sum(1 + log_stddev - K.square(mean) - K.square(K.exp(log_stddev)), axis=-1)
# return the average loss over all images in batch
total_loss = K.mean(reconstruction_loss + kl_loss)
return total_loss
@irhum
irhum / training.py
Last active January 20, 2019 12:49
def train(model, dataloaders, criterion, optimizer, device, out_name, dlib_models=None,
validate=True, validate_every=10, num_epochs=100):
if validate:
assert len(dataloaders) == 2
assert dlib_models is not None
# start at epoch 1, end at epoch num_epochs (inclusive)
for epoch in range(1, num_epochs+1):
# Training phase
# returns distance (a float) if face detected in synthesized image, None otherwise
def distance_metric(y_pred, x, dlib_models):
dist = None
face_detector, landmark_detector, face_embedder = dlib_models
face_boxes = face_detector(y_pred, 1)
if len(face_boxes) == 1:
landmarks = landmark_detector(y_pred, face_boxes[0])
embedding = face_embedder.compute_face_descriptor(y_pred, landmarks)
@irhum
irhum / model.py
Created January 19, 2019 16:07
A basic DCGAN implementation
class Generator(nn.Module):
def __init__(self, n_hidden, bottom_width=4, channels=512):
super().__init__()
self.channels = channels
self.bottom_width = bottom_width
self.linear = nn.Linear(n_hidden, bottom_width*bottom_width*channels)
self.dconv1 = nn.ConvTranspose2d(channels, channels // 2, 4, 2, 1)
self.dconv2 = nn.ConvTranspose2d(channels // 2, channels // 4, 4, 2, 1)
self.dconv3 = nn.ConvTranspose2d(channels // 4, channels // 8, 4, 2, 1)
self.dconv4 = nn.ConvTranspose2d(channels // 8, 3, 4, 2, 1)
# build your encoder upto here. It can simply be a series of dense layers, a convolutional network
# or even an LSTM decoder. Once made, flatten out the final layer of the encoder, call it hidden.
# we use Keras to build the graph
latent_size = 5
mean = Dense(latent_size)(hidden)
# we usually don't directly compute the stddev σ
# but the log of the stddev instead, which is log(σ)