Created
May 12, 2019 20:51
-
-
Save angelormrl/7bbd02a526b03635044805e636bbd18b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Multilayer Perceptron with Backpropagation\n", | |
"This is an implemention of an mlp network with incremental backpropagation as the chosen learning algorithm." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Function Definitions" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"## ======== FUNCTION DEFINITIONS ======== ##\n", | |
"\n", | |
"def normalize_data(x):\n", | |
" data = zeros(shape(x))\n", | |
" for i in range(len(x[0])):\n", | |
" feature = [example[i] for example in x]\n", | |
" max_val = np.max(feature)\n", | |
" min_val = np.min(feature)\n", | |
" val_range = max_val - min_val\n", | |
" for j in range(len(x)):\n", | |
" data[j][i] = (x[j][i] - min_val) / val_range\n", | |
" return data\n", | |
"\n", | |
"# print weights layer by layer\n", | |
"def print_weights(weights, title):\n", | |
" print(title)\n", | |
" for layer in weights:\n", | |
" print(layer)\n", | |
"\n", | |
"# initialises weights of the network to small random values\n", | |
"def init_weights(n_inputs, n_hiddens, n_outputs, rand):\n", | |
" weights = []\n", | |
" hidden_layer = random.uniform(-rand, rand, [n_hiddens, n_inputs+1])\n", | |
" weights.append(hidden_layer)\n", | |
" output_layer = random.uniform(-rand, rand, [n_outputs,n_hiddens+1])\n", | |
" weights.append(output_layer)\n", | |
" print_weights(weights, 'initial weights:')\n", | |
" return weights\n", | |
"\n", | |
"# caluclates activation for a node by calculating the sum of the products of corresponding inputs and weights\n", | |
"def activate(weights, inputs):\n", | |
" activation = sum(weights[i]*inputs[i] for i in range(len(inputs)))\n", | |
" activation += weights[-1]\n", | |
" return activation\n", | |
"\n", | |
"# sigmoid activation function which will be the default\n", | |
"def sigmoid(x):\n", | |
" return (1.0/(1.0+np.exp(-x)))\n", | |
"\n", | |
"# alternative being the tanh function\n", | |
"def tanh(x):\n", | |
" return (np.exp(2*x)-1)/(np.exp(2*x)+1)\n", | |
"\n", | |
"# calculates outputs of the network according to specific weights and inputs\n", | |
"def forward_pass(weights, x):\n", | |
" outputs = []\n", | |
" activations = []\n", | |
" inputs = x\n", | |
" for layer in weights:\n", | |
" layer_output = []\n", | |
" layer_activation = []\n", | |
" for neuron in layer:\n", | |
" activation = activate(neuron, inputs)\n", | |
" layer_output.append(sigmoid(activation))\n", | |
" layer_activation.append(activation)\n", | |
" #layer_output.append(tanh(activation))\n", | |
" inputs = layer_output\n", | |
" activations.append(layer_activation)\n", | |
" outputs.append(layer_output)\n", | |
" return outputs, activations\n", | |
"\n", | |
"def sigmoid_derivative(x):\n", | |
" return (x * (1.0 - x))\n", | |
"\n", | |
"def tanh_derivative(x):\n", | |
" return (1.0 - (x**2))\n", | |
"\n", | |
"# calculates beta values for nodes in the output and hidden layers\n", | |
"def backward_pass(weights, outputs, y):\n", | |
" betas = [[],[]]\n", | |
" errors = []\n", | |
" desired = zeros(len(outputs[1]))\n", | |
" desired[y] = 1\n", | |
" # betas for output neurons\n", | |
" for i in range(len(outputs[1])): \n", | |
" error = desired[i]-outputs[1][i]\n", | |
" errors.append(error)\n", | |
" beta = sigmoid_derivative(outputs[1][i])*error \n", | |
" betas[1].append(beta)\n", | |
" # betas for hidden neurons\n", | |
" for i in range(len(outputs[0])):\n", | |
" hidden_error = []\n", | |
" hidden_error = sum([weights[1][j][i]*betas[1][j] for j in range(len(outputs[1]))])\n", | |
" hidden_beta = sigmoid_derivative(outputs[0][i])*hidden_error\n", | |
" betas[0].append(hidden_beta)\n", | |
" return betas, errors\n", | |
"\n", | |
"# updates all weights in network according to learning rate and beta values\n", | |
"def update_weights(_weights, betas, learn_rate, inputs):\n", | |
" weights = _weights\n", | |
" for i in range(len(weights)):\n", | |
" for j in range(len(weights[i])):\n", | |
" for k in range(len(weights[i][j])-1):\n", | |
" delta = (inputs[i][j] * betas[i][j]) * learn_rate\n", | |
" weights[i][j][k] += delta\n", | |
" delta = betas[i][j] * learn_rate\n", | |
" weights[i][j][-1] += delta\n", | |
" return weights\n", | |
"\n", | |
"# calculates the total sum squared of error to assess accuracy\n", | |
"def mean_squared_error(errors):\n", | |
" mse = sum([error**2 for error in errors])/len(errors)\n", | |
" return mse\n", | |
"\n", | |
"# train network on input data for set number of epochs\n", | |
"def train(x, y, n_hiddens, learn_rate = 1, max_epoch=100, rand=0.5):\n", | |
" # weights initialised according to size of x and y and number of hidden nodes\n", | |
" weights = init_weights(len(x[0]), n_hiddens, len(set(y)), rand)\n", | |
" # train network\n", | |
" for i in range(max_epoch):\n", | |
" epoch_errors = []\n", | |
" for j in range(len(x)):\n", | |
" # outputs of hiddens and output nodes calculated\n", | |
" outputs, inputs = forward_pass(weights, x[j])\n", | |
" # betas caulculated for each nodes by backpropagating error signal\n", | |
" betas, errors = backward_pass(weights, outputs, y[j])\n", | |
" epoch_errors.extend(errors)\n", | |
" # weights updated according to nodes beta values and learning rate\n", | |
" weights = update_weights(weights, betas, learn_rate, inputs)\n", | |
" epoch_tss = mean_squared_error(epoch_errors)\n", | |
" print('mse:',epoch_tss)\n", | |
" # print final weights\n", | |
" print_weights(weights, 'updated weights:')\n", | |
" return weights\n", | |
"\n", | |
"# predict class for input data\n", | |
"def predict(weights, x):\n", | |
" outputs, activations = forward_pass(weights, x)\n", | |
" outputs = outputs[1]\n", | |
" print(outputs)\n", | |
" prediction = outputs.index(max(outputs))\n", | |
" return prediction" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Importing Dataset and Randomising Order" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[[ 0.22222222 0.625 0.06779661 0.04166667]\n", | |
" [ 0.16666667 0.41666667 0.06779661 0.04166667]\n", | |
" [ 0.11111111 0.5 0.05084746 0.04166667]\n", | |
" [ 0.08333333 0.45833333 0.08474576 0.04166667]\n", | |
" [ 0.19444444 0.66666667 0.06779661 0.04166667]\n", | |
" [ 0.30555556 0.79166667 0.11864407 0.125 ]\n", | |
" [ 0.08333333 0.58333333 0.06779661 0.08333333]\n", | |
" [ 0.19444444 0.58333333 0.08474576 0.04166667]\n", | |
" [ 0.02777778 0.375 0.06779661 0.04166667]\n", | |
" [ 0.16666667 0.45833333 0.08474576 0. ]\n", | |
" [ 0.30555556 0.70833333 0.08474576 0.04166667]\n", | |
" [ 0.13888889 0.58333333 0.10169492 0.04166667]\n", | |
" [ 0.13888889 0.41666667 0.06779661 0. ]\n", | |
" [ 0. 0.41666667 0.01694915 0. ]\n", | |
" [ 0.41666667 0.83333333 0.03389831 0.04166667]\n", | |
" [ 0.38888889 1. 0.08474576 0.125 ]\n", | |
" [ 0.30555556 0.79166667 0.05084746 0.125 ]\n", | |
" [ 0.22222222 0.625 0.06779661 0.08333333]\n", | |
" [ 0.38888889 0.75 0.11864407 0.08333333]\n", | |
" [ 0.22222222 0.75 0.08474576 0.08333333]\n", | |
" [ 0.30555556 0.58333333 0.11864407 0.04166667]\n", | |
" [ 0.22222222 0.70833333 0.08474576 0.125 ]\n", | |
" [ 0.08333333 0.66666667 0. 0.04166667]\n", | |
" [ 0.22222222 0.54166667 0.11864407 0.16666667]\n", | |
" [ 0.13888889 0.58333333 0.15254237 0.04166667]\n", | |
" [ 0.19444444 0.41666667 0.10169492 0.04166667]\n", | |
" [ 0.19444444 0.58333333 0.10169492 0.125 ]\n", | |
" [ 0.25 0.625 0.08474576 0.04166667]\n", | |
" [ 0.25 0.58333333 0.06779661 0.04166667]\n", | |
" [ 0.11111111 0.5 0.10169492 0.04166667]\n", | |
" [ 0.13888889 0.45833333 0.10169492 0.04166667]\n", | |
" [ 0.30555556 0.58333333 0.08474576 0.125 ]\n", | |
" [ 0.25 0.875 0.08474576 0. ]\n", | |
" [ 0.33333333 0.91666667 0.06779661 0.04166667]\n", | |
" [ 0.16666667 0.45833333 0.08474576 0. ]\n", | |
" [ 0.19444444 0.5 0.03389831 0.04166667]\n", | |
" [ 0.33333333 0.625 0.05084746 0.04166667]\n", | |
" [ 0.16666667 0.45833333 0.08474576 0. ]\n", | |
" [ 0.02777778 0.41666667 0.05084746 0.04166667]\n", | |
" [ 0.22222222 0.58333333 0.08474576 0.04166667]\n", | |
" [ 0.19444444 0.625 0.05084746 0.08333333]\n", | |
" [ 0.05555556 0.125 0.05084746 0.08333333]\n", | |
" [ 0.02777778 0.5 0.05084746 0.04166667]\n", | |
" [ 0.19444444 0.625 0.10169492 0.20833333]\n", | |
" [ 0.22222222 0.75 0.15254237 0.125 ]\n", | |
" [ 0.13888889 0.41666667 0.06779661 0.08333333]\n", | |
" [ 0.22222222 0.75 0.10169492 0.04166667]\n", | |
" [ 0.08333333 0.5 0.06779661 0.04166667]\n", | |
" [ 0.27777778 0.70833333 0.08474576 0.04166667]\n", | |
" [ 0.19444444 0.54166667 0.06779661 0.04166667]\n", | |
" [ 0.75 0.5 0.62711864 0.54166667]\n", | |
" [ 0.58333333 0.5 0.59322034 0.58333333]\n", | |
" [ 0.72222222 0.45833333 0.66101695 0.58333333]\n", | |
" [ 0.33333333 0.125 0.50847458 0.5 ]\n", | |
" [ 0.61111111 0.33333333 0.61016949 0.58333333]\n", | |
" [ 0.38888889 0.33333333 0.59322034 0.5 ]\n", | |
" [ 0.55555556 0.54166667 0.62711864 0.625 ]\n", | |
" [ 0.16666667 0.16666667 0.38983051 0.375 ]\n", | |
" [ 0.63888889 0.375 0.61016949 0.5 ]\n", | |
" [ 0.25 0.29166667 0.49152542 0.54166667]\n", | |
" [ 0.19444444 0. 0.42372881 0.375 ]\n", | |
" [ 0.44444444 0.41666667 0.54237288 0.58333333]\n", | |
" [ 0.47222222 0.08333333 0.50847458 0.375 ]\n", | |
" [ 0.5 0.375 0.62711864 0.54166667]\n", | |
" [ 0.36111111 0.375 0.44067797 0.5 ]\n", | |
" [ 0.66666667 0.45833333 0.57627119 0.54166667]\n", | |
" [ 0.36111111 0.41666667 0.59322034 0.58333333]\n", | |
" [ 0.41666667 0.29166667 0.52542373 0.375 ]\n", | |
" [ 0.52777778 0.08333333 0.59322034 0.58333333]\n", | |
" [ 0.36111111 0.20833333 0.49152542 0.41666667]\n", | |
" [ 0.44444444 0.5 0.6440678 0.70833333]\n", | |
" [ 0.5 0.33333333 0.50847458 0.5 ]\n", | |
" [ 0.55555556 0.20833333 0.66101695 0.58333333]\n", | |
" [ 0.5 0.33333333 0.62711864 0.45833333]\n", | |
" [ 0.58333333 0.375 0.55932203 0.5 ]\n", | |
" [ 0.63888889 0.41666667 0.57627119 0.54166667]\n", | |
" [ 0.69444444 0.33333333 0.6440678 0.54166667]\n", | |
" [ 0.66666667 0.41666667 0.6779661 0.66666667]\n", | |
" [ 0.47222222 0.375 0.59322034 0.58333333]\n", | |
" [ 0.38888889 0.25 0.42372881 0.375 ]\n", | |
" [ 0.33333333 0.16666667 0.47457627 0.41666667]\n", | |
" [ 0.33333333 0.16666667 0.45762712 0.375 ]\n", | |
" [ 0.41666667 0.29166667 0.49152542 0.45833333]\n", | |
" [ 0.47222222 0.29166667 0.69491525 0.625 ]\n", | |
" [ 0.30555556 0.41666667 0.59322034 0.58333333]\n", | |
" [ 0.47222222 0.58333333 0.59322034 0.625 ]\n", | |
" [ 0.66666667 0.45833333 0.62711864 0.58333333]\n", | |
" [ 0.55555556 0.125 0.57627119 0.5 ]\n", | |
" [ 0.36111111 0.41666667 0.52542373 0.5 ]\n", | |
" [ 0.33333333 0.20833333 0.50847458 0.5 ]\n", | |
" [ 0.33333333 0.25 0.57627119 0.45833333]\n", | |
" [ 0.5 0.41666667 0.61016949 0.54166667]\n", | |
" [ 0.41666667 0.25 0.50847458 0.45833333]\n", | |
" [ 0.19444444 0.125 0.38983051 0.375 ]\n", | |
" [ 0.36111111 0.29166667 0.54237288 0.5 ]\n", | |
" [ 0.38888889 0.41666667 0.54237288 0.45833333]\n", | |
" [ 0.38888889 0.375 0.54237288 0.5 ]\n", | |
" [ 0.52777778 0.375 0.55932203 0.5 ]\n", | |
" [ 0.22222222 0.20833333 0.33898305 0.41666667]\n", | |
" [ 0.38888889 0.33333333 0.52542373 0.5 ]\n", | |
" [ 0.55555556 0.54166667 0.84745763 1. ]\n", | |
" [ 0.41666667 0.29166667 0.69491525 0.75 ]\n", | |
" [ 0.77777778 0.41666667 0.83050847 0.83333333]\n", | |
" [ 0.55555556 0.375 0.77966102 0.70833333]\n", | |
" [ 0.61111111 0.41666667 0.81355932 0.875 ]\n", | |
" [ 0.91666667 0.41666667 0.94915254 0.83333333]\n", | |
" [ 0.16666667 0.20833333 0.59322034 0.66666667]\n", | |
" [ 0.83333333 0.375 0.89830508 0.70833333]\n", | |
" [ 0.66666667 0.20833333 0.81355932 0.70833333]\n", | |
" [ 0.80555556 0.66666667 0.86440678 1. ]\n", | |
" [ 0.61111111 0.5 0.69491525 0.79166667]\n", | |
" [ 0.58333333 0.29166667 0.72881356 0.75 ]\n", | |
" [ 0.69444444 0.41666667 0.76271186 0.83333333]\n", | |
" [ 0.38888889 0.20833333 0.6779661 0.79166667]\n", | |
" [ 0.41666667 0.33333333 0.69491525 0.95833333]\n", | |
" [ 0.58333333 0.5 0.72881356 0.91666667]\n", | |
" [ 0.61111111 0.41666667 0.76271186 0.70833333]\n", | |
" [ 0.94444444 0.75 0.96610169 0.875 ]\n", | |
" [ 0.94444444 0.25 1. 0.91666667]\n", | |
" [ 0.47222222 0.08333333 0.6779661 0.58333333]\n", | |
" [ 0.72222222 0.5 0.79661017 0.91666667]\n", | |
" [ 0.36111111 0.33333333 0.66101695 0.79166667]\n", | |
" [ 0.94444444 0.33333333 0.96610169 0.79166667]\n", | |
" [ 0.55555556 0.29166667 0.66101695 0.70833333]\n", | |
" [ 0.66666667 0.54166667 0.79661017 0.83333333]\n", | |
" [ 0.80555556 0.5 0.84745763 0.70833333]\n", | |
" [ 0.52777778 0.33333333 0.6440678 0.70833333]\n", | |
" [ 0.5 0.41666667 0.66101695 0.70833333]\n", | |
" [ 0.58333333 0.33333333 0.77966102 0.83333333]\n", | |
" [ 0.80555556 0.41666667 0.81355932 0.625 ]\n", | |
" [ 0.86111111 0.33333333 0.86440678 0.75 ]\n", | |
" [ 1. 0.75 0.91525424 0.79166667]\n", | |
" [ 0.58333333 0.33333333 0.77966102 0.875 ]\n", | |
" [ 0.55555556 0.33333333 0.69491525 0.58333333]\n", | |
" [ 0.5 0.25 0.77966102 0.54166667]\n", | |
" [ 0.94444444 0.41666667 0.86440678 0.91666667]\n", | |
" [ 0.55555556 0.58333333 0.77966102 0.95833333]\n", | |
" [ 0.58333333 0.45833333 0.76271186 0.70833333]\n", | |
" [ 0.47222222 0.41666667 0.6440678 0.70833333]\n", | |
" [ 0.72222222 0.45833333 0.74576271 0.83333333]\n", | |
" [ 0.66666667 0.45833333 0.77966102 0.95833333]\n", | |
" [ 0.72222222 0.45833333 0.69491525 0.91666667]\n", | |
" [ 0.41666667 0.29166667 0.69491525 0.75 ]\n", | |
" [ 0.69444444 0.5 0.83050847 0.91666667]\n", | |
" [ 0.66666667 0.54166667 0.79661017 1. ]\n", | |
" [ 0.66666667 0.41666667 0.71186441 0.91666667]\n", | |
" [ 0.55555556 0.20833333 0.6779661 0.75 ]\n", | |
" [ 0.61111111 0.41666667 0.71186441 0.79166667]\n", | |
" [ 0.52777778 0.58333333 0.74576271 0.91666667]\n", | |
" [ 0.44444444 0.41666667 0.69491525 0.70833333]]\n" | |
] | |
} | |
], | |
"source": [ | |
"from sklearn import datasets\n", | |
"iris = datasets.load_iris()\n", | |
"x=iris.data \n", | |
"y=iris.target\n", | |
"\n", | |
"x = normalize_data(x)\n", | |
"print(x) \n", | |
"\n", | |
"indices = np.arange(x.shape[0])\n", | |
"indices = np.random.permutation(indices)\n", | |
"\n", | |
"#bins = np.array_split(indices, 2)\n", | |
"\n", | |
"x_rand = x[indices]\n", | |
"y_rand = y[indices]\n", | |
"\n", | |
"bins = np.array_split(indices, 2)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Network Training and Testing" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"initial weights:\n", | |
"[[-0.41432542 -0.06104027 0.34875917 0.07392805 0.08620412]\n", | |
" [-0.23418321 -0.44667212 -0.18794773 0.22702542 -0.19077594]\n", | |
" [-0.20468493 -0.3616823 0.06805208 0.16114287 0.05910071]\n", | |
" [-0.29013359 0.11301177 -0.16516486 -0.3799814 -0.01003766]\n", | |
" [ 0.20282773 -0.34839668 0.03236127 -0.23377424 -0.14829255]\n", | |
" [ 0.31936647 0.34000365 -0.22747626 0.41429474 0.125919 ]\n", | |
" [ 0.13889251 0.1564796 0.48611292 -0.09908335 -0.41318986]\n", | |
" [ 0.02173519 0.42857869 -0.31108009 -0.44043815 -0.38247929]]\n", | |
"[[-0.25044637 0.05358339 0.16990451 0.37468224 0.43706985 0.42735406\n", | |
" -0.12699772 0.46502247 -0.37991754]\n", | |
" [-0.14237327 -0.1583937 0.33256144 0.2671041 -0.14152726 0.44842334\n", | |
" -0.30612801 -0.12864753 -0.31378353]\n", | |
" [-0.39447518 -0.02335355 0.47218536 0.19400707 0.41905433 0.01613016\n", | |
" 0.27377546 -0.40191487 0.36633742]]\n", | |
"('mse:', 0.24234808059831084)\n", | |
"('mse:', 0.2512161206382148)\n", | |
"('mse:', 0.24921974640476177)\n", | |
"('mse:', 0.25094530074157395)\n", | |
"('mse:', 0.25194228867033142)\n", | |
"('mse:', 0.25240853749491238)\n", | |
"('mse:', 0.25268522544280231)\n", | |
"('mse:', 0.25288587774600263)\n", | |
"('mse:', 0.2530380081909297)\n", | |
"('mse:', 0.25315278909618089)\n", | |
"('mse:', 0.25323621136682867)\n", | |
"('mse:', 0.25329142797411686)\n", | |
"('mse:', 0.25331976679144685)\n", | |
"('mse:', 0.25332104422530388)\n", | |
"('mse:', 0.25329389343723702)\n", | |
"('mse:', 0.25323640169851636)\n", | |
"('mse:', 0.25314778402050797)\n", | |
"('mse:', 0.25303365342192657)\n", | |
"('mse:', 0.25292037655711208)\n", | |
"('mse:', 0.25287016604620938)\n", | |
"('mse:', 0.25279210411944142)\n", | |
"('mse:', 0.25223616043840463)\n", | |
"('mse:', 0.25174414746512008)\n", | |
"('mse:', 0.25148355169804776)\n", | |
"('mse:', 0.25097377119770203)\n", | |
"('mse:', 0.25065577385058835)\n", | |
"('mse:', 0.31473918881891566)\n", | |
"('mse:', 0.32855071617224313)\n", | |
"('mse:', 0.32848633126898263)\n", | |
"('mse:', 0.32842202319593949)\n", | |
"('mse:', 0.3283679867170271)\n", | |
"('mse:', 0.32832349695422369)\n", | |
"('mse:', 0.32828709249740906)\n", | |
"('mse:', 0.32825774946207442)\n", | |
"('mse:', 0.32823496585434114)\n", | |
"('mse:', 0.32821862559793019)\n", | |
"('mse:', 0.32820882030273313)\n", | |
"('mse:', 0.32820568369046288)\n", | |
"('mse:', 0.32820926068205863)\n", | |
"('mse:', 0.32821941966077611)\n", | |
"('mse:', 0.32823580803998292)\n", | |
"('mse:', 0.32825784559334509)\n", | |
"('mse:', 0.32828474730707574)\n", | |
"('mse:', 0.32831556758012487)\n", | |
"('mse:', 0.32834925900758372)\n", | |
"('mse:', 0.32838473986753947)\n", | |
"('mse:', 0.32842096382228725)\n", | |
"('mse:', 0.32845698393528938)\n", | |
"('mse:', 0.32849200270121781)\n", | |
"('mse:', 0.32852540179056677)\n", | |
"('mse:', 0.32855674942112639)\n", | |
"('mse:', 0.32858578801500726)\n", | |
"('mse:', 0.32861240817982557)\n", | |
"('mse:', 0.32863661611864453)\n", | |
"('mse:', 0.32865850060662211)\n", | |
"('mse:', 0.3286782036476969)\n", | |
"('mse:', 0.32869589681100919)\n", | |
"('mse:', 0.32871176362001209)\n", | |
"('mse:', 0.32872598738710884)\n", | |
"('mse:', 0.32873874345609755)\n", | |
"('mse:', 0.32875019474800471)\n", | |
"('mse:', 0.32876048963124094)\n", | |
"('mse:', 0.32876976133418273)\n", | |
"('mse:', 0.32877812831853243)\n", | |
"('mse:', 0.32878569520364009)\n", | |
"('mse:', 0.3287925539661693)\n", | |
"('mse:', 0.32879878523796036)\n", | |
"('mse:', 0.32880445959395205)\n", | |
"('mse:', 0.32880963876858227)\n", | |
"('mse:', 0.32881437676944025)\n", | |
"('mse:', 0.32881872087601249)\n", | |
"('mse:', 0.32882271252284145)\n", | |
"('mse:', 0.32882638807290543)\n", | |
"('mse:', 0.3288297794903316)\n", | |
"('mse:', 0.32883291492289368)\n", | |
"('mse:', 0.32883581920490262)\n", | |
"('mse:', 0.32883851429060601)\n", | |
"('mse:', 0.32884101962737378)\n", | |
"('mse:', 0.32884335247697921)\n", | |
"('mse:', 0.32884552819228968)\n", | |
"('mse:', 0.32884756045573699)\n", | |
"('mse:', 0.32884946148506794)\n", | |
"('mse:', 0.32885124221109763)\n", | |
"('mse:', 0.32885291243151032)\n", | |
"('mse:', 0.32885448094415815)\n", | |
"('mse:', 0.32885595566279863)\n", | |
"('mse:', 0.32885734371777808)\n", | |
"('mse:', 0.32885865154379734)\n", | |
"('mse:', 0.32885988495657925)\n", | |
"('mse:', 0.32886104921999182)\n", | |
"('mse:', 0.32886214910495648)\n", | |
"('mse:', 0.32886318894127609)\n", | |
"('mse:', 0.32886417266335927)\n", | |
"('mse:', 0.32886510385067808)\n", | |
"('mse:', 0.32886598576367948)\n", | |
"('mse:', 0.32886682137577572)\n", | |
"('mse:', 0.32886761340194737)\n", | |
"('mse:', 0.32886836432442851)\n", | |
"('mse:', 0.32886907641587293)\n", | |
"('mse:', 0.32886975176035171)\n", | |
"updated weights:\n", | |
"[[ 26.08143941 26.43472456 26.844524 26.56969287 -15.51773436]\n", | |
" [ 11.54508632 11.33259742 11.59132181 12.00629496 -8.63510052]\n", | |
" [ 5.51164824 5.35465086 5.78438524 5.87747603 -8.47921781]\n", | |
" [ 11.02470227 11.42784763 11.149671 10.93485446 -8.69632611]\n", | |
" [ 6.75520948 6.20398507 6.58474301 6.31860751 -7.9857965 ]\n", | |
" [ 15.92908533 15.94972251 15.3822426 16.0240136 -10.32194669]\n", | |
" [ 8.09051496 8.10810205 8.43773537 7.8525391 -8.0624775 ]\n", | |
" [ 25.56342749 25.97027099 25.23061222 25.10125415 -15.15966543]]\n", | |
"[[-24.88331866 -24.5792889 -24.46296778 -24.25819005 -24.19580244\n", | |
" -24.20551823 -24.75987001 -24.16784982 6.96992045]\n", | |
" [ -1.49319467 -1.50921511 -1.01825996 -1.0837173 -1.49234866\n", | |
" -0.90239807 -1.65694941 -1.47946893 -3.03673549]\n", | |
" [ -1.33628629 -0.96516465 -0.46962574 -0.74780404 -0.52275677\n", | |
" -0.92568095 -0.66803564 -1.34372598 -5.72733897]]\n", | |
"[1.5249161757911919e-82, 1.1616040616666544e-06, 3.0405851590956717e-06]\n", | |
"('predicited class:', 2)\n", | |
"('true class:', 1)\n" | |
] | |
} | |
], | |
"source": [ | |
"n_hiddens = 8\n", | |
"\n", | |
"weights = train(x_rand, y_rand, n_hiddens)\n", | |
"\n", | |
"test_index = random.randint(len(x))\n", | |
"\n", | |
"pred = predict(weights, x[test_index])\n", | |
"print('predicited class:',pred)\n", | |
"print('true class:', y[test_index])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.14" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment