Skip to content

Instantly share code, notes, and snippets.

@drorata
Created November 7, 2021 09:40
Show Gist options
  • Save drorata/90f7ae87165afab948e248789feb8477 to your computer and use it in GitHub Desktop.
Save drorata/90f7ae87165afab948e248789feb8477 to your computer and use it in GitHub Desktop.
FB predictions using LSTM
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from numpy import array\n",
"from keras.models import Sequential\n",
"from keras.layers import LSTM\n",
"from keras.layers import Dense, Dropout, Activation\n",
"from tensorflow import keras\n",
"import tensorflow as tf\n",
"from sklearn import metrics\n",
"import datetime\n",
"pd.options.plotting.backend = \"plotly\"\n",
"%load_ext tensorboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = pd.read_parquet(\"./data/raw_historical_data/FB.parquet\")\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# split a univariate sequence into samples\n",
"def split_sequence(sequence, n_steps):\n",
"\tX, y = list(), list()\n",
"\tfor i in range(len(sequence)):\n",
"\t\t# find the end of this pattern\n",
"\t\tend_ix = i + n_steps\n",
"\t\t# check if we are beyond the sequence\n",
"\t\tif end_ix > len(sequence)-1:\n",
"\t\t\tbreak\n",
"\t\t# gather input and output parts of the pattern\n",
"\t\tseq_x, seq_y = sequence[i:end_ix], sequence[end_ix]\n",
"\t\tX.append(seq_x)\n",
"\t\ty.append(seq_y)\n",
"\treturn array(X), array(y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# choose a number of time steps\n",
"n_steps = 50\n",
"# split into samples\n",
"X, y = split_sequence(df.adjClose, n_steps)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_size = int(0.7 * X.shape[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X_train, y_train, X_validate, y_validate = X[:train_size], y[:train_size], X[train_size + n_steps + 1:], y[train_size + n_steps + 1:]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# reshape from [samples, timesteps] into [samples, timesteps, features]\n",
"n_features = 1\n",
"X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], n_features)\n",
"X_validate = X_validate.reshape(X_validate.shape[0], X_validate.shape[1], n_features)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"log_dir = \"logs/fit/\" + datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
"tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)\n",
"\n",
"# define model\n",
"model = Sequential()\n",
"model.add(LSTM(50, activation=keras.activations.relu, input_shape=(n_steps, n_features)))\n",
"# model.add(Dropout(0.2, name='lstm_dropout_0'))\n",
"# model.add(Dense(64, name='dense_0'))\n",
"# model.add(Activation('sigmoid', name='sigmoid_0'))\n",
"model.add(Dense(1))\n",
"model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.01), loss='mse')\n",
"# fit model\n",
"model.fit(X_train, y_train, epochs=200, verbose=1, callbacks=[tensorboard_callback])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# demonstrate prediction\n",
"yhat = model.predict(X_validate, verbose=0)\n",
"yhat = yhat.reshape(yhat.shape[0])\n",
"\n",
"y_vals = pd.DataFrame(\n",
" {\n",
" \"true\": y_validate,\n",
" \"pred\": yhat\n",
" }\n",
")\n",
"y_vals.head()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"y_vals.plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"(\n",
" f\"Mean absolute error: {metrics.mean_absolute_error(y_validate, yhat)} / \" + \n",
" f\"Mean squared error: {metrics.mean_squared_error(y_validate, yhat)}\"\n",
")"
]
}
],
"metadata": {
"interpreter": {
"hash": "ab3782f7a3367a0ac78615a23ccde97a675b35a20d9404343327068ff2cfd379"
},
"kernelspec": {
"display_name": "Python 3.9.6 64-bit ('stocks-opportunities-monitor': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment