Skip to content

Instantly share code, notes, and snippets.

@wielandbrendel
Last active September 30, 2015 11:33
Show Gist options
  • Save wielandbrendel/c03c91f78b9a209f2a8d to your computer and use it in GitHub Desktop.
Save wielandbrendel/c03c91f78b9a209f2a8d to your computer and use it in GitHub Desktop.
Updated simulation code (no devec + inplace BLAS operations)
blas_set_num_threads(1)
function train_network(A, T, Of, cs, dt)
I, N = size(T)
z = zeros(I)
r = zeros(N)
Az = Array(Float64, I)
Ofr = Array(Float64, N)
I_teach = Array(Float64, N)
Tz = Array(Float64, N)
@inbounds for t in 1:size(cs, 1)
# precompute
BLAS.gemv!('N', 1.0, A, z, 0.0, Az) # Az = A*z
BLAS.gemv!('N', 1.0, Of, r, 0.0, Ofr) # Ofr = Of*r
# compute training signal
for i in 1:I
z[i] += dt.*(Az[i] + cs[t, i] - 0.5.*z[i])
end
BLAS.gemv!('T', 1.0, T, Az + cs[t], 0.0, I_teach) # I_teach = T'*(Az + cs[t])
BLAS.gemv!('T', 1.0, T, z, 0.0, Tz) # Tz = T'*z
# rate updates
for n in 1:N
r[n] += dt.*(I_teach[n] - Ofr[n] - 0.1.*r[n])
end
# weights updates
for i in 1:I
for n in 1:N
T[i, n] += dt.*1e-3.*(z[i].*r[n] - T[i, n])
end
end
BLAS.scal!(N*N, 1 - dt*1e-3, Of, 1)
BLAS.ger!(dt*1e-3, Tz, r, Of)
end
end
# init parameters
N, I = 20, 2
dt = 1e-3
# init weights
T = rand(I, N)*N
A = rand(I, I)
Of = rand(N, N)/N
# simulation time & input
sim_T = 2000
ts = 0:dt:sim_T
cs = randn(size(ts, 1), I)
# compile
train_network(A, T, Of, cs, dt)
# test network simulation
@time train_network(A, T, Of, cs, dt)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment