Skip to content

Instantly share code, notes, and snippets.

@danielegrattarola
Created October 15, 2019 22:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save danielegrattarola/c663346b529e758f0224c8313818ad77 to your computer and use it in GitHub Desktop.
Save danielegrattarola/c663346b529e758f0224c8313818ad77 to your computer and use it in GitHub Desktop.
Implementing a Network-based Model of Epilepsy with Numpy and Numba. Code for https://danielegrattarola.github.io/posts/2019-10-03/epilepsy-model.html
import datetime
import time
import matplotlib.pyplot as plt
import numpy as np
from numba import njit
import networkx as nx
def degree_power(adj, pow):
"""
Computes D^{p} from the given adjacency matrix.
NOTE: no need to JIT compile because it only runs once.
:param adj: rank 2 array.
:param pow: exponent to which elevate the degree matrix.
:return: the exponentiated degree matrix.
"""
degrees = np.power(adj.sum(1), pow).reshape(-1)
degrees[np.isinf(degrees)] = 0.
D = np.diag(degrees)
return D
def normalized_adjacency(adj):
"""
Normalizes the given adjacency matrix using the degree matrix as
D^{-1/2}AD^{-1/2} (symmetric normalization).
NOTE: no need to JIT compile because it only runs once.
:param adj: rank 2 array.
:return: the normalized adjacency matrix.
"""
normalized_D = degree_power(adj, -0.5)
output = normalized_D.dot(adj).dot(normalized_D)
return output
@njit
def f(z, lamb=0., omega=1):
"""Eq. 1 in the paper, the deterministic update function of each node.
:param z: complex, the current state.
:param lamb: hyperparameter to control the attractors of each node.
:param omega: frequency of the ictal spike wave discharge (SWD) in rad/s.
"""
np.complex(0, 1)
return ((lamb - 1 + complex(0, omega)) * z
+ (2 * z * np.abs(z) ** 2)
- (z * np.abs(z) ** 4))
@njit
def delta_wiener(size, dt):
"""Returns the random delta between two consecutive steps of a Wiener
process (Brownian motion).
:param size: desired shape of the output array.
:param dt: float, time increment in seconds.
"""
return np.sqrt(dt) * np.random.randn(*size)
@njit
def complex_delta_wiener(size, dt):
"""Returns the random delta between two consecutive steps of a complex
Wiener process (Brownian motion). The process is calculared as u(t) + jv(t)
where u and v are simple Wiener processes.
:param size: desired shape of the output array.
:param dt: float, time increment in seconds.
"""
u = delta_wiener(size, dt)
v = delta_wiener(size, dt)
return u + 1j * v
@njit
def step(z):
"""
Compute one time step of the system.
:param z: complex, the current state.
:return: comples, the change from the current state to the next s.t.
z[t+1] = z[t] + step(z[t]).
"""
# Matrix with pairwise differences of nodes
delta_z = z.reshape(-1, 1) - z.reshape(1, -1)
# Compute diffusive coupling
diffusive_coupling = np.diag(A_norm.T.dot(delta_z))
# Compute change in state
update_from_self = f(z, lamb=lamb, omega=omega)
update_from_others = beta * diffusive_coupling
noise = alpha * complex_delta_wiener(z.shape, dt)
dz = (update_from_self + update_from_others) * dt + noise
return dz
@njit
def evolve_system(z0, steps):
"""
Evolve the system starting from the given initial state (z0) for a given
number of time steps (steps).
:param z0: complex, the initial state.
:param steps: int, number of steps to evolve the system for.
:return: list, the sequence of states.
"""
# TODO fastets way is to pre-allocate list and assign elements, but it didn't
# TODO seem to bring actual improvements in practice.
steps_in_percent = steps / 100
z = [z0]
for i in range(steps):
if not i % steps_in_percent:
print(i / steps_in_percent, '%')
dz = step(z[-1])
z.append(z[-1] + dz)
return z
################################################################################
# Configuration
################################################################################
N = 2 # Number of nodes in the system
seconds_to_generate = 100 # Number of seconds to generate
dt = 0.0001 # Sampling period (1/freq)
plots = True # Whether to produce plots with results
fmt = 'png' # Format of the saved figure
# Hyper-parameters
omega = 20 # Frequency of oscillations
alpha = 0.125 # Intensity of the noise
lamb = 0.5 #
beta = 0.1 # Coupling strength b/w nodes
A = np.random.randint(0, 2, (N, N))
np.fill_diagonal(A, 0)
A_norm = normalized_adjacency(A).astype(np.complex128)
################################################################################
# Evolve system
################################################################################
z0 = np.zeros(N).astype(np.complex128) # Starting conditions of each node
steps = int(seconds_to_generate / dt) # Number of steps to generate
# Evolve system and time it
timer = -time.time()
timesteps = evolve_system(z0, steps)
timer += time.time()
print('Elapsed: {} ({} iter/sec)'
.format(datetime.timedelta(seconds=timer), steps / timer))
timesteps = np.array(timesteps)
################################################################################
# Plots
################################################################################
if not plots:
exit()
# Plot adjacency matrix
plt.figure(figsize=(4, 4))
plt.title('Adjacency matrix')
plt.imshow(A, cmap='gray')
plt.xticks(np.arange(N))
plt.yticks(np.arange(N))
plt.tight_layout()
plt.savefig('adjacency_matrix.{}'.format(fmt), bbox_inches='tight')
# Plot graph
plt.figure()
G = nx.DiGraph(A)
pos = nx.layout.spring_layout(G)
nodes = nx.draw_networkx_nodes(G, pos, node_color='orange', node_size=1000)
edges = nx.draw_networkx_edges(G, pos, node_size=1000, arrowsize=10, width=2, edge_color='gray')
labels = nx.draw_networkx_labels(G, pos, {i:i+1 for i in range(N)})
plt.savefig('graph.{}'.format(fmt), bbox_inches='tight')
# Plot evolution of nodes in complex plane
rows = np.floor(np.sqrt(N))
cols = np.ceil(np.sqrt(N))
plt.figure(figsize=(cols * 4, rows * 4))
for n in range(N):
plt.subplot(rows, cols, n + 1)
plt.title('Node {}'.format(n + 1))
plt.plot(timesteps[..., n].real, timesteps[..., n].imag, linewidth=0.5)
plt.ylabel('Im(z)')
plt.xlabel('Re(z)')
plt.ylim(-1.6, 1.6)
plt.xlim(-1.6, 1.6)
plt.tight_layout()
plt.savefig('nodes_complex_plane.{}'.format(fmt), bbox_inches='tight')
# Plot Re of nodes vs. time
plt.figure(figsize=(16, N * 2))
for n in range(N):
plt.subplot(N, 1, n + 1)
plt.plot(timesteps[..., n].real, linewidth=0.5)
plt.ylabel('Node {}'.format(n + 1))
plt.ylim(-1.6, 1.6)
plt.tight_layout()
plt.savefig('nodes_re_v_time.{}'.format(fmt), bbox_inches='tight')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment