Created
July 30, 2017 04:03
-
-
Save joelouismarino/b78257a716df15fb1886442927cc6d72 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LRN(Layer): | |
def __init__(self, alpha=0.0001,k=1,beta=0.75,n=5, **kwargs): | |
self.alpha = alpha | |
self.k = k | |
self.beta = beta | |
self.n = n | |
super(LRN, self).__init__(**kwargs) | |
def call(self, x, mask=None): | |
b, ch, r, c = x.shape | |
half_n = self.n // 2 # half the local region | |
input_sqr = T.sqr(x) # square the input | |
extra_channels = T.alloc(0., b, ch + 2*half_n, r, c) # make an empty tensor with zero pads along channel dimension | |
input_sqr = T.set_subtensor(extra_channels[:, half_n:half_n+ch, :, :],input_sqr) # set the center to be the squared input | |
scale = self.k # offset for the scale | |
norm_alpha = self.alpha / self.n # normalized alpha | |
for i in range(self.n): | |
scale += norm_alpha * input_sqr[:, i:i+ch, :, :] | |
scale = scale ** self.beta | |
x = x / scale | |
return x | |
def get_config(self): | |
config = {"alpha": self.alpha, | |
"k": self.k, | |
"beta": self.beta, | |
"n": self.n} | |
base_config = super(LRN, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment