Skip to content

Instantly share code, notes, and snippets.

@oguiza
Created March 28, 2019 20:44
Show Gist options
  • Save oguiza/7f698965863916b60fee8e5f7f8906a3 to your computer and use it in GitHub Desktop.
Save oguiza/7f698965863916b60fee8e5f7f8906a3 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def get_wf_idxs(arr, wf_folds, train_val_sz_ratio, test_set=False, plot_chart=True):\n",
" n_folds = int(1 + int(train_val_sz_ratio) + test_set + (wf_folds - 1))\n",
" fold_sz = arr.shape[-1] / n_folds\n",
" train_idx = []\n",
" valid_idx = []\n",
" test_idx = []\n",
" start = 0\n",
" idxs = np.arange(arr.shape[-1])\n",
" for i in range(wf_folds):\n",
" train_idx.append(idxs[start:int(start + fold_sz * train_val_sz_ratio)])\n",
" valid_idx.append(idxs[start + int(fold_sz * train_val_sz_ratio) : start + int(fold_sz * (train_val_sz_ratio + 1))])\n",
" if test_set:\n",
" test_idx.append(idxs[start + int(fold_sz * (train_val_sz_ratio + 1)) : start + int(fold_sz * (train_val_sz_ratio + 2))])\n",
" start += int(fold_sz)\n",
" if test_set is None:\n",
" test_idx = None\n",
" \n",
" if plot_chart:\n",
" df = pd.DataFrame()\n",
" for i in range(wf_folds):\n",
" if test_set:\n",
" df = df.append([[0, np.min(train_idx[i]), np.max(train_idx[i]), np.max(valid_idx[i]), np.max(test_idx[i]), len(idxs)]])\n",
" else:\n",
" df = df.append([[0, np.min(train_idx[i]), np.max(train_idx[i]), np.max(valid_idx[i]), len(idxs)]])\n",
"\n",
" df = df.diff(axis=1)\n",
" df.drop(df.columns[0], inplace=True, axis=1)\n",
" df = df.set_index(np.arange(wf_folds))\n",
" plt.style.use('classic')\n",
" if test_set:\n",
" df.plot(kind='barh', stacked = True, color=['lightgray', 'purple', 'orange', 'skyblue', 'lightgray'], figsize=(10, 5))\n",
" plt.legend([None, 'Train', 'Valid', 'Test'])\n",
" else:\n",
" df.plot(kind='barh', stacked = True, color=['lightgray', 'purple', 'orange', 'lightgray'], figsize=(10, 5))\n",
" plt.legend([None, 'Train', 'Valid'])\n",
" plt.gca().invert_yaxis()\n",
" plt.title('Walk-forward folds')\n",
" plt.show()\n",
"\n",
" return train_idx, valid_idx, test_idx\n",
"\n",
"def sliding_window(X, y, ws, ss=None, y_step=1):\n",
" if ss is None:\n",
" ss = ws\n",
" X_ = []\n",
" y_ = []\n",
" start = 0\n",
" for i in range((y.shape[-1] - y_step) // ss - 1):\n",
" X_.append(X[:, start : start + ws])\n",
" y_.append(y[start + ws])\n",
" start += ss\n",
" return np.array(X_), np.array(y_)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# Get your data\n",
"n_feat = 3\n",
"total_len = 10000\n",
"X = np.arange(total_len)\n",
"\n",
"feat1 = [f'f1_{x}' for x in np.arange(total_len)]\n",
"feat2 = [f'f2_{x}' for x in np.arange(total_len)]\n",
"feat3 = [f'f3_{x}' for x in np.arange(total_len)]\n",
"y = np.array(feat1)\n",
"X = np.stack((feat1, feat2, feat3))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 800x400 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Walk-forward\n",
"wf_folds = 5\n",
"train_val_sz_ratio = 4\n",
"test_set = False\n",
"train, valid, test = get_wf_idxs(X, wf_folds, train_val_sz_ratio, test_set)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((3, 8000), (8000,))"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# This might be part of a training loop\n",
"X_train = X[:, train[0]]\n",
"y_train = y[train[0]]\n",
"X_train.shape, y_train.shape"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"ws = 50 # window size\n",
"ss = 25 # slide size (if less thatn ws, then there is some overlapping data)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"X_train, y_train = sliding_window(X[:, train[0]], y[train[0]], ws, ss)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((176, 3, 50), (176,))"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train.shape, y_train.shape"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([['f1_0', 'f1_1', 'f1_2', 'f1_3', 'f1_4', 'f1_5', 'f1_6', 'f1_7',\n",
" 'f1_8', 'f1_9', 'f1_10', 'f1_11', 'f1_12', 'f1_13', 'f1_14',\n",
" 'f1_15', 'f1_16', 'f1_17', 'f1_18', 'f1_19', 'f1_20', 'f1_21',\n",
" 'f1_22', 'f1_23', 'f1_24', 'f1_25', 'f1_26', 'f1_27', 'f1_28',\n",
" 'f1_29', 'f1_30', 'f1_31', 'f1_32', 'f1_33', 'f1_34', 'f1_35',\n",
" 'f1_36', 'f1_37', 'f1_38', 'f1_39', 'f1_40', 'f1_41', 'f1_42',\n",
" 'f1_43', 'f1_44', 'f1_45', 'f1_46', 'f1_47', 'f1_48', 'f1_49'],\n",
" ['f2_0', 'f2_1', 'f2_2', 'f2_3', 'f2_4', 'f2_5', 'f2_6', 'f2_7',\n",
" 'f2_8', 'f2_9', 'f2_10', 'f2_11', 'f2_12', 'f2_13', 'f2_14',\n",
" 'f2_15', 'f2_16', 'f2_17', 'f2_18', 'f2_19', 'f2_20', 'f2_21',\n",
" 'f2_22', 'f2_23', 'f2_24', 'f2_25', 'f2_26', 'f2_27', 'f2_28',\n",
" 'f2_29', 'f2_30', 'f2_31', 'f2_32', 'f2_33', 'f2_34', 'f2_35',\n",
" 'f2_36', 'f2_37', 'f2_38', 'f2_39', 'f2_40', 'f2_41', 'f2_42',\n",
" 'f2_43', 'f2_44', 'f2_45', 'f2_46', 'f2_47', 'f2_48', 'f2_49'],\n",
" ['f3_0', 'f3_1', 'f3_2', 'f3_3', 'f3_4', 'f3_5', 'f3_6', 'f3_7',\n",
" 'f3_8', 'f3_9', 'f3_10', 'f3_11', 'f3_12', 'f3_13', 'f3_14',\n",
" 'f3_15', 'f3_16', 'f3_17', 'f3_18', 'f3_19', 'f3_20', 'f3_21',\n",
" 'f3_22', 'f3_23', 'f3_24', 'f3_25', 'f3_26', 'f3_27', 'f3_28',\n",
" 'f3_29', 'f3_30', 'f3_31', 'f3_32', 'f3_33', 'f3_34', 'f3_35',\n",
" 'f3_36', 'f3_37', 'f3_38', 'f3_39', 'f3_40', 'f3_41', 'f3_42',\n",
" 'f3_43', 'f3_44', 'f3_45', 'f3_46', 'f3_47', 'f3_48', 'f3_49']],\n",
" dtype='<U7'), 'f1_50')"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train[0], y_train[0]"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([['f1_4375', 'f1_4376', 'f1_4377', 'f1_4378', 'f1_4379', 'f1_4380',\n",
" 'f1_4381', 'f1_4382', 'f1_4383', 'f1_4384', 'f1_4385', 'f1_4386',\n",
" 'f1_4387', 'f1_4388', 'f1_4389', 'f1_4390', 'f1_4391', 'f1_4392',\n",
" 'f1_4393', 'f1_4394', 'f1_4395', 'f1_4396', 'f1_4397', 'f1_4398',\n",
" 'f1_4399', 'f1_4400', 'f1_4401', 'f1_4402', 'f1_4403', 'f1_4404',\n",
" 'f1_4405', 'f1_4406', 'f1_4407', 'f1_4408', 'f1_4409', 'f1_4410',\n",
" 'f1_4411', 'f1_4412', 'f1_4413', 'f1_4414', 'f1_4415', 'f1_4416',\n",
" 'f1_4417', 'f1_4418', 'f1_4419', 'f1_4420', 'f1_4421', 'f1_4422',\n",
" 'f1_4423', 'f1_4424'],\n",
" ['f2_4375', 'f2_4376', 'f2_4377', 'f2_4378', 'f2_4379', 'f2_4380',\n",
" 'f2_4381', 'f2_4382', 'f2_4383', 'f2_4384', 'f2_4385', 'f2_4386',\n",
" 'f2_4387', 'f2_4388', 'f2_4389', 'f2_4390', 'f2_4391', 'f2_4392',\n",
" 'f2_4393', 'f2_4394', 'f2_4395', 'f2_4396', 'f2_4397', 'f2_4398',\n",
" 'f2_4399', 'f2_4400', 'f2_4401', 'f2_4402', 'f2_4403', 'f2_4404',\n",
" 'f2_4405', 'f2_4406', 'f2_4407', 'f2_4408', 'f2_4409', 'f2_4410',\n",
" 'f2_4411', 'f2_4412', 'f2_4413', 'f2_4414', 'f2_4415', 'f2_4416',\n",
" 'f2_4417', 'f2_4418', 'f2_4419', 'f2_4420', 'f2_4421', 'f2_4422',\n",
" 'f2_4423', 'f2_4424'],\n",
" ['f3_4375', 'f3_4376', 'f3_4377', 'f3_4378', 'f3_4379', 'f3_4380',\n",
" 'f3_4381', 'f3_4382', 'f3_4383', 'f3_4384', 'f3_4385', 'f3_4386',\n",
" 'f3_4387', 'f3_4388', 'f3_4389', 'f3_4390', 'f3_4391', 'f3_4392',\n",
" 'f3_4393', 'f3_4394', 'f3_4395', 'f3_4396', 'f3_4397', 'f3_4398',\n",
" 'f3_4399', 'f3_4400', 'f3_4401', 'f3_4402', 'f3_4403', 'f3_4404',\n",
" 'f3_4405', 'f3_4406', 'f3_4407', 'f3_4408', 'f3_4409', 'f3_4410',\n",
" 'f3_4411', 'f3_4412', 'f3_4413', 'f3_4414', 'f3_4415', 'f3_4416',\n",
" 'f3_4417', 'f3_4418', 'f3_4419', 'f3_4420', 'f3_4421', 'f3_4422',\n",
" 'f3_4423', 'f3_4424']], dtype='<U7'), 'f1_4425')"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_train[-1], y_train[-1]"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"X_valid, y_valid = sliding_window(X[:, valid[0]], y[valid[0]], ws, ss)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((43, 3, 50), (43,))"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_valid.shape, y_valid.shape"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([['f1_4444', 'f1_4445', 'f1_4446', 'f1_4447', 'f1_4448', 'f1_4449',\n",
" 'f1_4450', 'f1_4451', 'f1_4452', 'f1_4453', 'f1_4454', 'f1_4455',\n",
" 'f1_4456', 'f1_4457', 'f1_4458', 'f1_4459', 'f1_4460', 'f1_4461',\n",
" 'f1_4462', 'f1_4463', 'f1_4464', 'f1_4465', 'f1_4466', 'f1_4467',\n",
" 'f1_4468', 'f1_4469', 'f1_4470', 'f1_4471', 'f1_4472', 'f1_4473',\n",
" 'f1_4474', 'f1_4475', 'f1_4476', 'f1_4477', 'f1_4478', 'f1_4479',\n",
" 'f1_4480', 'f1_4481', 'f1_4482', 'f1_4483', 'f1_4484', 'f1_4485',\n",
" 'f1_4486', 'f1_4487', 'f1_4488', 'f1_4489', 'f1_4490', 'f1_4491',\n",
" 'f1_4492', 'f1_4493'],\n",
" ['f2_4444', 'f2_4445', 'f2_4446', 'f2_4447', 'f2_4448', 'f2_4449',\n",
" 'f2_4450', 'f2_4451', 'f2_4452', 'f2_4453', 'f2_4454', 'f2_4455',\n",
" 'f2_4456', 'f2_4457', 'f2_4458', 'f2_4459', 'f2_4460', 'f2_4461',\n",
" 'f2_4462', 'f2_4463', 'f2_4464', 'f2_4465', 'f2_4466', 'f2_4467',\n",
" 'f2_4468', 'f2_4469', 'f2_4470', 'f2_4471', 'f2_4472', 'f2_4473',\n",
" 'f2_4474', 'f2_4475', 'f2_4476', 'f2_4477', 'f2_4478', 'f2_4479',\n",
" 'f2_4480', 'f2_4481', 'f2_4482', 'f2_4483', 'f2_4484', 'f2_4485',\n",
" 'f2_4486', 'f2_4487', 'f2_4488', 'f2_4489', 'f2_4490', 'f2_4491',\n",
" 'f2_4492', 'f2_4493'],\n",
" ['f3_4444', 'f3_4445', 'f3_4446', 'f3_4447', 'f3_4448', 'f3_4449',\n",
" 'f3_4450', 'f3_4451', 'f3_4452', 'f3_4453', 'f3_4454', 'f3_4455',\n",
" 'f3_4456', 'f3_4457', 'f3_4458', 'f3_4459', 'f3_4460', 'f3_4461',\n",
" 'f3_4462', 'f3_4463', 'f3_4464', 'f3_4465', 'f3_4466', 'f3_4467',\n",
" 'f3_4468', 'f3_4469', 'f3_4470', 'f3_4471', 'f3_4472', 'f3_4473',\n",
" 'f3_4474', 'f3_4475', 'f3_4476', 'f3_4477', 'f3_4478', 'f3_4479',\n",
" 'f3_4480', 'f3_4481', 'f3_4482', 'f3_4483', 'f3_4484', 'f3_4485',\n",
" 'f3_4486', 'f3_4487', 'f3_4488', 'f3_4489', 'f3_4490', 'f3_4491',\n",
" 'f3_4492', 'f3_4493']], dtype='<U7'), 'f1_4494')"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X_valid[0], y_valid[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment