Created
October 15, 2019 22:10
-
-
Save danielegrattarola/c663346b529e758f0224c8313818ad77 to your computer and use it in GitHub Desktop.
Implementing a Network-based Model of Epilepsy with Numpy and Numba. Code for https://danielegrattarola.github.io/posts/2019-10-03/epilepsy-model.html
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
import time | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from numba import njit | |
import networkx as nx | |
def degree_power(adj, pow): | |
""" | |
Computes D^{p} from the given adjacency matrix. | |
NOTE: no need to JIT compile because it only runs once. | |
:param adj: rank 2 array. | |
:param pow: exponent to which elevate the degree matrix. | |
:return: the exponentiated degree matrix. | |
""" | |
degrees = np.power(adj.sum(1), pow).reshape(-1) | |
degrees[np.isinf(degrees)] = 0. | |
D = np.diag(degrees) | |
return D | |
def normalized_adjacency(adj): | |
""" | |
Normalizes the given adjacency matrix using the degree matrix as | |
D^{-1/2}AD^{-1/2} (symmetric normalization). | |
NOTE: no need to JIT compile because it only runs once. | |
:param adj: rank 2 array. | |
:return: the normalized adjacency matrix. | |
""" | |
normalized_D = degree_power(adj, -0.5) | |
output = normalized_D.dot(adj).dot(normalized_D) | |
return output | |
@njit | |
def f(z, lamb=0., omega=1): | |
"""Eq. 1 in the paper, the deterministic update function of each node. | |
:param z: complex, the current state. | |
:param lamb: hyperparameter to control the attractors of each node. | |
:param omega: frequency of the ictal spike wave discharge (SWD) in rad/s. | |
""" | |
np.complex(0, 1) | |
return ((lamb - 1 + complex(0, omega)) * z | |
+ (2 * z * np.abs(z) ** 2) | |
- (z * np.abs(z) ** 4)) | |
@njit | |
def delta_wiener(size, dt): | |
"""Returns the random delta between two consecutive steps of a Wiener | |
process (Brownian motion). | |
:param size: desired shape of the output array. | |
:param dt: float, time increment in seconds. | |
""" | |
return np.sqrt(dt) * np.random.randn(*size) | |
@njit | |
def complex_delta_wiener(size, dt): | |
"""Returns the random delta between two consecutive steps of a complex | |
Wiener process (Brownian motion). The process is calculared as u(t) + jv(t) | |
where u and v are simple Wiener processes. | |
:param size: desired shape of the output array. | |
:param dt: float, time increment in seconds. | |
""" | |
u = delta_wiener(size, dt) | |
v = delta_wiener(size, dt) | |
return u + 1j * v | |
@njit | |
def step(z): | |
""" | |
Compute one time step of the system. | |
:param z: complex, the current state. | |
:return: comples, the change from the current state to the next s.t. | |
z[t+1] = z[t] + step(z[t]). | |
""" | |
# Matrix with pairwise differences of nodes | |
delta_z = z.reshape(-1, 1) - z.reshape(1, -1) | |
# Compute diffusive coupling | |
diffusive_coupling = np.diag(A_norm.T.dot(delta_z)) | |
# Compute change in state | |
update_from_self = f(z, lamb=lamb, omega=omega) | |
update_from_others = beta * diffusive_coupling | |
noise = alpha * complex_delta_wiener(z.shape, dt) | |
dz = (update_from_self + update_from_others) * dt + noise | |
return dz | |
@njit | |
def evolve_system(z0, steps): | |
""" | |
Evolve the system starting from the given initial state (z0) for a given | |
number of time steps (steps). | |
:param z0: complex, the initial state. | |
:param steps: int, number of steps to evolve the system for. | |
:return: list, the sequence of states. | |
""" | |
# TODO fastets way is to pre-allocate list and assign elements, but it didn't | |
# TODO seem to bring actual improvements in practice. | |
steps_in_percent = steps / 100 | |
z = [z0] | |
for i in range(steps): | |
if not i % steps_in_percent: | |
print(i / steps_in_percent, '%') | |
dz = step(z[-1]) | |
z.append(z[-1] + dz) | |
return z | |
################################################################################ | |
# Configuration | |
################################################################################ | |
N = 2 # Number of nodes in the system | |
seconds_to_generate = 100 # Number of seconds to generate | |
dt = 0.0001 # Sampling period (1/freq) | |
plots = True # Whether to produce plots with results | |
fmt = 'png' # Format of the saved figure | |
# Hyper-parameters | |
omega = 20 # Frequency of oscillations | |
alpha = 0.125 # Intensity of the noise | |
lamb = 0.5 # | |
beta = 0.1 # Coupling strength b/w nodes | |
A = np.random.randint(0, 2, (N, N)) | |
np.fill_diagonal(A, 0) | |
A_norm = normalized_adjacency(A).astype(np.complex128) | |
################################################################################ | |
# Evolve system | |
################################################################################ | |
z0 = np.zeros(N).astype(np.complex128) # Starting conditions of each node | |
steps = int(seconds_to_generate / dt) # Number of steps to generate | |
# Evolve system and time it | |
timer = -time.time() | |
timesteps = evolve_system(z0, steps) | |
timer += time.time() | |
print('Elapsed: {} ({} iter/sec)' | |
.format(datetime.timedelta(seconds=timer), steps / timer)) | |
timesteps = np.array(timesteps) | |
################################################################################ | |
# Plots | |
################################################################################ | |
if not plots: | |
exit() | |
# Plot adjacency matrix | |
plt.figure(figsize=(4, 4)) | |
plt.title('Adjacency matrix') | |
plt.imshow(A, cmap='gray') | |
plt.xticks(np.arange(N)) | |
plt.yticks(np.arange(N)) | |
plt.tight_layout() | |
plt.savefig('adjacency_matrix.{}'.format(fmt), bbox_inches='tight') | |
# Plot graph | |
plt.figure() | |
G = nx.DiGraph(A) | |
pos = nx.layout.spring_layout(G) | |
nodes = nx.draw_networkx_nodes(G, pos, node_color='orange', node_size=1000) | |
edges = nx.draw_networkx_edges(G, pos, node_size=1000, arrowsize=10, width=2, edge_color='gray') | |
labels = nx.draw_networkx_labels(G, pos, {i:i+1 for i in range(N)}) | |
plt.savefig('graph.{}'.format(fmt), bbox_inches='tight') | |
# Plot evolution of nodes in complex plane | |
rows = np.floor(np.sqrt(N)) | |
cols = np.ceil(np.sqrt(N)) | |
plt.figure(figsize=(cols * 4, rows * 4)) | |
for n in range(N): | |
plt.subplot(rows, cols, n + 1) | |
plt.title('Node {}'.format(n + 1)) | |
plt.plot(timesteps[..., n].real, timesteps[..., n].imag, linewidth=0.5) | |
plt.ylabel('Im(z)') | |
plt.xlabel('Re(z)') | |
plt.ylim(-1.6, 1.6) | |
plt.xlim(-1.6, 1.6) | |
plt.tight_layout() | |
plt.savefig('nodes_complex_plane.{}'.format(fmt), bbox_inches='tight') | |
# Plot Re of nodes vs. time | |
plt.figure(figsize=(16, N * 2)) | |
for n in range(N): | |
plt.subplot(N, 1, n + 1) | |
plt.plot(timesteps[..., n].real, linewidth=0.5) | |
plt.ylabel('Node {}'.format(n + 1)) | |
plt.ylim(-1.6, 1.6) | |
plt.tight_layout() | |
plt.savefig('nodes_re_v_time.{}'.format(fmt), bbox_inches='tight') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment