Created
August 13, 2016 01:26
Star
You must be signed in to star a gist
More general version of the Highway Network
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.engine import InputSpec | |
from keras.layers import Dense | |
from keras.layers.wrappers import Wrapper, TimeDistributed | |
class Freeway(Wrapper): | |
def __init__(self, layer, gate=None, **kwargs): | |
self.supports_masking = True | |
self.gate = gate | |
super(Freeway, self).__init__(layer, **kwargs) | |
def build(self, input_shape=None): | |
assert len(input_shape) in [2, 3] | |
self.input_spec = [InputSpec(shape=input_shape)] | |
nb_output_dims = input_shape[-1] | |
if self.gate is None: | |
gate = Dense(nb_output_dims, activation='sigmoid') | |
if len(input_shape) == 3: | |
gate = TimeDistributed(gate) | |
self.gate = gate | |
super(Freeway, self).build(input_shape) | |
def get_output_shape_for(self, input_shape): | |
assert self.layer.get_output_shape_for(input_shape) == input_shape | |
assert self.gate.get_output_shape_for(input_shape) == input_shape | |
return input_shape | |
def call(self, x, mask=None): | |
return self.layer(x) * self.gate(x) + x * (1 - self.gate(x)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment