Skip to content

Instantly share code, notes, and snippets.

@wielandbrendel
Created September 30, 2015 12:09
Show Gist options
  • Save wielandbrendel/5ee1457722cb6706240b to your computer and use it in GitHub Desktop.
Save wielandbrendel/5ee1457722cb6706240b to your computer and use it in GitHub Desktop.
Recurrent network simulation in Cython
cimport cython
cimport numpy as np
from libc.math cimport log, exp, sqrt, sin
import numpy as np
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def c_analytic_simulate(np.ndarray[double, ndim=2] A, np.ndarray[double, ndim=2] T, np.ndarray[double, ndim=2] Of,
np.ndarray[double, ndim=2] cs, double dt):
cdef unsigned int I = T.shape[0]
cdef unsigned int N = T.shape[1]
cdef np.ndarray[double, ndim=1] z = np.zeros(I)
cdef np.ndarray[double, ndim=1] r = np.zeros(N)
cdef np.ndarray[double, ndim=1] Az = np.zeros(I)
cdef np.ndarray[double, ndim=1] Azc = np.zeros(I)
cdef np.ndarray[double, ndim=1] Ofr = np.zeros(N)
cdef np.ndarray[double, ndim=1] I_teach = np.zeros(N)
cdef np.ndarray[double, ndim=1] Tz = np.zeros(N)
cdef int t = 0
cdef int i = 0
cdef int n = 0
cdef int m = 0
for t in xrange(cs.shape[0]):
for i in xrange(I):
Az[i] = 0
for j in xrange(I):
Az[i] += A[i, j]*z[j]
for n in xrange(N):
Ofr[n] = 0
for m in xrange(N):
Ofr[n] += Of[n, m]*r[m]
for i in xrange(I):
z[i] += dt*(Az[i] + cs[t, i] - 0.5*z[i])
Azc[i] = Az[i] + cs[t, i]
for n in xrange(N):
I_teach[n] = 0
Tz[n] = 0
for i in xrange(I):
I_teach[n] += T[i, n]*(Az[i] + cs[t, i])
Tz[n] += T[i, n]*z[i]
# rate updates
for n in xrange(N):
r[n] += dt*(I_teach[n] - Ofr[n] - 0.1*r[n])
# weights updates
for i in xrange(I):
for n in xrange(N):
T[i, n] += dt*1e-3*(z[i]*r[n] - T[i, n])
for n in xrange(N):
for m in xrange(N):
Of[n, m] += dt*1e-3*(Tz[n]*r[m] - Of[n, m])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment