Skip to content

Instantly share code, notes, and snippets.

@giuseppebonaccorso
Created September 20, 2017 12:33
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save giuseppebonaccorso/36b6a7eac1a3e54063ea96aa48ba4acf to your computer and use it in GitHub Desktop.
Save giuseppebonaccorso/36b6a7eac1a3e54063ea96aa48ba4acf to your computer and use it in GitHub Desktop.
Hopfield Network
import matplotlib.pyplot as plt
import numpy as np
# Set random seed for reproducibility
np.random.seed(1000)
nb_patterns = 4
pattern_width = 4
pattern_height = 4
max_iterations = 10
# Initialize the patterns
X = np.zeros((nb_patterns, pattern_width * pattern_height))
X[0] = [-1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1]
X[1] = [-1, -1, -1, -1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1]
X[2] = [-1, -1, 1, 1, -1, -1, 1, 1, 1, 1, -1, -1, 1, 1, -1, -1]
X[3] = [1, 1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, 1, 1]
# Show the patterns
fig, ax = plt.subplots(1, nb_patterns, figsize=(10, 5))
for i in range(nb_patterns):
ax[i].matshow(X[i].reshape((pattern_height, pattern_width)), cmap='gray')
ax[i].set_xticks([])
ax[i].set_yticks([])
plt.show()
# Train the network
W = np.zeros((pattern_width * pattern_height, pattern_width * pattern_height))
for i in range(pattern_width * pattern_height):
for j in range(pattern_width * pattern_height):
if i == j or W[i, j] != 0.0:
continue
w = 0.0
for n in range(nb_patterns):
w += X[n, i] * X[n, j]
W[i, j] = w / X.shape[0]
W[j, i] = W[i, j]
# Create a corrupted test pattern
x_test = np.array([1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1])
# Recover the original patterns
A = x_test.copy()
for _ in range(max_iterations):
for i in range(pattern_width * pattern_height):
A[i] = 1.0 if np.dot(W[i], A) > 0 else -1.0
# Show corrupted and recovered patterns
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].matshow(x_test.reshape(pattern_height, pattern_width), cmap='gray')
ax[0].set_title('Corrupted pattern')
ax[0].set_xticks([])
ax[0].set_yticks([])
ax[1].matshow(A.reshape(pattern_height, pattern_width), cmap='gray')
ax[1].set_title('Recovered pattern')
ax[1].set_xticks([])
ax[1].set_yticks([])
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment