This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def vae_loss(input_img, output): | |
# compute the average MSE error, then scale it up i.e. simply sum on all axes | |
reconstruction_loss = K.sum(K.square(output-input_img)) | |
# compute the KL loss | |
kl_loss = -0.5 * K.sum(1 + log_stddev - K.square(mean) - K.square(K.exp(log_stddev)), axis=-1) | |
# return the average loss over all images in batch | |
total_loss = K.mean(reconstruction_loss + kl_loss) | |
return total_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train(model, dataloaders, criterion, optimizer, device, out_name, dlib_models=None, | |
validate=True, validate_every=10, num_epochs=100): | |
if validate: | |
assert len(dataloaders) == 2 | |
assert dlib_models is not None | |
# start at epoch 1, end at epoch num_epochs (inclusive) | |
for epoch in range(1, num_epochs+1): | |
# Training phase |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# returns distance (a float) if face detected in synthesized image, None otherwise | |
def distance_metric(y_pred, x, dlib_models): | |
dist = None | |
face_detector, landmark_detector, face_embedder = dlib_models | |
face_boxes = face_detector(y_pred, 1) | |
if len(face_boxes) == 1: | |
landmarks = landmark_detector(y_pred, face_boxes[0]) | |
embedding = face_embedder.compute_face_descriptor(y_pred, landmarks) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Generator(nn.Module): | |
def __init__(self, n_hidden, bottom_width=4, channels=512): | |
super().__init__() | |
self.channels = channels | |
self.bottom_width = bottom_width | |
self.linear = nn.Linear(n_hidden, bottom_width*bottom_width*channels) | |
self.dconv1 = nn.ConvTranspose2d(channels, channels // 2, 4, 2, 1) | |
self.dconv2 = nn.ConvTranspose2d(channels // 2, channels // 4, 4, 2, 1) | |
self.dconv3 = nn.ConvTranspose2d(channels // 4, channels // 8, 4, 2, 1) | |
self.dconv4 = nn.ConvTranspose2d(channels // 8, 3, 4, 2, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# build your encoder upto here. It can simply be a series of dense layers, a convolutional network | |
# or even an LSTM decoder. Once made, flatten out the final layer of the encoder, call it hidden. | |
# we use Keras to build the graph | |
latent_size = 5 | |
mean = Dense(latent_size)(hidden) | |
# we usually don't directly compute the stddev σ | |
# but the log of the stddev instead, which is log(σ) |