Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
keras lambda layer supporting masking
class Lambda(Layer):
'''Used for evaluating an arbitrary Theano / TensorFlow expression
on the output of the previous layer.
# Examples
# add a x -> x^2 layer
model.add(Lambda(lambda x: x ** 2))
# add a layer that returns the concatenation
# of the positive part of the input and
# the opposite of the negative part
def antirectifier(x):
x -= K.mean(x, axis=1, keepdims=True)
x = K.l2_normalize(x, axis=1)
pos = K.relu(x)
neg = K.relu(-x)
return K.concatenate([pos, neg], axis=1)
def antirectifier_output_shape(input_shape):
shape = list(input_shape)
assert len(shape) == 2 # only valid for 2D tensors
shape[-1] *= 2
return tuple(shape)
model.add(Lambda(antirectifier, output_shape=antirectifier_output_shape))
# Arguments
function: The function to be evaluated.
Takes one argument: the output of previous layer
output_shape: Expected output shape from function.
Could be a tuple or a function of the shape of the input
mask_function: A function which takes as input x and mask and returns a new mask
arguments: optional dictionary of keyword arguments to be passed
to the function.
# Input shape
Arbitrary. Use the keyword argument input_shape
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
# Output shape
Specified by `output_shape` argument.
def __init__(self, function, output_shape=None, mask_function=None, arguments={}, **kwargs):
self.function = function
self.arguments = arguments
if output_shape is None:
self._output_shape = None
elif type(output_shape) in {tuple, list}:
self._output_shape = tuple(output_shape)
if not hasattr(output_shape, '__call__'):
raise Exception('In Lambda, `output_shape` '
'must be a list, a tuple, or a function.')
self._output_shape = output_shape
if mask_function is None:
self._mask_function = None
self.supports_masking = False # can flag masking here or not. not sure which to do.
elif hasattr(mask_function, '__call__'):
self._mask_function = mask_function
self.supports_masking = True
raise Exception("In Lambda, `mask_function` "
"must be a function that computes the new mask")
super(Lambda, self).__init__(**kwargs)
def get_output_shape_for(self, input_shape):
if self._output_shape is None:
# if TensorFlow, we can infer the output shape directly:
if K._BACKEND == 'tensorflow':
if type(input_shape) is list:
xs = [K.placeholder(shape=shape) for shape in input_shape]
x =
x = K.placeholder(shape=input_shape)
x =
if type(x) is list:
return [K.int_shape(x_elem) for x_elem in x]
return K.int_shape(x)
# otherwise, we default to the input shape
return input_shape
elif type(self._output_shape) in {tuple, list}:
nb_samples = input_shape[0] if input_shape else None
return (nb_samples,) + tuple(self._output_shape)
shape = self._output_shape(input_shape)
if type(shape) not in {list, tuple}:
raise Exception('output_shape function must return a tuple')
return tuple(shape)
def call(self, x, mask=None):
arguments = self.arguments
arg_spec = inspect.getargspec(self.function)
if 'mask' in arg_spec.args:
arguments['mask'] = mask
return self.function(x, **arguments)
def compute_mask(self, x, mask=None):
''' can either throw exception or just accept the mask here... not sure which to do'''
if self._mask_function is not None:
return self._mask_function(x, mask)
return mask
def get_config(self):
py3 = sys.version_info[0] == 3
if isinstance(self.function, python_types.LambdaType):
if py3:
function = marshal.dumps(self.function.__code__).decode('raw_unicode_escape')
function = marshal.dumps(self.function.func_code).decode('raw_unicode_escape')
function_type = 'lambda'
function = self.function.__name__
function_type = 'function'
if isinstance(self._mask_function, python_types.LambdaType):
if py3:
mask_func = marshal.dumps(self._mask_function.__code__).decode('raw_unicode_escape')
mask_func = marshal.dumps(self._mask_function.func_code).decode('raw_unicode_escape')
mask_func_type = 'lambda'
elif callable(self._mask_function):
mask_func = self._mask_function.__name__
mask_func_type = 'function'
mask_func = 'unknown'
mask_func_type = 'unknown'
if isinstance(self._output_shape, python_types.LambdaType):
if py3:
output_shape = marshal.dumps(self._output_shape.__code__)
output_shape = marshal.dumps(self._output_shape.func_code)
output_shape_type = 'lambda'
elif callable(self._output_shape):
output_shape = self._output_shape.__name__
output_shape_type = 'function'
output_shape = self._output_shape
output_shape_type = 'raw'
config = {'function': function,
'function_type': function_type,
'mask_function': mask_func,
'mask_function_type': mask_func_type,
'output_shape': output_shape,
'output_shape_type': output_shape_type,
'arguments': self.arguments}
base_config = super(Lambda, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def from_config(cls, config):
function_type = config.pop('function_type')
if function_type == 'function':
function = globals()[config['function']]
elif function_type == 'lambda':
function = marshal.loads(config['function'].encode('raw_unicode_escape'))
function = python_types.FunctionType(function, globals())
raise Exception('Unknown function type: ' + function_type)
mask_function_type = config.pop('mask_function_type')
if mask_function_type == 'function':
mask_function = globals()[config['mask_function']]
elif mask_function_type == 'lambda':
mask_function = marshal.loads(config['mask_function'].encode('raw_unicode_escape'))
mask_function = python_types.FunctionType(mask_function, globals())
raise Exception('Unknown function type: ' + mask_function_type)
output_shape_type = config.pop('output_shape_type')
if output_shape_type == 'function':
output_shape = globals()[config['output_shape']]
elif output_shape_type == 'lambda':
output_shape = marshal.loads(config['output_shape'])
output_shape = python_types.FunctionType(output_shape, globals())
output_shape = config['output_shape']
config['function'] = function
config['output_shape'] = output_shape
config['mask_function'] = mask_function
return cls(**config)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment