Skip to content

Instantly share code, notes, and snippets.

@angelormrl
Created May 12, 2019 20:51
Show Gist options
  • Save angelormrl/7bbd02a526b03635044805e636bbd18b to your computer and use it in GitHub Desktop.
Save angelormrl/7bbd02a526b03635044805e636bbd18b to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multilayer Perceptron with Backpropagation\n",
"This is an implemention of an mlp network with incremental backpropagation as the chosen learning algorithm."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Function Definitions"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"## ======== FUNCTION DEFINITIONS ======== ##\n",
"\n",
"def normalize_data(x):\n",
" data = zeros(shape(x))\n",
" for i in range(len(x[0])):\n",
" feature = [example[i] for example in x]\n",
" max_val = np.max(feature)\n",
" min_val = np.min(feature)\n",
" val_range = max_val - min_val\n",
" for j in range(len(x)):\n",
" data[j][i] = (x[j][i] - min_val) / val_range\n",
" return data\n",
"\n",
"# print weights layer by layer\n",
"def print_weights(weights, title):\n",
" print(title)\n",
" for layer in weights:\n",
" print(layer)\n",
"\n",
"# initialises weights of the network to small random values\n",
"def init_weights(n_inputs, n_hiddens, n_outputs, rand):\n",
" weights = []\n",
" hidden_layer = random.uniform(-rand, rand, [n_hiddens, n_inputs+1])\n",
" weights.append(hidden_layer)\n",
" output_layer = random.uniform(-rand, rand, [n_outputs,n_hiddens+1])\n",
" weights.append(output_layer)\n",
" print_weights(weights, 'initial weights:')\n",
" return weights\n",
"\n",
"# caluclates activation for a node by calculating the sum of the products of corresponding inputs and weights\n",
"def activate(weights, inputs):\n",
" activation = sum(weights[i]*inputs[i] for i in range(len(inputs)))\n",
" activation += weights[-1]\n",
" return activation\n",
"\n",
"# sigmoid activation function which will be the default\n",
"def sigmoid(x):\n",
" return (1.0/(1.0+np.exp(-x)))\n",
"\n",
"# alternative being the tanh function\n",
"def tanh(x):\n",
" return (np.exp(2*x)-1)/(np.exp(2*x)+1)\n",
"\n",
"# calculates outputs of the network according to specific weights and inputs\n",
"def forward_pass(weights, x):\n",
" outputs = []\n",
" activations = []\n",
" inputs = x\n",
" for layer in weights:\n",
" layer_output = []\n",
" layer_activation = []\n",
" for neuron in layer:\n",
" activation = activate(neuron, inputs)\n",
" layer_output.append(sigmoid(activation))\n",
" layer_activation.append(activation)\n",
" #layer_output.append(tanh(activation))\n",
" inputs = layer_output\n",
" activations.append(layer_activation)\n",
" outputs.append(layer_output)\n",
" return outputs, activations\n",
"\n",
"def sigmoid_derivative(x):\n",
" return (x * (1.0 - x))\n",
"\n",
"def tanh_derivative(x):\n",
" return (1.0 - (x**2))\n",
"\n",
"# calculates beta values for nodes in the output and hidden layers\n",
"def backward_pass(weights, outputs, y):\n",
" betas = [[],[]]\n",
" errors = []\n",
" desired = zeros(len(outputs[1]))\n",
" desired[y] = 1\n",
" # betas for output neurons\n",
" for i in range(len(outputs[1])): \n",
" error = desired[i]-outputs[1][i]\n",
" errors.append(error)\n",
" beta = sigmoid_derivative(outputs[1][i])*error \n",
" betas[1].append(beta)\n",
" # betas for hidden neurons\n",
" for i in range(len(outputs[0])):\n",
" hidden_error = []\n",
" hidden_error = sum([weights[1][j][i]*betas[1][j] for j in range(len(outputs[1]))])\n",
" hidden_beta = sigmoid_derivative(outputs[0][i])*hidden_error\n",
" betas[0].append(hidden_beta)\n",
" return betas, errors\n",
"\n",
"# updates all weights in network according to learning rate and beta values\n",
"def update_weights(_weights, betas, learn_rate, inputs):\n",
" weights = _weights\n",
" for i in range(len(weights)):\n",
" for j in range(len(weights[i])):\n",
" for k in range(len(weights[i][j])-1):\n",
" delta = (inputs[i][j] * betas[i][j]) * learn_rate\n",
" weights[i][j][k] += delta\n",
" delta = betas[i][j] * learn_rate\n",
" weights[i][j][-1] += delta\n",
" return weights\n",
"\n",
"# calculates the total sum squared of error to assess accuracy\n",
"def mean_squared_error(errors):\n",
" mse = sum([error**2 for error in errors])/len(errors)\n",
" return mse\n",
"\n",
"# train network on input data for set number of epochs\n",
"def train(x, y, n_hiddens, learn_rate = 1, max_epoch=100, rand=0.5):\n",
" # weights initialised according to size of x and y and number of hidden nodes\n",
" weights = init_weights(len(x[0]), n_hiddens, len(set(y)), rand)\n",
" # train network\n",
" for i in range(max_epoch):\n",
" epoch_errors = []\n",
" for j in range(len(x)):\n",
" # outputs of hiddens and output nodes calculated\n",
" outputs, inputs = forward_pass(weights, x[j])\n",
" # betas caulculated for each nodes by backpropagating error signal\n",
" betas, errors = backward_pass(weights, outputs, y[j])\n",
" epoch_errors.extend(errors)\n",
" # weights updated according to nodes beta values and learning rate\n",
" weights = update_weights(weights, betas, learn_rate, inputs)\n",
" epoch_tss = mean_squared_error(epoch_errors)\n",
" print('mse:',epoch_tss)\n",
" # print final weights\n",
" print_weights(weights, 'updated weights:')\n",
" return weights\n",
"\n",
"# predict class for input data\n",
"def predict(weights, x):\n",
" outputs, activations = forward_pass(weights, x)\n",
" outputs = outputs[1]\n",
" print(outputs)\n",
" prediction = outputs.index(max(outputs))\n",
" return prediction"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Importing Dataset and Randomising Order"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 0.22222222 0.625 0.06779661 0.04166667]\n",
" [ 0.16666667 0.41666667 0.06779661 0.04166667]\n",
" [ 0.11111111 0.5 0.05084746 0.04166667]\n",
" [ 0.08333333 0.45833333 0.08474576 0.04166667]\n",
" [ 0.19444444 0.66666667 0.06779661 0.04166667]\n",
" [ 0.30555556 0.79166667 0.11864407 0.125 ]\n",
" [ 0.08333333 0.58333333 0.06779661 0.08333333]\n",
" [ 0.19444444 0.58333333 0.08474576 0.04166667]\n",
" [ 0.02777778 0.375 0.06779661 0.04166667]\n",
" [ 0.16666667 0.45833333 0.08474576 0. ]\n",
" [ 0.30555556 0.70833333 0.08474576 0.04166667]\n",
" [ 0.13888889 0.58333333 0.10169492 0.04166667]\n",
" [ 0.13888889 0.41666667 0.06779661 0. ]\n",
" [ 0. 0.41666667 0.01694915 0. ]\n",
" [ 0.41666667 0.83333333 0.03389831 0.04166667]\n",
" [ 0.38888889 1. 0.08474576 0.125 ]\n",
" [ 0.30555556 0.79166667 0.05084746 0.125 ]\n",
" [ 0.22222222 0.625 0.06779661 0.08333333]\n",
" [ 0.38888889 0.75 0.11864407 0.08333333]\n",
" [ 0.22222222 0.75 0.08474576 0.08333333]\n",
" [ 0.30555556 0.58333333 0.11864407 0.04166667]\n",
" [ 0.22222222 0.70833333 0.08474576 0.125 ]\n",
" [ 0.08333333 0.66666667 0. 0.04166667]\n",
" [ 0.22222222 0.54166667 0.11864407 0.16666667]\n",
" [ 0.13888889 0.58333333 0.15254237 0.04166667]\n",
" [ 0.19444444 0.41666667 0.10169492 0.04166667]\n",
" [ 0.19444444 0.58333333 0.10169492 0.125 ]\n",
" [ 0.25 0.625 0.08474576 0.04166667]\n",
" [ 0.25 0.58333333 0.06779661 0.04166667]\n",
" [ 0.11111111 0.5 0.10169492 0.04166667]\n",
" [ 0.13888889 0.45833333 0.10169492 0.04166667]\n",
" [ 0.30555556 0.58333333 0.08474576 0.125 ]\n",
" [ 0.25 0.875 0.08474576 0. ]\n",
" [ 0.33333333 0.91666667 0.06779661 0.04166667]\n",
" [ 0.16666667 0.45833333 0.08474576 0. ]\n",
" [ 0.19444444 0.5 0.03389831 0.04166667]\n",
" [ 0.33333333 0.625 0.05084746 0.04166667]\n",
" [ 0.16666667 0.45833333 0.08474576 0. ]\n",
" [ 0.02777778 0.41666667 0.05084746 0.04166667]\n",
" [ 0.22222222 0.58333333 0.08474576 0.04166667]\n",
" [ 0.19444444 0.625 0.05084746 0.08333333]\n",
" [ 0.05555556 0.125 0.05084746 0.08333333]\n",
" [ 0.02777778 0.5 0.05084746 0.04166667]\n",
" [ 0.19444444 0.625 0.10169492 0.20833333]\n",
" [ 0.22222222 0.75 0.15254237 0.125 ]\n",
" [ 0.13888889 0.41666667 0.06779661 0.08333333]\n",
" [ 0.22222222 0.75 0.10169492 0.04166667]\n",
" [ 0.08333333 0.5 0.06779661 0.04166667]\n",
" [ 0.27777778 0.70833333 0.08474576 0.04166667]\n",
" [ 0.19444444 0.54166667 0.06779661 0.04166667]\n",
" [ 0.75 0.5 0.62711864 0.54166667]\n",
" [ 0.58333333 0.5 0.59322034 0.58333333]\n",
" [ 0.72222222 0.45833333 0.66101695 0.58333333]\n",
" [ 0.33333333 0.125 0.50847458 0.5 ]\n",
" [ 0.61111111 0.33333333 0.61016949 0.58333333]\n",
" [ 0.38888889 0.33333333 0.59322034 0.5 ]\n",
" [ 0.55555556 0.54166667 0.62711864 0.625 ]\n",
" [ 0.16666667 0.16666667 0.38983051 0.375 ]\n",
" [ 0.63888889 0.375 0.61016949 0.5 ]\n",
" [ 0.25 0.29166667 0.49152542 0.54166667]\n",
" [ 0.19444444 0. 0.42372881 0.375 ]\n",
" [ 0.44444444 0.41666667 0.54237288 0.58333333]\n",
" [ 0.47222222 0.08333333 0.50847458 0.375 ]\n",
" [ 0.5 0.375 0.62711864 0.54166667]\n",
" [ 0.36111111 0.375 0.44067797 0.5 ]\n",
" [ 0.66666667 0.45833333 0.57627119 0.54166667]\n",
" [ 0.36111111 0.41666667 0.59322034 0.58333333]\n",
" [ 0.41666667 0.29166667 0.52542373 0.375 ]\n",
" [ 0.52777778 0.08333333 0.59322034 0.58333333]\n",
" [ 0.36111111 0.20833333 0.49152542 0.41666667]\n",
" [ 0.44444444 0.5 0.6440678 0.70833333]\n",
" [ 0.5 0.33333333 0.50847458 0.5 ]\n",
" [ 0.55555556 0.20833333 0.66101695 0.58333333]\n",
" [ 0.5 0.33333333 0.62711864 0.45833333]\n",
" [ 0.58333333 0.375 0.55932203 0.5 ]\n",
" [ 0.63888889 0.41666667 0.57627119 0.54166667]\n",
" [ 0.69444444 0.33333333 0.6440678 0.54166667]\n",
" [ 0.66666667 0.41666667 0.6779661 0.66666667]\n",
" [ 0.47222222 0.375 0.59322034 0.58333333]\n",
" [ 0.38888889 0.25 0.42372881 0.375 ]\n",
" [ 0.33333333 0.16666667 0.47457627 0.41666667]\n",
" [ 0.33333333 0.16666667 0.45762712 0.375 ]\n",
" [ 0.41666667 0.29166667 0.49152542 0.45833333]\n",
" [ 0.47222222 0.29166667 0.69491525 0.625 ]\n",
" [ 0.30555556 0.41666667 0.59322034 0.58333333]\n",
" [ 0.47222222 0.58333333 0.59322034 0.625 ]\n",
" [ 0.66666667 0.45833333 0.62711864 0.58333333]\n",
" [ 0.55555556 0.125 0.57627119 0.5 ]\n",
" [ 0.36111111 0.41666667 0.52542373 0.5 ]\n",
" [ 0.33333333 0.20833333 0.50847458 0.5 ]\n",
" [ 0.33333333 0.25 0.57627119 0.45833333]\n",
" [ 0.5 0.41666667 0.61016949 0.54166667]\n",
" [ 0.41666667 0.25 0.50847458 0.45833333]\n",
" [ 0.19444444 0.125 0.38983051 0.375 ]\n",
" [ 0.36111111 0.29166667 0.54237288 0.5 ]\n",
" [ 0.38888889 0.41666667 0.54237288 0.45833333]\n",
" [ 0.38888889 0.375 0.54237288 0.5 ]\n",
" [ 0.52777778 0.375 0.55932203 0.5 ]\n",
" [ 0.22222222 0.20833333 0.33898305 0.41666667]\n",
" [ 0.38888889 0.33333333 0.52542373 0.5 ]\n",
" [ 0.55555556 0.54166667 0.84745763 1. ]\n",
" [ 0.41666667 0.29166667 0.69491525 0.75 ]\n",
" [ 0.77777778 0.41666667 0.83050847 0.83333333]\n",
" [ 0.55555556 0.375 0.77966102 0.70833333]\n",
" [ 0.61111111 0.41666667 0.81355932 0.875 ]\n",
" [ 0.91666667 0.41666667 0.94915254 0.83333333]\n",
" [ 0.16666667 0.20833333 0.59322034 0.66666667]\n",
" [ 0.83333333 0.375 0.89830508 0.70833333]\n",
" [ 0.66666667 0.20833333 0.81355932 0.70833333]\n",
" [ 0.80555556 0.66666667 0.86440678 1. ]\n",
" [ 0.61111111 0.5 0.69491525 0.79166667]\n",
" [ 0.58333333 0.29166667 0.72881356 0.75 ]\n",
" [ 0.69444444 0.41666667 0.76271186 0.83333333]\n",
" [ 0.38888889 0.20833333 0.6779661 0.79166667]\n",
" [ 0.41666667 0.33333333 0.69491525 0.95833333]\n",
" [ 0.58333333 0.5 0.72881356 0.91666667]\n",
" [ 0.61111111 0.41666667 0.76271186 0.70833333]\n",
" [ 0.94444444 0.75 0.96610169 0.875 ]\n",
" [ 0.94444444 0.25 1. 0.91666667]\n",
" [ 0.47222222 0.08333333 0.6779661 0.58333333]\n",
" [ 0.72222222 0.5 0.79661017 0.91666667]\n",
" [ 0.36111111 0.33333333 0.66101695 0.79166667]\n",
" [ 0.94444444 0.33333333 0.96610169 0.79166667]\n",
" [ 0.55555556 0.29166667 0.66101695 0.70833333]\n",
" [ 0.66666667 0.54166667 0.79661017 0.83333333]\n",
" [ 0.80555556 0.5 0.84745763 0.70833333]\n",
" [ 0.52777778 0.33333333 0.6440678 0.70833333]\n",
" [ 0.5 0.41666667 0.66101695 0.70833333]\n",
" [ 0.58333333 0.33333333 0.77966102 0.83333333]\n",
" [ 0.80555556 0.41666667 0.81355932 0.625 ]\n",
" [ 0.86111111 0.33333333 0.86440678 0.75 ]\n",
" [ 1. 0.75 0.91525424 0.79166667]\n",
" [ 0.58333333 0.33333333 0.77966102 0.875 ]\n",
" [ 0.55555556 0.33333333 0.69491525 0.58333333]\n",
" [ 0.5 0.25 0.77966102 0.54166667]\n",
" [ 0.94444444 0.41666667 0.86440678 0.91666667]\n",
" [ 0.55555556 0.58333333 0.77966102 0.95833333]\n",
" [ 0.58333333 0.45833333 0.76271186 0.70833333]\n",
" [ 0.47222222 0.41666667 0.6440678 0.70833333]\n",
" [ 0.72222222 0.45833333 0.74576271 0.83333333]\n",
" [ 0.66666667 0.45833333 0.77966102 0.95833333]\n",
" [ 0.72222222 0.45833333 0.69491525 0.91666667]\n",
" [ 0.41666667 0.29166667 0.69491525 0.75 ]\n",
" [ 0.69444444 0.5 0.83050847 0.91666667]\n",
" [ 0.66666667 0.54166667 0.79661017 1. ]\n",
" [ 0.66666667 0.41666667 0.71186441 0.91666667]\n",
" [ 0.55555556 0.20833333 0.6779661 0.75 ]\n",
" [ 0.61111111 0.41666667 0.71186441 0.79166667]\n",
" [ 0.52777778 0.58333333 0.74576271 0.91666667]\n",
" [ 0.44444444 0.41666667 0.69491525 0.70833333]]\n"
]
}
],
"source": [
"from sklearn import datasets\n",
"iris = datasets.load_iris()\n",
"x=iris.data \n",
"y=iris.target\n",
"\n",
"x = normalize_data(x)\n",
"print(x) \n",
"\n",
"indices = np.arange(x.shape[0])\n",
"indices = np.random.permutation(indices)\n",
"\n",
"#bins = np.array_split(indices, 2)\n",
"\n",
"x_rand = x[indices]\n",
"y_rand = y[indices]\n",
"\n",
"bins = np.array_split(indices, 2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Network Training and Testing"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"initial weights:\n",
"[[-0.41432542 -0.06104027 0.34875917 0.07392805 0.08620412]\n",
" [-0.23418321 -0.44667212 -0.18794773 0.22702542 -0.19077594]\n",
" [-0.20468493 -0.3616823 0.06805208 0.16114287 0.05910071]\n",
" [-0.29013359 0.11301177 -0.16516486 -0.3799814 -0.01003766]\n",
" [ 0.20282773 -0.34839668 0.03236127 -0.23377424 -0.14829255]\n",
" [ 0.31936647 0.34000365 -0.22747626 0.41429474 0.125919 ]\n",
" [ 0.13889251 0.1564796 0.48611292 -0.09908335 -0.41318986]\n",
" [ 0.02173519 0.42857869 -0.31108009 -0.44043815 -0.38247929]]\n",
"[[-0.25044637 0.05358339 0.16990451 0.37468224 0.43706985 0.42735406\n",
" -0.12699772 0.46502247 -0.37991754]\n",
" [-0.14237327 -0.1583937 0.33256144 0.2671041 -0.14152726 0.44842334\n",
" -0.30612801 -0.12864753 -0.31378353]\n",
" [-0.39447518 -0.02335355 0.47218536 0.19400707 0.41905433 0.01613016\n",
" 0.27377546 -0.40191487 0.36633742]]\n",
"('mse:', 0.24234808059831084)\n",
"('mse:', 0.2512161206382148)\n",
"('mse:', 0.24921974640476177)\n",
"('mse:', 0.25094530074157395)\n",
"('mse:', 0.25194228867033142)\n",
"('mse:', 0.25240853749491238)\n",
"('mse:', 0.25268522544280231)\n",
"('mse:', 0.25288587774600263)\n",
"('mse:', 0.2530380081909297)\n",
"('mse:', 0.25315278909618089)\n",
"('mse:', 0.25323621136682867)\n",
"('mse:', 0.25329142797411686)\n",
"('mse:', 0.25331976679144685)\n",
"('mse:', 0.25332104422530388)\n",
"('mse:', 0.25329389343723702)\n",
"('mse:', 0.25323640169851636)\n",
"('mse:', 0.25314778402050797)\n",
"('mse:', 0.25303365342192657)\n",
"('mse:', 0.25292037655711208)\n",
"('mse:', 0.25287016604620938)\n",
"('mse:', 0.25279210411944142)\n",
"('mse:', 0.25223616043840463)\n",
"('mse:', 0.25174414746512008)\n",
"('mse:', 0.25148355169804776)\n",
"('mse:', 0.25097377119770203)\n",
"('mse:', 0.25065577385058835)\n",
"('mse:', 0.31473918881891566)\n",
"('mse:', 0.32855071617224313)\n",
"('mse:', 0.32848633126898263)\n",
"('mse:', 0.32842202319593949)\n",
"('mse:', 0.3283679867170271)\n",
"('mse:', 0.32832349695422369)\n",
"('mse:', 0.32828709249740906)\n",
"('mse:', 0.32825774946207442)\n",
"('mse:', 0.32823496585434114)\n",
"('mse:', 0.32821862559793019)\n",
"('mse:', 0.32820882030273313)\n",
"('mse:', 0.32820568369046288)\n",
"('mse:', 0.32820926068205863)\n",
"('mse:', 0.32821941966077611)\n",
"('mse:', 0.32823580803998292)\n",
"('mse:', 0.32825784559334509)\n",
"('mse:', 0.32828474730707574)\n",
"('mse:', 0.32831556758012487)\n",
"('mse:', 0.32834925900758372)\n",
"('mse:', 0.32838473986753947)\n",
"('mse:', 0.32842096382228725)\n",
"('mse:', 0.32845698393528938)\n",
"('mse:', 0.32849200270121781)\n",
"('mse:', 0.32852540179056677)\n",
"('mse:', 0.32855674942112639)\n",
"('mse:', 0.32858578801500726)\n",
"('mse:', 0.32861240817982557)\n",
"('mse:', 0.32863661611864453)\n",
"('mse:', 0.32865850060662211)\n",
"('mse:', 0.3286782036476969)\n",
"('mse:', 0.32869589681100919)\n",
"('mse:', 0.32871176362001209)\n",
"('mse:', 0.32872598738710884)\n",
"('mse:', 0.32873874345609755)\n",
"('mse:', 0.32875019474800471)\n",
"('mse:', 0.32876048963124094)\n",
"('mse:', 0.32876976133418273)\n",
"('mse:', 0.32877812831853243)\n",
"('mse:', 0.32878569520364009)\n",
"('mse:', 0.3287925539661693)\n",
"('mse:', 0.32879878523796036)\n",
"('mse:', 0.32880445959395205)\n",
"('mse:', 0.32880963876858227)\n",
"('mse:', 0.32881437676944025)\n",
"('mse:', 0.32881872087601249)\n",
"('mse:', 0.32882271252284145)\n",
"('mse:', 0.32882638807290543)\n",
"('mse:', 0.3288297794903316)\n",
"('mse:', 0.32883291492289368)\n",
"('mse:', 0.32883581920490262)\n",
"('mse:', 0.32883851429060601)\n",
"('mse:', 0.32884101962737378)\n",
"('mse:', 0.32884335247697921)\n",
"('mse:', 0.32884552819228968)\n",
"('mse:', 0.32884756045573699)\n",
"('mse:', 0.32884946148506794)\n",
"('mse:', 0.32885124221109763)\n",
"('mse:', 0.32885291243151032)\n",
"('mse:', 0.32885448094415815)\n",
"('mse:', 0.32885595566279863)\n",
"('mse:', 0.32885734371777808)\n",
"('mse:', 0.32885865154379734)\n",
"('mse:', 0.32885988495657925)\n",
"('mse:', 0.32886104921999182)\n",
"('mse:', 0.32886214910495648)\n",
"('mse:', 0.32886318894127609)\n",
"('mse:', 0.32886417266335927)\n",
"('mse:', 0.32886510385067808)\n",
"('mse:', 0.32886598576367948)\n",
"('mse:', 0.32886682137577572)\n",
"('mse:', 0.32886761340194737)\n",
"('mse:', 0.32886836432442851)\n",
"('mse:', 0.32886907641587293)\n",
"('mse:', 0.32886975176035171)\n",
"updated weights:\n",
"[[ 26.08143941 26.43472456 26.844524 26.56969287 -15.51773436]\n",
" [ 11.54508632 11.33259742 11.59132181 12.00629496 -8.63510052]\n",
" [ 5.51164824 5.35465086 5.78438524 5.87747603 -8.47921781]\n",
" [ 11.02470227 11.42784763 11.149671 10.93485446 -8.69632611]\n",
" [ 6.75520948 6.20398507 6.58474301 6.31860751 -7.9857965 ]\n",
" [ 15.92908533 15.94972251 15.3822426 16.0240136 -10.32194669]\n",
" [ 8.09051496 8.10810205 8.43773537 7.8525391 -8.0624775 ]\n",
" [ 25.56342749 25.97027099 25.23061222 25.10125415 -15.15966543]]\n",
"[[-24.88331866 -24.5792889 -24.46296778 -24.25819005 -24.19580244\n",
" -24.20551823 -24.75987001 -24.16784982 6.96992045]\n",
" [ -1.49319467 -1.50921511 -1.01825996 -1.0837173 -1.49234866\n",
" -0.90239807 -1.65694941 -1.47946893 -3.03673549]\n",
" [ -1.33628629 -0.96516465 -0.46962574 -0.74780404 -0.52275677\n",
" -0.92568095 -0.66803564 -1.34372598 -5.72733897]]\n",
"[1.5249161757911919e-82, 1.1616040616666544e-06, 3.0405851590956717e-06]\n",
"('predicited class:', 2)\n",
"('true class:', 1)\n"
]
}
],
"source": [
"n_hiddens = 8\n",
"\n",
"weights = train(x_rand, y_rand, n_hiddens)\n",
"\n",
"test_index = random.randint(len(x))\n",
"\n",
"pred = predict(weights, x[test_index])\n",
"print('predicited class:',pred)\n",
"print('true class:', y[test_index])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment