Instantly share code, notes, and snippets.

# giuseppebonaccorso/oja.py

Last active June 8, 2019 00:19
Star You must be signed in to star a gist
Oja's rule (Hebbian Learning)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import numpy as np from sklearn.datasets import make_blobs from sklearn.preprocessing import StandardScaler # Set random seed for reproducibility np.random.seed(1000) # Create and scale dataset X, _ = make_blobs(n_samples=500, centers=2, cluster_std=5.0, random_state=1000) scaler = StandardScaler(with_std=False) Xs = scaler.fit_transform(X) # Compute eigenvalues and eigenvectors Q = np.cov(Xs.T) eigu, eigv = np.linalg.eig(Q) # Apply the Oja's rule W_oja = np.random.normal(scale=0.25, size=(2, 1)) prev_W_oja = np.ones((2, 1)) learning_rate = 0.0001 tolerance = 1e-8 while np.linalg.norm(prev_W_oja - W_oja) > tolerance: prev_W_oja = W_oja.copy() Ys = np.dot(Xs, W_oja) W_oja += learning_rate * np.sum(Ys*Xs - np.square(Ys)*W_oja.T, axis=0).reshape((2, 1))
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 # Eigenvalues print(eigu) [ 0.67152209 1.33248593] # Eigenvectors print(eigv) [[-0.70710678 -0.70710678] [ 0.70710678 -0.70710678]] # W_oja at the end of the training process print(W_oja) [[-0.70710658] [-0.70710699]]

### biggzlar commented Aug 21, 2018

Doesn't work due to overflow. Where does the `np.square(Ys)` come from?