Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Brain-State-in-a-Box Network
import matplotlib.pyplot as plt
import numpy as np
# Set random seed for reproducibility
np.random.seed(1000)
nb_patterns = 4
pattern_width = 4
pattern_height = 4
max_iterations = 100
learning_rate = 0.5
# Initialize the patterns
X = np.zeros((nb_patterns, pattern_width * pattern_height))
X[0] = [-1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1]
X[1] = [-1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1]
X[2] = [-1, -1, 1, 1, -1, -1, 1, 1, 1, 1, -1, -1, 1, 1, -1, -1]
X[3] = [1, 1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, 1, 1]
# Show the patterns
fig, ax = plt.subplots(1, nb_patterns, figsize=(10, 5))
for i in range(nb_patterns):
ax[i].matshow(X[i].reshape((pattern_height, pattern_width)), cmap='gray')
ax[i].set_xticks([])
ax[i].set_yticks([])
plt.show()
# Initialize the weight matrix
W = np.random.uniform(-0.1, 0.1, size=(pattern_width * pattern_height, pattern_width * pattern_height))
W = W + W.T
# Create a vectorized activation function
def activation(x):
if x > 1.0:
return 1.0
elif x < -1.0:
return -1.0
else:
return x
act = np.vectorize(activation)
# Train the network
for _ in range(max_iterations):
for n in range(nb_patterns):
for i in range(pattern_width * pattern_height):
for j in range(pattern_width * pattern_height):
W[i, j] += learning_rate * X[n, i] * X[n, j]
W[j, i] = W[i, j]
# Create a corrupted test pattern
x_test = np.array([1, -1, 0.7, 1, -0.8, -1, 1, 1, -1, 1, -0.75, -1, 1, 1, 0.9, 1])
# Recover the original patterns
A = x_test.copy()
for _ in range(max_iterations):
for i in range(pattern_width * pattern_height):
A[i] = activation(np.dot(W[i], A))
# Show corrupted and recovered patterns
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].matshow(x_test.reshape(pattern_height, pattern_width), cmap='gray')
ax[0].set_title('Corrupted pattern')
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[1].matshow(A.reshape(pattern_height, pattern_width), cmap='gray')
ax[1].set_title('Recovered pattern')
ax[1].set_xticks([])
ax[1].set_yticks([])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.