Last active
January 13, 2020 13:12
-
-
Save bquast/4b35f0b95423bf86e6a5dd1d83530434 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## sigmoid function | |
sigmoid <- function(x, k=1, x0=0) | |
1 / (1+exp( -k*(x-x0) )) | |
## tanh^2 function | |
tanhsq <- function(x) | |
((exp(2*x)-1)^2)/((exp(2*x)+1)^2) | |
# define weights | |
Wa = matrix( c(0.45, 0.25), nrow=1 ); Ua = matrix(0.15); ba = matrix(0.20) | |
Wi = matrix( c(0.95, 0.80), nrow=1 ); Ui = matrix(0.80); bi = matrix(0.65) | |
Wf = matrix( c(0.70, 0.45), nrow=1 ); Uf = matrix(0.10); bf = matrix(0.15) | |
Wo = matrix( c(0.60, 0.40), nrow=1 ); Uo = matrix(0.25); bo = matrix(0.10) | |
W = rbind(Wa, Wi, Wf, Wo) | |
U = rbind(Ua, Ui, Uf, Uo) | |
# initialise | |
out_minus1 = 0 | |
state_minus1 = 0 | |
# define input and output data | |
x0 = matrix( c(1, 2) ); y0 = matrix(0.5) | |
x1 = matrix( c(0.5, 3) ); y1 = matrix(1.25) | |
# forward @ t=0 | |
a0 = tanh( Wa%*%x0 + Ua%*%out_minus1 + ba ) | |
i0 = sigmoid( Wi%*%x0 + Ui%*%out_minus1 + bi ) | |
f0 = sigmoid( Wf%*%x0 + Uf%*%out_minus1 + bf ) | |
o0 = sigmoid( Wo%*%x0 + Uo%*%out_minus1 + bo ) | |
state0 = a0*i0 + f0*state_minus1 | |
out0 = tanh(state0) * o0 | |
# forward @ t=1 | |
a1 = tanh( Wa%*%x1 + Ua%*%out0 + ba ) | |
i1 = sigmoid( Wi%*%x1 + Ui%*%out0 + bi ) | |
f1 = sigmoid( Wf%*%x1 + Uf%*%out0 + bf ) | |
o1 = sigmoid( Wo%*%x1 + Uo%*%out0 + bo ) | |
state1 = a1*i1 + f1*state0 | |
out1 = tanh(state1) * o1 | |
# backward @ t=1 | |
## difference using L2 loss | |
Delta1 = out1 - y1 | |
Delta_out1 = 0 # because last timestep | |
delta_out1 = Delta1 + Delta_out1 | |
delta_state1 = delta_out1 * o1 * ( 1-tanhsq(state1) ) + 0*0 # -0.11 instead of -0.07 | |
delta_a1 = delta_state1 * i1 * (1-a1^2) | |
delta_i1 = delta_state1 * a1 * i1*(1-i1) | |
delta_f1 = delta_state1 * state0 * f1*(1-f1) | |
delta_o1 = delta_out1 * tanh(state1) * o1*(1-o1) | |
delta_gates1 = rbind(delta_a1, delta_i1, delta_f1, delta_o1) | |
delta_x1 = t(W) %*% delta_gates1 | |
Delta_out0 = t(U) %*% delta_gates1 | |
# backward @ t=0 | |
Delta_0 = out0 - y0 | |
# Delta_out0 already defined | |
delta_out0 = Delta_0 + Delta_out0 | |
delta_state0 = delta_out0 * o0 * ( 1-tanhsq(state0) ) + delta_state1*f1 | |
delta_a0 = delta_state0 * i0 * (1-a0^2) | |
delta_i0 = delta_state0 * a0 * i0 * (1-i0) | |
delta_f0 = delta_state0 * state_minus1 * f0 * (1-f0) | |
delta_o0 = delta_out0 * tanh(state0) * o0 * (1-o0) | |
delta_gates0 = rbind(delta_a0, delta_i0, delta_f0, delta_o0) | |
delta_x0 = t(W) %*% delta_gates0 | |
Delta_out_minus1 = t(U) %*% delta_gates0 | |
# calculate weight updates | |
deltaW = delta_gates0 %x% t(x0) + delta_gates1 %x% t(x1) | |
deltaU = delta_gates1 %x% out1 | |
deltab = delta_gates0 + delta_gates1 | |
# update weights using Stochastic Gradient Descent (SGD) | |
lambda = 0.1 | |
Wa = Wa - lambda*deltaW[1,] | |
Wi = Wi - lambda*deltaW[2,] | |
Wf = Wf - lambda*deltaW[3,] | |
Wo = Wo - lambda*deltaW[4,] | |
Ua = Ua - lambda*deltaU[1,] | |
Ui = Ui - lambda*deltaU[2,] | |
Uf = Uf - lambda*deltaU[3,] | |
Uo = Uo - lambda*deltaU[4,] | |
ba = ba - lambda*deltab[1,] | |
bi = bi - lambda*deltab[2,] | |
bf = bf - lambda*deltab[3,] | |
bo = bo - lambda*deltab[4,] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment