Skip to content

Instantly share code, notes, and snippets.

@njellinas
Created October 26, 2017 14:46
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save njellinas/5f4979a8ff4b231961d8d680d71de427 to your computer and use it in GitHub Desktop.
Save njellinas/5f4979a8ff4b231961d8d680d71de427 to your computer and use it in GitHub Desktop.
Two Keras Layer-Class definitions for implementing Weight-Tying and for loading pretrained weights in Deep Autoencoders
import keras.backend as K
from keras.layers import Layer
from keras.legacy import interfaces
from keras.engine import InputSpec
from keras import activations, initializers, regularizers, constraints
class DenseTransposeTied(Layer):
@interfaces.legacy_dense_support
def __init__(self, units,
tied_to=None, # Enter a layer as input to enforce weight-tying
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
if 'input_shape' not in kwargs and 'input_dim' in kwargs:
kwargs['input_shape'] = (kwargs.pop('input_dim'),)
super(DenseTransposeTied, self).__init__(**kwargs)
self.units = units
# We add these two properties to save the tied weights
self.tied_to = tied_to
self.tied_weights = self.tied_to.weights
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(min_ndim=2)
self.supports_masking = True
def build(self, input_shape):
assert len(input_shape) >= 2
input_dim = input_shape[-1]
# We remove the weights and bias because we do not want them to be trainable
self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
self.built = True
def call(self, inputs, **kwargs):
# Return the transpose layer mapping using the explicit weight matrices
output = K.dot(inputs - self.tied_weights[1], K.transpose(self.tied_weights[0]))
if self.activation is not None:
output = self.activation(output)
return output
def compute_output_shape(self, input_shape):
assert input_shape and len(input_shape) >= 2
assert input_shape[-1]
output_shape = list(input_shape)
output_shape[-1] = self.units
return tuple(output_shape)
def get_config(self):
config = {
'units': self.units,
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'kernel_initializer': initializers.serialize(self.kernel_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer': regularizers.serialize(self.activity_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint)
}
base_config = super(DenseTransposeTied, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class DenseTransposePretrained(Layer):
@interfaces.legacy_dense_support
def __init__(self, units,
pretrained_from=None, # Enter a layer as input to load pretrained weights
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
if 'input_shape' not in kwargs and 'input_dim' in kwargs:
kwargs['input_shape'] = (kwargs.pop('input_dim'),)
super(DenseTransposePretrained, self).__init__(**kwargs)
self.units = units
self.pretrained_from = pretrained_from
self.pretrained_weights = self.pretrained_from.weights
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(min_ndim=2)
self.supports_masking = True
def build(self, input_shape):
assert len(input_shape) >= 2
input_dim = input_shape[-1]
# Now add the weights because we want them to be trainable
self.kernel = self.add_weight(shape=(input_dim, self.units),
pretrained_weight=self.pretrained_weights[0],
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight(shape=(self.units,),
pretrained_weight=self.pretrained_weights[1],
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
self.built = True
def call(self, inputs, **kwargs):
# Transpose mapping with the trainable weights
output = K.dot(inputs - self.bias, K.transpose(self.kernel))
if self.activation is not None:
output = self.activation(output)
return output
def compute_output_shape(self, input_shape):
assert input_shape and len(input_shape) >= 2
assert input_shape[-1]
output_shape = list(input_shape)
output_shape[-1] = self.units
return tuple(output_shape)
def get_config(self):
config = {
'units': self.units,
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'kernel_initializer': initializers.serialize(self.kernel_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer': regularizers.serialize(self.activity_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint)
}
base_config = super(DenseTransposePretrained, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
# Override Layer class add_weight method so instead of an initializer it gets
# a pretrained weight as an input
@interfaces.legacy_add_weight_support
def add_weight(self,
name,
shape,
dtype=None,
pretrained_weight=None,
regularizer=None,
trainable=True,
constraint=None):
"""Adds a weight variable to the layer.
# Arguments
name: String, the name for the weight variable.
shape: The shape tuple of the weight.
dtype: The dtype of the weight.
initializer: An Initializer instance (callable).
regularizer: An optional Regularizer instance.
trainable: A boolean, whether the weight should
be trained via backprop or not (assuming
that the layer itself is also trainable).
constraint: An optional Constraint instance.
# Returns
The created weight variable.
"""
if dtype is None:
dtype = K.floatx()
weight = K.variable(pretrained_weight,
dtype=dtype,
name=name,
constraint=constraint)
if regularizer is not None:
self.add_loss(regularizer(weight))
if trainable:
self._trainable_weights.append(weight)
else:
self._non_trainable_weights.append(weight)
return weight
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment