Skip to content

Instantly share code, notes, and snippets.

@mihaild
Created October 6, 2018 18:16
Show Gist options
  • Save mihaild/d2e146f854a9398f64845c39ebfb3c8f to your computer and use it in GitHub Desktop.
Save mihaild/d2e146f854a9398f64845c39ebfb3c8f to your computer and use it in GitHub Desktop.
simple median
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from tqdm import tqdm_notebook\n",
"import sklearn as sk\n",
"import sklearn.metrics\n",
"\n",
"import keras\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"def mean_absolute_percentage_error(y_true, y_pred): \n",
" y_true, y_pred = np.array(y_true), np.array(y_pred)\n",
" return np.mean(np.abs((y_true - y_pred) / y_true)) * 100"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(42)"
]
},
{
"cell_type": "code",
"execution_count": 89,
"metadata": {},
"outputs": [],
"source": [
"xs_train = np.random.uniform(-1, 1, (1000, 5))\n",
"ys_train = np.median(xs_train, axis=1)\n",
"\n",
"xs_test = np.random.uniform(-1, 1, (1000, 5))\n",
"ys_test = np.median(xs_test, axis=1)\n",
"\n",
"xs_test2 = np.random.normal(0, 1, (1000, 5))\n",
"ys_test2 = np.median(xs_test, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
"def get_model(layers):\n",
" model = Sequential()\n",
" model.add(Dense(layers[0], input_dim=5, activation='relu'))\n",
" for l in layers[1:]:\n",
" model.add(Dense(l, input_dim=5, activation='relu'))\n",
" model.add(Dense(1))\n",
" model.compile(loss='mean_squared_error', optimizer='adam')\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 91,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[5]\n",
"0.030256047280279485 21.129024587973653\n",
"0.031121650823280233 20.429932248712433\n",
"0.39859154903909766 26.796106251682545\n",
"\n",
"[10]\n",
"0.019739526289567665 21.59204996700235\n",
"0.02095126037665723 20.815109864833577\n",
"0.42457059391257224 27.499531821913436\n",
"\n",
"[100]\n",
"0.0010199594643338689 22.463057269252957\n",
"0.004249229495881521 21.669214427122103\n",
"0.4113864029702579 27.05308372155763\n",
"\n",
"[1000]\n",
"0.0004520376869600072 22.444454632977852\n",
"0.003649622967328898 21.669094535414356\n",
"0.40105973103053405 26.673923790970306\n",
"\n",
"[10, 10]\n",
"0.005699108450494615 22.17993472187734\n",
"0.010154578177297265 21.604647398146643\n",
"0.43389573135131215 27.958277738307064\n",
"\n",
"[50, 50]\n",
"0.00020974870240867818 22.31495489879563\n",
"0.003244904106442189 21.61388607850857\n",
"0.3824939181705547 26.157605809455546\n",
"\n",
"[100, 100]\n",
"1.4699913802082056e-05 22.42388236348745\n",
"0.002165714208292056 21.72246746460389\n",
"0.38703285454082936 26.46199116474197\n",
"\n",
"[1000, 1000]\n",
"7.010253801702123e-05 22.489723406664233\n",
"0.0011533135005442645 21.7244513095109\n",
"0.3937379193300744 26.73518432667773\n",
"\n",
"[10, 10, 10]\n",
"0.008072452595890461 22.120908518254954\n",
"0.020859176172920692 21.493421902845057\n",
"0.43488613612458044 28.599210984527723\n",
"\n",
"[100, 100, 100]\n",
"2.2458645510210152e-05 22.40489708352883\n",
"0.00233151248430418 21.660590168667024\n",
"0.38288814536745314 26.311009261936274\n",
"\n",
"[10, 10, 10, 10]\n",
"0.004796239732775853 22.181825668667766\n",
"0.011370873609471229 21.566316132684758\n",
"0.39455097127413946 26.70313833759559\n",
"\n",
"[100, 100, 100, 100]\n",
"5.834447390057562e-06 22.45040559059972\n",
"0.0022323232883944954 21.658926231938512\n",
"0.3864176872546934 26.44920075218356\n",
"\n"
]
}
],
"source": [
"for arch in [\n",
" [5],\n",
" [10],\n",
" [100],\n",
" [1000],\n",
" [10, 10],\n",
" [50, 50],\n",
" [100, 100],\n",
" [1000, 1000],\n",
" [10, 10, 10],\n",
" [100, 100, 100],\n",
" [10, 10, 10, 10],\n",
" [100, 100, 100, 100],\n",
"]:\n",
" model = get_model(arch)\n",
" model.fit(xs_train, ys_train, batch_size=100, epochs=1000, verbose=0)\n",
" print(arch)\n",
" for xs, ys in [(xs_train, ys_train), (xs_test, ys_test), (xs_test2, ys_test2)]:\n",
" print(\n",
" sk.metrics.mean_squared_error(ys, model.predict(xs)),\n",
" mean_absolute_percentage_error(ys + 2, model.predict(xs) + 2),\n",
" )\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment