Skip to content

Instantly share code, notes, and snippets.

@vlandeiro
Last active October 18, 2018 20:04
Show Gist options
  • Save vlandeiro/d03bfecc2e8195987205fc388d49ce98 to your computer and use it in GitHub Desktop.
Save vlandeiro/d03bfecc2e8195987205fc388d49ce98 to your computer and use it in GitHub Desktop.
Implementation of Domain Adversarial Network in Keras using the Tensorflow back-end
import numpy as np
import tensorflow as tf
import keras.backend as K
from keras.layers import Dense, Input, Dropout
from keras.models import Model
from keras.engine import Layer
# Reverse gradient layer from https://github.com/michetonu/gradient_reversal_keras_tf/blob/master/flipGradientTF.py
# - Added compute_output_shape for Keras 2 compatibility
# - Fixed bug where RegisterGradient was raising a KeyError
def reverse_gradient(X, hp_lambda):
"""Flips the sign of the incoming gradient during training."""
try:
reverse_gradient.num_calls += 1
except AttributeError:
reverse_gradient.num_calls = 1
while True:
try:
grad_name = "GradientReversal%d" % reverse_gradient.num_calls
@tf.RegisterGradient(grad_name)
def _flip_gradients(op, grad):
return [tf.negative(grad) * hp_lambda]
break
except KeyError:
reverse_gradient.num_calls += 1
g = K.get_session().graph
with g.gradient_override_map({"Identity": grad_name}):
y = tf.identity(X)
return y
class GradientReversal(Layer):
"""Flip the sign of gradient during training."""
def __init__(self, hp_lambda, **kwargs):
super(GradientReversal, self).__init__(**kwargs)
self.supports_masking = False
self.hp_lambda = hp_lambda
def build(self, input_shape):
self.trainable_weights = []
def call(self, x, mask=None):
return reverse_gradient(x, self.hp_lambda)
def get_output_shape_for(self, input_shape):
return input_shape
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = {"hp_lambda": self.hp_lambda}
base_config = super(GradientReversal, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class DANN:
def __init__(
self,
x_dim,
y_dim,
z_dim,
h1=100,
h2=50,
z_inv_factor=1,
loss_weights=None,
dropout_rate=.1,
y_loss="binary_crossentropy",
y_activation="sigmoid",
z_loss="binary_crossentropy",
z_activation="sigmoid",
optimizer="adam"
):
"""
Args:
x_dim (int): number of features/columns in input data X
y_dim (int): dimension of the predicted variable
z_dim (int): dimension of the confounding variable
h1 (int): number of neurons in the first hidden layer (from X to e)
h2 (int): number of neurons in the second set of hidden layers (from e to y and e to z)
z_inv_factor (float): scaler applied to the gradient at the reversal step
loss_weights (dict or None): dictionary indicating the weight of losses on y and z (e.g. dict(y=10, z=3)).
Default value (None) makes all the weights equal to 1.
dropout_rate (float between 0 and 1): ratio of features to randomly drop between X and e at fitting time
{y,z}_loss: loss function to associate to this variable (default to binary crossentropy)
{y,z}_activation: activation function to associate to this variable (default to sigmoid)
optimizer: optimizer to use during training
"""
self.x_dim = x_dim
self.z_dim = z_dim
self.h1, self.h2 = h1, h2
self.z_inv_factor = z_inv_factor
self.loss_weights = loss_weights
self.dropout_rate = dropout_rate
self.y_loss = y_loss
self.y_activation = y_activation
self.z_loss = z_loss
self.z_activation = z_activation
self.optimizer = optimizer
def _build_model(self):
x_input = Input((self.x_dim,), name="x_input")
e = Dropout(self.dropout_rate)(x_input)
e = Dense(self.h1, activation="relu", kernel_regularizer="l2", name="e")(e)
# Predict y
l = Dense(self.h2, activation="relu", kernel_regularizer="l2")(e)
y = Dense(self.y_dim, name="y", activation=self.y_activation)(l)
# Predict z with gradient reversal
l = GradientReversal(self.z_inv_factor)(e)
l = Dense(self.h2, activation="relu", kernel_regularizer="l2")(l)
z = Dense(self.z_dim, name="z", activation=self.z_activation)(l)
# Create the full model and compile it
self.model = Model(x_input, [y, z])
self.model.compile(
optimizer=self.optimizer,
loss=[self.y_loss, self.z_loss],
loss_weights=self.loss_weights
)
# Expose a model that predicts the target variable only
self.clf = Model(x_input, y)
def fit(self, *args, **kwargs):
# Reset the Tensorflow graph to avoid resource exhaustion
K.clear_session()
# Build a fresh model
self._build_model()
self.h = self.model.fit(*args, **kwargs)
return self.h
def predict(self, *args, **kwargs):
return self.clf.predict(*args, **kwargs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment