Created
October 26, 2017 14:46
-
-
Save njellinas/5f4979a8ff4b231961d8d680d71de427 to your computer and use it in GitHub Desktop.
Two Keras Layer-Class definitions for implementing Weight-Tying and for loading pretrained weights in Deep Autoencoders
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import keras.backend as K | |
from keras.layers import Layer | |
from keras.legacy import interfaces | |
from keras.engine import InputSpec | |
from keras import activations, initializers, regularizers, constraints | |
class DenseTransposeTied(Layer): | |
@interfaces.legacy_dense_support | |
def __init__(self, units, | |
tied_to=None, # Enter a layer as input to enforce weight-tying | |
activation=None, | |
use_bias=True, | |
kernel_initializer='glorot_uniform', | |
bias_initializer='zeros', | |
kernel_regularizer=None, | |
bias_regularizer=None, | |
activity_regularizer=None, | |
kernel_constraint=None, | |
bias_constraint=None, | |
**kwargs): | |
if 'input_shape' not in kwargs and 'input_dim' in kwargs: | |
kwargs['input_shape'] = (kwargs.pop('input_dim'),) | |
super(DenseTransposeTied, self).__init__(**kwargs) | |
self.units = units | |
# We add these two properties to save the tied weights | |
self.tied_to = tied_to | |
self.tied_weights = self.tied_to.weights | |
self.activation = activations.get(activation) | |
self.use_bias = use_bias | |
self.kernel_initializer = initializers.get(kernel_initializer) | |
self.bias_initializer = initializers.get(bias_initializer) | |
self.kernel_regularizer = regularizers.get(kernel_regularizer) | |
self.bias_regularizer = regularizers.get(bias_regularizer) | |
self.activity_regularizer = regularizers.get(activity_regularizer) | |
self.kernel_constraint = constraints.get(kernel_constraint) | |
self.bias_constraint = constraints.get(bias_constraint) | |
self.input_spec = InputSpec(min_ndim=2) | |
self.supports_masking = True | |
def build(self, input_shape): | |
assert len(input_shape) >= 2 | |
input_dim = input_shape[-1] | |
# We remove the weights and bias because we do not want them to be trainable | |
self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim}) | |
self.built = True | |
def call(self, inputs, **kwargs): | |
# Return the transpose layer mapping using the explicit weight matrices | |
output = K.dot(inputs - self.tied_weights[1], K.transpose(self.tied_weights[0])) | |
if self.activation is not None: | |
output = self.activation(output) | |
return output | |
def compute_output_shape(self, input_shape): | |
assert input_shape and len(input_shape) >= 2 | |
assert input_shape[-1] | |
output_shape = list(input_shape) | |
output_shape[-1] = self.units | |
return tuple(output_shape) | |
def get_config(self): | |
config = { | |
'units': self.units, | |
'activation': activations.serialize(self.activation), | |
'use_bias': self.use_bias, | |
'kernel_initializer': initializers.serialize(self.kernel_initializer), | |
'bias_initializer': initializers.serialize(self.bias_initializer), | |
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), | |
'bias_regularizer': regularizers.serialize(self.bias_regularizer), | |
'activity_regularizer': regularizers.serialize(self.activity_regularizer), | |
'kernel_constraint': constraints.serialize(self.kernel_constraint), | |
'bias_constraint': constraints.serialize(self.bias_constraint) | |
} | |
base_config = super(DenseTransposeTied, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
class DenseTransposePretrained(Layer): | |
@interfaces.legacy_dense_support | |
def __init__(self, units, | |
pretrained_from=None, # Enter a layer as input to load pretrained weights | |
activation=None, | |
use_bias=True, | |
kernel_initializer='glorot_uniform', | |
bias_initializer='zeros', | |
kernel_regularizer=None, | |
bias_regularizer=None, | |
activity_regularizer=None, | |
kernel_constraint=None, | |
bias_constraint=None, | |
**kwargs): | |
if 'input_shape' not in kwargs and 'input_dim' in kwargs: | |
kwargs['input_shape'] = (kwargs.pop('input_dim'),) | |
super(DenseTransposePretrained, self).__init__(**kwargs) | |
self.units = units | |
self.pretrained_from = pretrained_from | |
self.pretrained_weights = self.pretrained_from.weights | |
self.activation = activations.get(activation) | |
self.use_bias = use_bias | |
self.kernel_initializer = initializers.get(kernel_initializer) | |
self.bias_initializer = initializers.get(bias_initializer) | |
self.kernel_regularizer = regularizers.get(kernel_regularizer) | |
self.bias_regularizer = regularizers.get(bias_regularizer) | |
self.activity_regularizer = regularizers.get(activity_regularizer) | |
self.kernel_constraint = constraints.get(kernel_constraint) | |
self.bias_constraint = constraints.get(bias_constraint) | |
self.input_spec = InputSpec(min_ndim=2) | |
self.supports_masking = True | |
def build(self, input_shape): | |
assert len(input_shape) >= 2 | |
input_dim = input_shape[-1] | |
# Now add the weights because we want them to be trainable | |
self.kernel = self.add_weight(shape=(input_dim, self.units), | |
pretrained_weight=self.pretrained_weights[0], | |
name='kernel', | |
regularizer=self.kernel_regularizer, | |
constraint=self.kernel_constraint) | |
if self.use_bias: | |
self.bias = self.add_weight(shape=(self.units,), | |
pretrained_weight=self.pretrained_weights[1], | |
name='bias', | |
regularizer=self.bias_regularizer, | |
constraint=self.bias_constraint) | |
else: | |
self.bias = None | |
self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim}) | |
self.built = True | |
def call(self, inputs, **kwargs): | |
# Transpose mapping with the trainable weights | |
output = K.dot(inputs - self.bias, K.transpose(self.kernel)) | |
if self.activation is not None: | |
output = self.activation(output) | |
return output | |
def compute_output_shape(self, input_shape): | |
assert input_shape and len(input_shape) >= 2 | |
assert input_shape[-1] | |
output_shape = list(input_shape) | |
output_shape[-1] = self.units | |
return tuple(output_shape) | |
def get_config(self): | |
config = { | |
'units': self.units, | |
'activation': activations.serialize(self.activation), | |
'use_bias': self.use_bias, | |
'kernel_initializer': initializers.serialize(self.kernel_initializer), | |
'bias_initializer': initializers.serialize(self.bias_initializer), | |
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), | |
'bias_regularizer': regularizers.serialize(self.bias_regularizer), | |
'activity_regularizer': regularizers.serialize(self.activity_regularizer), | |
'kernel_constraint': constraints.serialize(self.kernel_constraint), | |
'bias_constraint': constraints.serialize(self.bias_constraint) | |
} | |
base_config = super(DenseTransposePretrained, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) | |
# Override Layer class add_weight method so instead of an initializer it gets | |
# a pretrained weight as an input | |
@interfaces.legacy_add_weight_support | |
def add_weight(self, | |
name, | |
shape, | |
dtype=None, | |
pretrained_weight=None, | |
regularizer=None, | |
trainable=True, | |
constraint=None): | |
"""Adds a weight variable to the layer. | |
# Arguments | |
name: String, the name for the weight variable. | |
shape: The shape tuple of the weight. | |
dtype: The dtype of the weight. | |
initializer: An Initializer instance (callable). | |
regularizer: An optional Regularizer instance. | |
trainable: A boolean, whether the weight should | |
be trained via backprop or not (assuming | |
that the layer itself is also trainable). | |
constraint: An optional Constraint instance. | |
# Returns | |
The created weight variable. | |
""" | |
if dtype is None: | |
dtype = K.floatx() | |
weight = K.variable(pretrained_weight, | |
dtype=dtype, | |
name=name, | |
constraint=constraint) | |
if regularizer is not None: | |
self.add_loss(regularizer(weight)) | |
if trainable: | |
self._trainable_weights.append(weight) | |
else: | |
self._non_trainable_weights.append(weight) | |
return weight |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment