Created
October 6, 2018 18:16
-
-
Save mihaild/d2e146f854a9398f64845c39ebfb3c8f to your computer and use it in GitHub Desktop.
simple median
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 57, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from tqdm import tqdm_notebook\n", | |
"import sklearn as sk\n", | |
"import sklearn.metrics\n", | |
"\n", | |
"import keras\n", | |
"from keras.models import Sequential\n", | |
"from keras.layers import Dense" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 46, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def mean_absolute_percentage_error(y_true, y_pred): \n", | |
" y_true, y_pred = np.array(y_true), np.array(y_pred)\n", | |
" return np.mean(np.abs((y_true - y_pred) / y_true)) * 100" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 88, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"np.random.seed(42)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 89, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"xs_train = np.random.uniform(-1, 1, (1000, 5))\n", | |
"ys_train = np.median(xs_train, axis=1)\n", | |
"\n", | |
"xs_test = np.random.uniform(-1, 1, (1000, 5))\n", | |
"ys_test = np.median(xs_test, axis=1)\n", | |
"\n", | |
"xs_test2 = np.random.normal(0, 1, (1000, 5))\n", | |
"ys_test2 = np.median(xs_test, axis=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 84, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def get_model(layers):\n", | |
" model = Sequential()\n", | |
" model.add(Dense(layers[0], input_dim=5, activation='relu'))\n", | |
" for l in layers[1:]:\n", | |
" model.add(Dense(l, input_dim=5, activation='relu'))\n", | |
" model.add(Dense(1))\n", | |
" model.compile(loss='mean_squared_error', optimizer='adam')\n", | |
" return model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 91, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[5]\n", | |
"0.030256047280279485 21.129024587973653\n", | |
"0.031121650823280233 20.429932248712433\n", | |
"0.39859154903909766 26.796106251682545\n", | |
"\n", | |
"[10]\n", | |
"0.019739526289567665 21.59204996700235\n", | |
"0.02095126037665723 20.815109864833577\n", | |
"0.42457059391257224 27.499531821913436\n", | |
"\n", | |
"[100]\n", | |
"0.0010199594643338689 22.463057269252957\n", | |
"0.004249229495881521 21.669214427122103\n", | |
"0.4113864029702579 27.05308372155763\n", | |
"\n", | |
"[1000]\n", | |
"0.0004520376869600072 22.444454632977852\n", | |
"0.003649622967328898 21.669094535414356\n", | |
"0.40105973103053405 26.673923790970306\n", | |
"\n", | |
"[10, 10]\n", | |
"0.005699108450494615 22.17993472187734\n", | |
"0.010154578177297265 21.604647398146643\n", | |
"0.43389573135131215 27.958277738307064\n", | |
"\n", | |
"[50, 50]\n", | |
"0.00020974870240867818 22.31495489879563\n", | |
"0.003244904106442189 21.61388607850857\n", | |
"0.3824939181705547 26.157605809455546\n", | |
"\n", | |
"[100, 100]\n", | |
"1.4699913802082056e-05 22.42388236348745\n", | |
"0.002165714208292056 21.72246746460389\n", | |
"0.38703285454082936 26.46199116474197\n", | |
"\n", | |
"[1000, 1000]\n", | |
"7.010253801702123e-05 22.489723406664233\n", | |
"0.0011533135005442645 21.7244513095109\n", | |
"0.3937379193300744 26.73518432667773\n", | |
"\n", | |
"[10, 10, 10]\n", | |
"0.008072452595890461 22.120908518254954\n", | |
"0.020859176172920692 21.493421902845057\n", | |
"0.43488613612458044 28.599210984527723\n", | |
"\n", | |
"[100, 100, 100]\n", | |
"2.2458645510210152e-05 22.40489708352883\n", | |
"0.00233151248430418 21.660590168667024\n", | |
"0.38288814536745314 26.311009261936274\n", | |
"\n", | |
"[10, 10, 10, 10]\n", | |
"0.004796239732775853 22.181825668667766\n", | |
"0.011370873609471229 21.566316132684758\n", | |
"0.39455097127413946 26.70313833759559\n", | |
"\n", | |
"[100, 100, 100, 100]\n", | |
"5.834447390057562e-06 22.45040559059972\n", | |
"0.0022323232883944954 21.658926231938512\n", | |
"0.3864176872546934 26.44920075218356\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"for arch in [\n", | |
" [5],\n", | |
" [10],\n", | |
" [100],\n", | |
" [1000],\n", | |
" [10, 10],\n", | |
" [50, 50],\n", | |
" [100, 100],\n", | |
" [1000, 1000],\n", | |
" [10, 10, 10],\n", | |
" [100, 100, 100],\n", | |
" [10, 10, 10, 10],\n", | |
" [100, 100, 100, 100],\n", | |
"]:\n", | |
" model = get_model(arch)\n", | |
" model.fit(xs_train, ys_train, batch_size=100, epochs=1000, verbose=0)\n", | |
" print(arch)\n", | |
" for xs, ys in [(xs_train, ys_train), (xs_test, ys_test), (xs_test2, ys_test2)]:\n", | |
" print(\n", | |
" sk.metrics.mean_squared_error(ys, model.predict(xs)),\n", | |
" mean_absolute_percentage_error(ys + 2, model.predict(xs) + 2),\n", | |
" )\n", | |
" print()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment