Skip to content

Instantly share code, notes, and snippets.

@303248153
Created July 3, 2018 13:19
Show Gist options
  • Save 303248153/497253a05df94b787be4fde9c5f64328 to your computer and use it in GitHub Desktop.
Save 303248153/497253a05df94b787be4fde9c5f64328 to your computer and use it in GitHub Desktop.
import numpy as np
# todo:
# - train bias
# - increase alpha when learning progress is slow
# - fix bugs
# - produce nan when weights_expected are all 0
np.random.seed(1)
np.set_printoptions(suppress=True)
class PerceptronNetwork(object):
def __init__(self, n_input, n_output):
self.n_input = n_input
self.n_output = n_output
self.weight = np.random.random((n_input, n_output))*2-1
self.bias = 0
self.alpha_weight = 1
self.alpha_bias = 1
self.cycles = 0
def train(self, arr_input, arr_output, max_cycles):
ratio_input = self.ratio(arr_input.T)
weight_update_prev = None
for i in range(max_cycles):
l1 = self.predict(arr_input)
l1_error = arr_output - l1
# detect convergence
if np.isclose(l1_error.mean(), 0):
break
l1_delta = l1_error / np.abs(arr_output)
weight_update = ratio_input.dot(l1_delta)
weight_update *= self.alpha_weight
# detect divergence
if (weight_update_prev is not None and
np.abs(weight_update).sum() >
np.abs(weight_update_prev).sum()):
self.alpha_weight /= 2.0
continue
weight_update_prev = weight_update
self.weight += weight_update
self.cycles += 1
def predict(self, arr_input):
l0 = arr_input
l1 = self.transfer(l0.dot(self.weight) + self.bias)
return l1
def transfer(self, x):
return x
def ratio(self, x):
return x / x.sum(axis=1, keepdims=True)
# parameters
weights_expected = [1, -2, 3]
learn_cases = 30
number_range = 100
max_cycles = 10000
# test network
X = np.floor(np.random.random((
learn_cases, len(weights_expected)))*100)
Y = (X * np.array(weights_expected)).sum(axis=1, keepdims=True)
network = PerceptronNetwork(X.shape[1], Y.shape[1])
network.train(X, Y, max_cycles)
A = network.predict(X)
error = np.abs((Y-A)/Y)
print("error:\n",
"mean:", np.mean(error, keepdims=True), "\n",
"max:", np.max(error, keepdims=True))
print("cycles:\n", network.cycles)
print("alpha_weight:\n", network.alpha_weight)
print("alpha_bias:\n", network.alpha_bias)
print("weight:\n", network.weight)
print("bias:\n", network.bias)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment