Skip to content

Instantly share code, notes, and snippets.

@altescy
Created November 30, 2018 14:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save altescy/5731095e8b26507937d27fbf65ba0254 to your computer and use it in GitHub Desktop.
Save altescy/5731095e8b26507937d27fbf65ba0254 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"SQLで機械学習する\n",
"===\n",
"\n",
"\n",
"### 準備\n",
"\n",
"```\n",
"$ docker run --rm -d -p3306:3306 \\\n",
" -e MYSQL_RANDOM_ROOT_PASSWORD='yes' \\\n",
" -e MYSQL_DATABASE='mysql' \\\n",
" -e MYSQL_USER='user' \\\n",
" -e MYSQL_PASSWORD='password' \\\n",
" mysql:8.0\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<sqlalchemy.engine.result.ResultProxy at 0x7fe800801240>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import pymysql\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sqlalchemy import create_engine\n",
"import warnings\n",
"%matplotlib inline\n",
"\n",
"warnings.filterwarnings('ignore')\n",
"pymysql.install_as_MySQLdb()\n",
"\n",
"np.random.seed(128)\n",
"\n",
"db_settings = {\n",
" \"host\": 'localhost',\n",
" \"database\": 'mysql',\n",
" \"user\": 'user',\n",
" \"password\": 'password',\n",
" \"port\":3306\n",
"}\n",
"\n",
"conn = create_engine('mysql://{user}:{password}@{host}:{port}/{database}'.format(**db_settings))\n",
"conn.execute('set @@cte_max_recursion_depth = 10000')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 最急降下法\n",
"\n",
"$$\n",
"\\mathrm{argmin}_{x} \\,\\, \\mathcal{f}(x), \\,\\,\\,\\,\\, \\mathrm{where} \\,\\,\\, \\mathcal{f}(x) = x^2 - 6x + 9\n",
"$$\n",
"\n",
"$$\n",
"\\frac{\\partial\\mathcal{f}}{\\partial x} = 2x - 6\\\\\n",
"w \\leftarrow w - \\eta \\frac{\\partial\\mathcal{f}}{\\partial x}\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>iter</th>\n",
" <th>x</th>\n",
" <th>y</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>197</td>\n",
" <td>2.941</td>\n",
" <td>0.003</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" iter x y\n",
"0 197 2.941 0.003"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"query = '''\n",
"with\n",
" recursive gd(iter, x, y) as (\n",
" select\n",
" 0, cast(0.0 as decimal(6,3)), cast(9 as decimal(6,3))\n",
" union all\n",
" select\n",
" iter+1, x-0.01*(2.0*x-6.0), pow(x-0.01*(2.0*x-6.0),2)-6*(x-0.01*(2.0*x-6.0)) + 9 \n",
" from\n",
" gd\n",
" where\n",
" iter < 1000\n",
" and y - (pow(x-0.1*(2.0*x-6.0),2)-6*(x-0.1*(2.0*x-6.0)) + 9) > 0.001\n",
" )\n",
"select * from gd order by iter desc limit 1\n",
"'''\n",
"\n",
"pd.read_sql(query, conn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 線形回帰\n",
"\n",
"$$\n",
"\\mathrm{argmin}_{a, b} \\,\\, \\mathcal{L}, \\,\\,\\,\\,\\, \\mathrm{where} \\,\\,\\, \\mathcal{L} = \\frac{1}{2N} \\sum_{i=1}^{N}(a+bx_i-y_i)^2\n",
"$$\n",
"\n",
"$$\n",
"\\frac{\\partial\\mathcal{L}}{\\partial a} = \\frac{1}{N} \\sum_{i=1}^{N}(a+bx_i-y_i)\\\\\n",
"\\frac{\\partial\\mathcal{L}}{\\partial b} = \\frac{1}{N} \\sum_{i=1}^{N}(a+bx_i-y_i)x_i\\\\\n",
"$$\n"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"N = 30\n",
"\n",
"df = pd.DataFrame()\n",
"df['x'] = np.random.uniform(10, size=N)\n",
"df['y'] = -3 * df['x'] + 1 + np.random.normal(scale=3.0, size=df['x'].shape)\n",
"\n",
"\n",
"df.to_sql(con=conn, name='lr', index=False, if_exists='replace')\n",
"plt.figure(figsize=(8, 6))\n",
"plt.plot(df['x'], df['y'], 'o')\n",
"plt.title(\"\")\n",
"plt.xlabel('x')\n",
"plt.ylabel('y')\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>iter</th>\n",
" <th>a</th>\n",
" <th>b</th>\n",
" <th>loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>18</td>\n",
" <td>0.34115</td>\n",
" <td>-2.84881</td>\n",
" <td>3.31067</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" iter a b loss\n",
"0 18 0.34115 -2.84881 3.31067"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"query = '''\n",
"\n",
"with\n",
" recursive agg as (\n",
" select\n",
" sum(x) s_x,\n",
" sum(y) s_y,\n",
" sum(x*x) s_x2,\n",
" sum(y*y) s_y2,\n",
" sum(x*y) s_xy,\n",
" count(*) n\n",
" from lr\n",
" ),\n",
" gd(iter, a, b, loss) as (\n",
" select\n",
" 0, cast(1.0 as decimal(10, 5)), cast(1.0 as decimal(10, 5)), cast(100.0 as decimal(10, 5))\n",
" from lr\n",
" union all\n",
" select\n",
" iter+1,\n",
" a-0.01*(n*a+b*s_x-s_y)/n,\n",
" b-0.01*(a*s_x+b*s_x2-s_xy)/n,\n",
" (n*pow(a-0.01*(n*a+b*s_x-s_y)/n,2)\n",
" +pow(b-0.01*(a*s_x+b*s_x2-s_xy),2)*s_x2\n",
" +s_y2\n",
" +2*(a-0.01*(n*a+b*s_x-s_y)/n)*(b-0.01*(a*s_x+b*s_x2-s_xy))*s_x\n",
" -2*(a-0.01*(n*a+b*s_x-s_y)/n)*s_y\n",
" -2*(b-0.01*(a*s_x+b*s_x2-s_xy))*s_xy)/(2*n)\n",
" from\n",
" gd, agg\n",
" where\n",
" iter < 5000\n",
" and loss > 0.1\n",
" and abs(\n",
" loss-(n*pow(a-0.01*(n*a+b*s_x-s_y)/n,2)\n",
" +pow(b-0.01*(a*s_x+b*s_x2-s_xy),2)*s_x2\n",
" +s_y2\n",
" +2*(a-0.01*(n*a+b*s_x-s_y)/n)*(b-0.01*(a*s_x+b*s_x2-s_xy))*s_x\n",
" -2*(a-0.01*(n*a+b*s_x-s_y)/n)*s_y\n",
" -2*(b-0.01*(a*s_x+b*s_x2-s_xy))*s_xy)/(2*n)) > 0.01\n",
" )\n",
"select * from gd order by iter desc limit 1\n",
"'''\n",
"\n",
"result = pd.read_sql(query, conn)\n",
"result"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"a = result['a'].values[-1]\n",
"b = result['b'].values[-1]\n",
"x_line = np.linspace(0, 10, 100)\n",
"y_line = a + b * x_line\n",
"\n",
"plt.figure(figsize=(8, 6))\n",
"plt.plot(df['x'], df['y'], 'o')\n",
"plt.plot(x_line, y_line, label='predict')\n",
"plt.xlabel('x')\n",
"plt.ylabel('y')\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment