Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Natural Gradient Demo
import numpy as np
from sklearn.utils import shuffle
import random
import argparse
import matplotlib.pyplot as plt
import time
parser = argparse.ArgumentParser()
parser.add_argument('--ng', action='store_true')
parser.add_argument('--seed', type=int, default=10)
parser.add_argument('--num_data', type=int, default=500)
parser.add_argument('--dimension', type=int, default=4)
args = parser.parse_args()
#np.random.seed(args.seed)
#random.seed(args.seed)
num_data = int(args.num_data)
num_half_data = int(args.num_data/2)
dimension = int(args.dimension)
X0 = np.random.randn(num_half_data, dimension) - 1
X1 = np.random.randn(num_half_data, dimension) + 1
X = np.vstack([X0, X1])
t = np.vstack([np.zeros([num_half_data, 1]), np.ones([num_half_data, 1])])
X, t = shuffle(X, t)
X_train, X_test = X[:num_data-100], X[num_data-100:]
t_train, t_test = t[:num_data-100], t[num_data-100:]
# Model
W = np.random.randn(dimension, 1) * 0.01
def sigm(x):
return 1/(1+np.exp(-x))
def NLL(y, t):
return -np.mean(t*np.log(y) + (1-t)*np.log(1-y))
alpha = 0.1
losses = []
durations = []
# Training
for it in range(20):
start = time.time()
# Forward
z = X_train @ W
z = z * 0.1
y = sigm(z)
loss = NLL(y, t_train)
losses.append(loss)
# Loss
print(f'Loss: {loss:.3f}')
m = y.shape[0]
dy = (y-t_train)/(m * (y - y*y))
dz = sigm(z)*(1-sigm(z))
dW = X_train.T @ (dz * dy)
grad_loglik_z = (t_train-y)/(y - y*y) * dz
grad_loglik_W = grad_loglik_z * X_train
F = np.cov(grad_loglik_W.T)
# Step
if args.ng:
W = W - alpha * np.linalg.inv(F) @ dW
else:
W = W - alpha * dW
duration = time.time() - start
durations.append(duration)
# print(W)
color = 'b-'
if args.ng:
color = 'r-'
plt.plot(range(20), losses, color)
algo = ''
if args.ng:
algo = 'Natural gradient descent'
else:
algo = 'Gradient descent'
mean_duration = np.round(np.mean(durations), 4)
plt.title('Duration of an iteration of '+algo+' is '+ str(mean_duration))
plt.ylabel('Training loss')
plt.xlabel('Iteration')
plt.show()
y = sigm(X_test @ W).ravel()
acc = np.mean((y >= 0.5) == t_test.ravel())
# print(f'Accuracy: {acc:.3f}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment