Skip to content

Instantly share code, notes, and snippets.

@BenTenmann
Last active April 19, 2021 20:46
Show Gist options
  • Save BenTenmann/e1d2e29a4fd05785ef72d6d7228e170d to your computer and use it in GitHub Desktop.
Save BenTenmann/e1d2e29a4fd05785ef72d6d7228e170d to your computer and use it in GitHub Desktop.
import numpy as np
def is_binary(x):
if np.all(np.abs(x) == 1):
return x
raise AttributeError('Inputs need to be binary: 1 or -1')
class ClassicHopfield:
def __init__(self, memories):
self.memories = is_binary(memories)
self.connections = self._generate_connections()
def _generate_connections(self):
outer_products = np.dstack([np.outer(memory) for memory in self.memories])
stack_sum = np.sum(outer_products, axis=2)
return np.fill_diagonal(stack_sum, 0)
def update(self, query, n_updates=1):
query = is_binary(query)
interaction = np.where(self.connections @ query >= 0, 1, -1)
for _ in range(n_updates):
for i, unit in enumerate(query):
interaction[i] = (-1) ** int((self.connections[i, :] @ interaction) < 0)
return interaction
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment