Skip to content

Instantly share code, notes, and snippets.

@jolks
Forked from MaxHalford/ogd-in-sql.ipynb
Created March 8, 2023 04:10
Show Gist options
  • Save jolks/f356ce701fd735ca7f857d69c95663c2 to your computer and use it in GitHub Desktop.
Save jolks/f356ce701fd735ca7f857d69c95663c2 to your computer and use it in GitHub Desktop.
Online gradient descent written in SQL
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# OGD in SQL"
]
},
{
"cell_type": "code",
"execution_count": 188,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[*********************100%***********************] 1 of 1 completed\n",
"| Date | Open | High | Low | Close | Adj Close | Volume |\n",
"|:--------------------|--------:|--------:|--------:|--------:|------------:|---------:|\n",
"| 2021-12-27 00:00:00 | 6.01218 | 6.09808 | 6.04943 | 6.1254 | 6.07454 | 1.18628 |\n",
"| 2021-12-28 00:00:00 | 6.1164 | 6.12883 | 6.09931 | 6.09008 | 6.03951 | 1.25318 |\n",
"| 2021-12-29 00:00:00 | 6.08822 | 6.10517 | 6.08598 | 6.09314 | 6.04254 | 0.987236 |\n",
"| 2021-12-30 00:00:00 | 6.09298 | 6.10315 | 6.08427 | 6.05305 | 6.00279 | 0.946449 |\n",
"| 2021-12-31 00:00:00 | 6.04613 | 6.05785 | 6.05592 | 6.03165 | 5.98157 | 1.01437 |\n"
]
}
],
"source": [
"import yfinance as yf\n",
"\n",
"figures = yf.download(\n",
" tickers=['AAPL'],\n",
" start='2020-01-01',\n",
" end='2022-01-01'\n",
")\n",
"figures = figures / figures.std()\n",
"print(figures.tail().to_markdown())"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Mean"
]
},
{
"cell_type": "code",
"execution_count": 200,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"┌───────┬───────────────────┬────────────────────┐\n",
"│ step │ x │ avg │\n",
"│ int64 │ double │ double │\n",
"├───────┼───────────────────┼────────────────────┤\n",
"│ 505 │ 5.981568542028378 │ 3.9577706471349923 │\n",
"│ 504 │ 6.002789566151079 │ 3.953755175121315 │\n",
"│ 503 │ 6.042539700173864 │ 3.949681548101375 │\n",
"│ 502 │ 6.039508125299193 │ 3.945512507957804 │\n",
"│ 501 │ 6.074541325571636 │ 3.941332875987063 │\n",
"└───────┴───────────────────┴────────────────────┘\n",
"\n"
]
}
],
"source": [
"import duckdb\n",
"\n",
"duckdb.sql('''\n",
"WITH RECURSIVE\n",
" stream AS (\n",
" SELECT ROW_NUMBER() OVER () AS step, \"Adj Close\" AS x\n",
" FROM figures\n",
" ORDER BY step\n",
" ),\n",
" state(step, x, avg) AS (\n",
" -- Initialize\n",
" SELECT step, x, x AS avg\n",
" FROM stream\n",
" WHERE step = 1\n",
" UNION ALL\n",
" -- Update\n",
" SELECT\n",
" stream.step,\n",
" stream.x,\n",
" state.avg + (stream.x - state.avg) / stream.step AS avg\n",
" FROM stream\n",
" INNER JOIN state ON state.step + 1 = stream.step\n",
" )\n",
"\n",
"SELECT *\n",
"FROM state\n",
"ORDER BY step DESC\n",
"LIMIT 5\n",
"''').show()"
]
},
{
"cell_type": "code",
"execution_count": 206,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Date\n",
"2021-12-31 3.957771\n",
"2021-12-30 3.953755\n",
"2021-12-29 3.949682\n",
"2021-12-28 3.945513\n",
"2021-12-27 3.941333\n",
"Name: Adj Close, dtype: float64"
]
},
"execution_count": 206,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"figures['Adj Close'].rolling(len(figures), min_periods=1).mean().tail()[::-1]"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Covariance"
]
},
{
"cell_type": "code",
"execution_count": 192,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"┌───────┬────────────────────┐\n",
"│ step │ cov │\n",
"│ int64 │ double │\n",
"├───────┼────────────────────┤\n",
"│ 505 │ 0.9979967767965502 │\n",
"│ 504 │ 0.9918524780369538 │\n",
"│ 503 │ 0.985478504290919 │\n",
"│ 502 │ 0.9787158318485241 │\n",
"│ 501 │ 0.9719167545245742 │\n",
"└───────┴────────────────────┘\n",
"\n"
]
}
],
"source": [
"duckdb.sql('''\n",
"WITH RECURSIVE\n",
" stream AS (\n",
" SELECT\n",
" ROW_NUMBER() OVER () AS step,\n",
" \"Adj Close\" AS x,\n",
" \"Close\" AS y\n",
" FROM figures\n",
" ),\n",
" state(step, x, x_avg, y, y_avg, cov) AS (\n",
" -- Initialize\n",
" SELECT\n",
" step,\n",
" x,\n",
" x AS x_avg,\n",
" y,\n",
" y AS y_avg,\n",
" 0::DOUBLE AS cov\n",
" FROM stream\n",
" WHERE step = 1\n",
" UNION ALL\n",
" -- Update\n",
" SELECT\n",
" step,\n",
" x,\n",
" x_new_avg AS x_avg,\n",
" y,\n",
" y_new_avg AS y_avg,\n",
" cov + ((x - x_prev_avg) * (y - y_new_avg) - cov) / step AS cov\n",
" FROM (\n",
" SELECT\n",
" stream.step,\n",
" stream.x,\n",
" stream.y,\n",
" state.x_avg AS x_prev_avg,\n",
" state.x_avg + (stream.x - state.x_avg) / stream.step AS x_new_avg,\n",
" state.y_avg AS y_prev_avg,\n",
" state.y_avg + (stream.y - state.y_avg) / stream.step AS y_new_avg,\n",
" state.cov\n",
" FROM stream\n",
" INNER JOIN state ON state.step + 1 = stream.step\n",
" )\n",
" )\n",
"\n",
"SELECT step, cov\n",
"FROM state\n",
"ORDER BY step DESC\n",
"LIMIT 5\n",
"''').show()"
]
},
{
"cell_type": "code",
"execution_count": 217,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Date\n",
"2021-12-31 0.997997\n",
"2021-12-30 0.991852\n",
"2021-12-29 0.985479\n",
"2021-12-28 0.978716\n",
"2021-12-27 0.971917\n",
"Name: Adj Close, dtype: float64"
]
},
"execution_count": 217,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(\n",
" figures\n",
" .rolling(len(figures), min_periods=1)\n",
" .cov(ddof=0)['Adj Close']\n",
" .loc[:, 'Close']\n",
" .tail()[::-1]\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Handling many variables"
]
},
{
"cell_type": "code",
"execution_count": 221,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"| date | variable | value |\n",
"|:--------------------|:-----------|--------:|\n",
"| 2020-01-02 00:00:00 | Adj Close | 2.49235 |\n",
"| 2020-01-02 00:00:00 | Close | 2.55055 |\n",
"| 2020-01-02 00:00:00 | High | 2.54002 |\n",
"| 2020-01-02 00:00:00 | Low | 2.52122 |\n",
"| 2020-01-02 00:00:00 | Open | 2.51432 |\n",
"| 2020-01-02 00:00:00 | Volume | 2.14521 |\n",
"| 2020-01-03 00:00:00 | Adj Close | 2.46812 |\n",
"| 2020-01-03 00:00:00 | Close | 2.52576 |\n",
"| 2020-01-03 00:00:00 | High | 2.53985 |\n",
"| 2020-01-03 00:00:00 | Low | 2.53241 |\n"
]
}
],
"source": [
"figures_flat = figures.melt(ignore_index=False).reset_index()\n",
"figures_flat.columns = ['date', 'variable', 'value']\n",
"figures_flat = figures_flat.sort_values(['date', 'variable'])\n",
"print(figures_flat.head(10).to_markdown(index=False))"
]
},
{
"cell_type": "code",
"execution_count": 222,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"┌───────┬───────────┬────────────────────┬────────────────────┐\n",
"│ step │ variable │ value │ avg │\n",
"│ int64 │ varchar │ double │ double │\n",
"├───────┼───────────┼────────────────────┼────────────────────┤\n",
"│ 505 │ Adj Close │ 5.981568542028378 │ 3.9577706471349923 │\n",
"│ 505 │ Close │ 6.03165394229666 │ 4.012373756823449 │\n",
"│ 505 │ High │ 6.057853942108038 │ 4.03765319364954 │\n",
"│ 505 │ Low │ 6.05591789308585 │ 3.985178489614261 │\n",
"│ 505 │ Open │ 6.046125216781687 │ 4.006746251814558 │\n",
"│ 505 │ Volume │ 1.0143664144585565 │ 1.9651814487272024 │\n",
"└───────┴───────────┴────────────────────┴────────────────────┘\n",
"\n"
]
}
],
"source": [
"duckdb.sql('''\n",
"WITH RECURSIVE\n",
" stream AS (\n",
" SELECT RANK_DENSE() OVER (ORDER BY date) AS step, *\n",
" FROM figures_flat\n",
" ORDER BY date\n",
" ),\n",
" state(step, variable, value, avg) AS (\n",
" -- Initialize\n",
" SELECT step, variable, value, value AS avg\n",
" FROM stream\n",
" WHERE step = 1\n",
" UNION ALL\n",
" -- Update\n",
" SELECT\n",
" stream.step,\n",
" stream.variable,\n",
" stream.value,\n",
" state.avg + (stream.value - state.avg) / stream.step AS avg\n",
" FROM stream\n",
" INNER JOIN state ON\n",
" state.step + 1 = stream.step AND\n",
" state.variable = stream.variable\n",
" )\n",
"\n",
"SELECT *\n",
"FROM state\n",
"WHERE step = (SELECT MAX(step) FROM state)\n",
"ORDER BY variable\n",
"''').show()"
]
},
{
"cell_type": "code",
"execution_count": 232,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"variable \n",
"Adj Close 2524 3.957771\n",
"Close 2019 4.012374\n",
"High 1009 4.037653\n",
"Low 1514 3.985178\n",
"Open 504 4.006746\n",
"Volume 3029 1.965181\n",
"Name: value, dtype: float64"
]
},
"execution_count": 232,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(\n",
" figures_flat\n",
" .groupby('variable')['value']\n",
" .rolling(len(figures_flat), min_periods=1)\n",
" .mean()\n",
" .groupby('variable')\n",
" .tail(1)[::-1].sort_index()\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Stochastic gradient descent\n",
"\n",
"Vanilla SGD, meaning\n",
"\n",
"- Constant learning rate\n",
"- Single epoch\n",
"- Squared loss\n",
"- No gradient clipping\n",
"- No regularisation\n",
"- No intercept"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"$$p_t = \\sum_{t=1}^{n} \\dot{w}_t * \\dot{x}_t$$\n",
"$$l_t = p_t - y_t$$\n",
"$$\\dot{g}_t = l_t * \\dot{x}_t$$\n",
"$$\\dot{w}_{t+1} = \\dot{w}_t - \\eta \\dot{g}_t$$"
]
},
{
"cell_type": "code",
"execution_count": 234,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"┌───────┬──────────┬──────────────────────┬───────────────────┬───────────────────┐\n",
"│ step │ variable │ weight │ target │ prediction │\n",
"│ int64 │ varchar │ double │ double │ double │\n",
"├───────┼──────────┼──────────────────────┼───────────────────┼───────────────────┤\n",
"│ 505 │ Close │ 0.2511547716803354 │ 5.981568542028378 │ 5.938875441702928 │\n",
"│ 505 │ High │ 0.24043897039853313 │ 5.981568542028378 │ 5.938875441702928 │\n",
"│ 505 │ Low │ 0.2447191283620627 │ 5.981568542028378 │ 5.938875441702928 │\n",
"│ 505 │ Open │ 0.23603830762609726 │ 5.981568542028378 │ 5.938875441702928 │\n",
"│ 505 │ Volume │ 0.057510279698874206 │ 5.981568542028378 │ 5.938875441702928 │\n",
"└───────┴──────────┴──────────────────────┴───────────────────┴───────────────────┘"
]
},
"execution_count": 234,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"duckdb.sql('''\n",
"WITH RECURSIVE\n",
" X AS (\n",
" SELECT\n",
" RANK_DENSE() OVER (ORDER BY date) AS step, *\n",
" FROM figures_flat\n",
" WHERE variable != 'Adj Close'\n",
" ORDER BY date\n",
" ),\n",
" y AS (\n",
" SELECT\n",
" RANK_DENSE() OVER (ORDER BY date) AS step, *\n",
" FROM figures_flat\n",
" WHERE variable = 'Adj Close'\n",
" ORDER BY date\n",
" ),\n",
" stream AS (\n",
" SELECT X.*, y.value AS target\n",
" FROM X\n",
" INNER JOIN y ON X.step = y.step\n",
" ),\n",
" state AS (\n",
" -- Initialize\n",
" SELECT\n",
" step,\n",
" target,\n",
" variable,\n",
" value,\n",
" 0::DOUBLE AS weight,\n",
" 0::DOUBLE AS prediction\n",
" FROM stream\n",
" WHERE step = 1\n",
" UNION ALL\n",
" -- Update\n",
" SELECT\n",
" step,\n",
" target,\n",
" variable,\n",
" value,\n",
" weight,\n",
" SUM(weight * value) OVER () AS prediction\n",
" FROM (\n",
" SELECT\n",
" stream.step,\n",
" stream.target,\n",
" stream.variable,\n",
" stream.value,\n",
" state.prediction - state.target AS loss_gradient,\n",
" loss_gradient * state.value AS gradient,\n",
" state.weight - 0.01 * gradient AS weight\n",
" FROM stream\n",
" INNER JOIN state ON\n",
" state.step + 1 = stream.step AND\n",
" state.variable = stream.variable\n",
" )\n",
" )\n",
"\n",
"SELECT step, variable, weight, target, prediction\n",
"FROM state\n",
"WHERE step = (SELECT MAX(step) FROM state)\n",
"ORDER BY variable\n",
"''')"
]
},
{
"cell_type": "code",
"execution_count": 236,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'Close': 0.2511547716803354,\n",
" 'High': 0.2404389703985331,\n",
" 'Low': 0.2447191283620624,\n",
" 'Open': 0.23603830762609757,\n",
" 'Volume': 0.05751027969887417}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/max/.pyenv/versions/3.11.0/lib/python3.11/site-packages/sklearn/linear_model/_stochastic_gradient.py:1551: ConvergenceWarning: Maximum number of iteration reached before convergence. Consider increasing max_iter to improve the fit.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from pprint import pprint\n",
"from sklearn import linear_model\n",
"\n",
"model = linear_model.SGDRegressor(\n",
" loss='squared_error',\n",
" penalty=None,\n",
" fit_intercept=False,\n",
" learning_rate='constant',\n",
" eta0=0.01,\n",
" max_iter=1,\n",
" shuffle=False\n",
")\n",
"\n",
"X = figures[:-1].copy()\n",
"y = X.pop('Adj Close')\n",
"model = model.fit(X, y)\n",
"pprint(dict(zip(X.columns, model.coef_)))"
]
},
{
"cell_type": "code",
"execution_count": 237,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'Close': 0.2511547716803356,\n",
" 'High': 0.2404389703985331,\n",
" 'Low': 0.24471912836206253,\n",
" 'Open': 0.2360383076260972,\n",
" 'Volume': 0.057510279698874255}\n"
]
}
],
"source": [
"from river import linear_model\n",
"from river import optim\n",
"\n",
"class ScikitLearnSquaredLoss:\n",
" \"\"\"sklearn removes the leading 2 from the gradient of the squared loss.\"\"\"\n",
"\n",
" def gradient(self, y_true, y_pred):\n",
" return y_pred - y_true\n",
"\n",
"model = linear_model.LinearRegression(\n",
" optimizer=optim.SGD(lr=0.01),\n",
" loss=ScikitLearnSquaredLoss(),\n",
" intercept_lr=0.0,\n",
" l2=0.0\n",
")\n",
"\n",
"for i, x in enumerate(figures[:-1].to_dict(orient='records')):\n",
" y = x.pop('Adj Close')\n",
" model.learn_one(x, y)\n",
"\n",
"pprint(model.weights)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment