Skip to content

Instantly share code, notes, and snippets.

@martinsbruveris
Created August 4, 2021 19:55
Show Gist options
  • Save martinsbruveris/8b3808ed9bb7c50683e3b0fbc59b05da to your computer and use it in GitHub Desktop.
Save martinsbruveris/8b3808ed9bb7c50683e3b0fbc59b05da to your computer and use it in GitHub Desktop.
Baseline model for turbulence prediction
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "09a9fb28",
"metadata": {},
"outputs": [],
"source": [
"import itertools\n",
"import math\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "markdown",
"id": "2be4e4ad",
"metadata": {},
"source": [
"## Preprocess data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6c68f19d",
"metadata": {},
"outputs": [],
"source": [
"data = torch.load(\"data/tf_net/rbc_data.pt\")\n",
"print(data.shape) # (2000, 2, 256, 1792)\n",
"\n",
"# Standardization\n",
"# std = torch.std(data)\n",
"# avg = torch.mean(data)\n",
"# We computed this once on the whole dataset. This is a shortcut to save memory\n",
"std = 4506.0068\n",
"avg = -0.7982\n",
"data = (data - avg)/std\n",
"\n",
"# Subsampling (maybe averaging would be better)\n",
"data = data[:, :, ::4, ::4].clone().detach()\n",
"print(data.shape) # (2000, 2, 64, 448)\n",
"torch.save(data, \"data/tf_net/data.pt\")"
]
},
{
"cell_type": "markdown",
"id": "5138ca5c",
"metadata": {},
"source": [
"## Dataset loader"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f560ec3a",
"metadata": {},
"outputs": [],
"source": [
"class RBCDataset(torch.utils.data.Dataset):\n",
" def __init__(\n",
" self, \n",
" data,\n",
" sequence_length,\n",
" input_length,\n",
" midpoint,\n",
" output_length,\n",
" stack_x,\n",
" ):\n",
" if sequence_length > data.shape[0]:\n",
" raise ValueError(\"sequence_length larger than data.\")\n",
" if input_length > midpoint:\n",
" raise ValueError(\"input_length larger than midpoint.\")\n",
" if midpoint + output_length > sequence_length:\n",
" raise ValueError(\"output goes past end of sequence\")\n",
" \n",
" self.data = data # (2000, 2, 64, 448)\n",
" self.sequence_length = sequence_length\n",
" self.input_length = input_length\n",
" self.midpoint = midpoint\n",
" self.output_length = output_length\n",
" self.stack_x = stack_x\n",
" \n",
" self.region_width = 64 # Could make this a parameter...\n",
" width = self.data.shape[3]\n",
" self.indices = [\n",
" (a, b) \n",
" for a in range(self.data.shape[0] - self.sequence_length)\n",
" for b in range(width // self.region_width)\n",
" ]\n",
"\n",
" def __len__(self):\n",
" return len(self.indices)\n",
"\n",
" def __getitem__(self, index):\n",
" a, b = self.indices[index]\n",
" w_from = b * self.region_width\n",
" w_to = w_from + self.region_width\n",
" sequence = self.data[a:a+self.sequence_length, :, :, w_from:w_to]\n",
" x = sequence[(self.midpoint - self.input_length):self.midpoint]\n",
" y = sequence[self.midpoint:(self.midpoint + self.output_length)]\n",
" if self.stack_x:\n",
" # Stack time and (u, v) dimension into one\n",
" x = x.reshape(-1, x.shape[-2], x.shape[-1])\n",
" x = x.clone().detach()\n",
" y = y.clone().detach()\n",
" return x.float(), y.float()"
]
},
{
"cell_type": "markdown",
"id": "eb56311b",
"metadata": {},
"source": [
"## Baseline model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "dcdca4db",
"metadata": {},
"outputs": [],
"source": [
"class BaselineModel:\n",
" def __init__(self, data, output_length, strides, neighbors, device=None):\n",
" self.data = data\n",
" self.output_length = output_length\n",
" self.strides = strides\n",
" self.neighbors = neighbors\n",
" \n",
" if device is None:\n",
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" self.device = device\n",
" \n",
" self.nb_frames = self.data.shape[0]\n",
" self.height = self.data.shape[2]\n",
" self.width = self.data.shape[3]\n",
" \n",
" def __call__(self, x):\n",
" # x has shape (N, 2, H, W)\n",
" y = torch.zeros((self.output_length, *x.shape[1:]))\n",
" y = y.to(self.device)\n",
" for l in range(y.shape[1]):\n",
" src = x[:, l, :, :]\n",
" \n",
" a = self.nb_frames - self.output_length - x.shape[0]\n",
" b = (self.height - x.shape[2]) // self.strides[0] + 1\n",
" c = (self.width - x.shape[3]) // self.strides[1] + 1\n",
" dists = torch.zeros((a, b, c)).to(self.device)\n",
" for i, j, k in itertools.product(range(a), range(b), range(c)):\n",
" h_from = j * self.strides[0]\n",
" h_to = h_from + x.shape[2]\n",
" w_from = k * self.strides[1]\n",
" w_to = w_from + x.shape[3]\n",
" vec = self.data[i:i+x.shape[0], l, h_from:h_to, w_from:w_to]\n",
" dists[i, j, k] = torch.sqrt(torch.mean((src - vec)**2))\n",
" \n",
" dists = dists.cpu().numpy()\n",
" sorted_dist_idx = np.argsort(dists.flatten())\n",
" total_weight = 0\n",
" for idx in sorted_dist_idx[:self.neighbors]:\n",
" i, j, k = np.unravel_index(idx, dists.shape)\n",
" t_from = i + x.shape[0]\n",
" t_to = t_from + self.output_length\n",
" h_from = j * self.strides[0]\n",
" h_to = h_from + x.shape[2]\n",
" w_from = k * self.strides[1]\n",
" w_to = w_from + x.shape[3]\n",
" \n",
" weight = 1. / dists[i, j, k]\n",
" y[:, l, :, :] += weight * self.data[t_from:t_to, l, h_from:h_to, w_from:w_to]\n",
" total_weight += weight\n",
" y[:, l, :, :] /= total_weight\n",
" return y "
]
},
{
"cell_type": "markdown",
"id": "c3b7d145",
"metadata": {},
"source": [
"## Loss"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bef318ad",
"metadata": {},
"outputs": [],
"source": [
"def rmse_loss(x, y):\n",
" std = 4506.0068\n",
" avg = -0.7982\n",
" x = x * std + avg\n",
" y = y * std + avg\n",
" mean = torch.mean(torch.mean(torch.mean((x-y)**2, axis=-1), axis=-1), axis=-1)\n",
" loss = torch.sqrt(mean)\n",
" return loss"
]
},
{
"cell_type": "markdown",
"id": "f67f7437",
"metadata": {},
"source": [
"## Compute results"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b753d1c2",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "af80f164",
"metadata": {},
"outputs": [],
"source": [
"data = torch.load(\"../data/tf_net/data.pt\").to(device)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "8e0b8bfc",
"metadata": {},
"outputs": [],
"source": [
"def run_experiment(input_length, strides, neighbors):\n",
" # Other shared parameters\n",
" output_length = 60\n",
" \n",
" test_ds = RBCDataset(\n",
" data = data[1100:1500],\n",
" sequence_length=100,\n",
" input_length=input_length,\n",
" midpoint=40,\n",
" output_length=output_length,\n",
" stack_x=False,\n",
" )\n",
" \n",
" model = BaselineModel(\n",
" data=data[0:1100], \n",
" output_length=output_length, \n",
" strides=strides,\n",
" neighbors=neighbors,\n",
" )\n",
" \n",
" rmse = np.zeros((len(test_ds), output_length))\n",
" for j, (x, y) in tqdm(enumerate(test_ds)):\n",
" y_pred = model(x)\n",
" rmse[j] = rmse_loss(y, y_pred).cpu().numpy()\n",
"\n",
" np.save(\n",
" f\"../data/tf_net/baseline_rmse_n{neighbors}_i{input_length}_s{strides[1]}.npy\",\n",
" rmse\n",
" )\n",
" \n",
" return rmse"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "67a3a9fb",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"rmse = run_experiment(input_length=1, strides=(1, 8), neighbors=20)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "c6c299b9",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAac0lEQVR4nO3de3hc9X3n8fdXMxrdJUuWLNvyRTY2GDuAcYQxgU3IpQbcbk3aJIV0E5fQOO3CLnmW7i4kzxPatN1eNpdutilPnYQCKeHSXDZelifEeMmShGJbBmN8xfIF32TrZls3SxrNfPePOTKDsbEsyRrNnM/reeaZc77nnJnfD8afOfrNmd+YuyMiIuGQl+kGiIjI+FHoi4iEiEJfRCREFPoiIiGi0BcRCZFophvwXqqrq72+vj7TzRARySqbN29uc/eac22b0KFfX19PY2NjppshIpJVzOyt823T8I6ISIgo9EVEQkShLyISIgp9EZEQuWDom9lMM3vRzHaY2XYzuy+o/6mZHTGzLcFtRdoxD5pZk5ntNrNb0uq3BrUmM3vg0nRJRETOZzhX7wwC97v7q2ZWBmw2s3XBtm+6+9fSdzazhcAdwCJgOvCCmV0ebP428BvAYWCTma119x1j0REREbmwC4a+uzcDzcFyl5ntBOre45CVwFPu3g/sN7MmYGmwrcnd9wGY2VPBvgp9EZFxclFj+mZWD1wLbAhK95rZVjN7xMwqg1odcCjtsMNB7Xz1s59jtZk1mllja2vrxTRPREQuYNihb2alwI+AL7p7J/AwcBmwmNRfAl8fiwa5+xp3b3D3hpqac36hTERERmhY38g1s3xSgf+Eu/8YwN2Pp23/DvBssHoEmJl2+IygxnvURURkHAzn6h0DvgfsdPdvpNWnpe32cWBbsLwWuMPMCsxsDjAf2AhsAuab2Rwzi5H6sHft2HRDRESGYzhn+jcCnwHeMLMtQe1LwJ1mthhw4ADwBQB3325mz5D6gHYQuMfdEwBmdi/wPBABHnH37WPWExERuSCbyL+R29DQ4JpwTUTk4pjZZndvONc2fSNXRCREFPoiIiGi0BcRCRGFvohIiCj0RURCRKEvIhIiCn0RkRBR6IuIhIhCX0QkRBT6IiIhotAXEQkRhb6ISIgo9EVEQkShLyISIgp9EZEQUeiLiISIQl9EJEQU+iIiIaLQFxEJEYW+iEiIKPRFREJEoS8iEiIKfRGREFHoi4iEiEJfRCREFPoiIiGi0BcRCRGFvohIiCj0RURCRKEvIhIiCn0RkRBR6IuIhMgFQ9/MZprZi2a2w8y2m9l9Qb3KzNaZ2Z7gvjKom5l9y8yazGyrmS1Je6xVwf57zGzVpeuWiIicy3DO9AeB+919IbAMuMfMFgIPAOvdfT6wPlgHuA2YH9xWAw9D6k0CeAi4HlgKPDT0RiEiIuPjgqHv7s3u/mqw3AXsBOqAlcBjwW6PAbcHyyuBxz3lFWCSmU0DbgHWuXuHu58A1gG3jmVnRETkvV3UmL6Z1QPXAhuAWndvDjYdA2qD5TrgUNphh4Pa+eoiIjJOhh36ZlYK/Aj4ort3pm9zdwd8LBpkZqvNrNHMGltbW8fiIUVEJDCs0DezfFKB/4S7/zgoHw+GbQjuW4L6EWBm2uEzgtr56u/g7mvcvcHdG2pqai6mLyIicgHDuXrHgO8BO939G2mb1gJDV+CsAn6aVv9scBXPMuBUMAz0PLDczCqDD3CXBzURERkn0WHscyPwGeANM9sS1L4E/DXwjJndDbwFfCrY9hywAmgCeoG7ANy9w8z+HNgU7PdVd+8Yi06IiMjwWGo4fmJqaGjwxsbGTDdDRCSrmNlmd2841zZ9I1dEJEQU+iIiIaLQFxEJEYW+iEiIKPRFREJEoS8iEiIKfRGREFHoi4iEiEJfRCREFPoiIiGi0BcRCRGFvohIiCj0RURCRKEvIhIiCn0RkRBR6IuIhIhCX0QkRBT6IiIhotAXEQkRhb6ISIgo9EVEQkShLyISIgp9EZEQUeiLiISIQl9EJEQU+iIiIaLQFxEJEYW+iEiIKPRFREJEoS8iEiIKfRGREFHoi4iEiEJfRCRELhj6ZvaImbWY2ba02p+a2REz2xLcVqRte9DMmsxst5ndkla/Nag1mdkDY98VERG5kOGc6T8K3HqO+jfdfXFwew7AzBYCdwCLgmP+wcwiZhYBvg3cBiwE7gz2FRGRcRS90A7u/pKZ1Q/z8VYCT7l7P7DfzJqApcG2JnffB2BmTwX77rj4JouIyEiNZkz/XjPbGgz/VAa1OuBQ2j6Hg9r56u9iZqvNrNHMGltbW0fRPBEROdtIQ/9h4DJgMdAMfH2sGuTua9y9wd0bampqxuphRUSyQnt3P7/Y3cILO45fkse/4PDOubj7mdaY2XeAZ4PVI8DMtF1nBDXeoy4iEkrt3f28ceQU246cYuvh1P3RU30ALJhaxscW1o75c44o9M1smrs3B6sfB4au7FkL/MDMvgFMB+YDGwED5pvZHFJhfwfw6dE0XEQkm7R09bH9aCfbzxHwAHOqS3h/fRV/UFfO++oqWDS94pK044Khb2ZPAjcD1WZ2GHgIuNnMFgMOHAC+AODu283sGVIf0A4C97h7Inice4HngQjwiLtvH+vOiIhkWjyR5EBbD7uPd7HjaCfbj3ayo7mT1q7+M/ukB/xVdZNYVFdOeWH+uLTP3H1cnmgkGhoavLGxMdPNEBF5l5O9Axxo7+Wt9h72t/Ww53g3bx7vYn9bD4PJVK5G84x5U0pZNL2ChdPLWTitfFwC3sw2u3vDubaNaHhHRCRXJZLOqdNxOnoGaOvu59ipPo519qXuT/Vx9NRpDrT10Nk3eOYYM5hVVcz8KaV8bGEtl9eWMn9KGfNrSymIRjLYm3dT6IvIhDeYSDKQSNIfT9I/mGRgMEn/YIL+wbfXBxKp+3gidesPlvvjb28b2q93YJCe/gQ9/YP0BMudp+N09A5w6nSccw2AlBVEqa0oZFpFISsX1zF7cjGzJ5cwe3Ixs6qKKcyfWOF+Pgp9Ebko7s5AIknfQJLe+CC9AwlODyQ4HU+cWe6Lp9aH6v3xtwO6fzBxJrz7ztSD+/jby31pxySSYzMMnWdQEI1QHItQXBChJBalpCBKWWGUusoiKovzqSqOUVkSo6okxuSSAqZWFDK1opDSgtyIy9zohYicEU8kOdkb52TvACd645zoHeBUb5yegUH64qkw7RtM0DeQCM6A/czZ8dCZ8Ntn00kG0kJ4KMRHksGF+XkURCMURPMoyM+jMBqhIK1WUhBNq71dj0XzKMwPjovmUZAfIRZJPUYsktoeC7bFIhHyo0Yskkd+JK0eTe0bjWiOSYW+SJZJJp0jJ0+zr62HA209HDl5OnU7cZqjJ0/T2t1/zuGJdLFoHoVpARqL5pEfMfKDsCyI5lFWGKU6+nbYFsYiFOUHt1iEwvzgjDmoF8eiFMXyKMqPUpS271CIm9n4/AeS96TQF5mgkknnYEcvu451setYJ3tautnb0s2B9h764skz+8WiedRNKqJuUhE3X1HDtIoiJpfGmFQco7I4n8riGBVF+ZQURFMhHM0jL08BHFYKfZEJoC+eYNexLrYdOXXmuu49x7voHUgAb18dMq+mlH8zv5q5NaXMrS5hTk0JNaUFOouWYVPoi4yz/sEEu5q72Hr4JFsOpb6Z2dTafebDyoqifK6cVsbvXTeTBVPLWDC1nMtryyiKZcfVITKxKfRFLrGjJ0/T+NYJNh/oYMuhk+xs7mIgkRqeqS6NcVVdBcsX1bJoegWLppczo7JIZ+5yySj0RcaQu7OnpZsN+9rZdOAEjQc6zsyvUhyLcPWMCu66qZ5rZkzimpmTmF5RqICXcaXQFxmFoZB/ZV87r+xrZ8O+Dtp7BgCoLS+gob6Kz8+u5Lr6KhZMLdMlg5JxCn2Ri3T4RC+/bmrj103tvLy3nbbu1ERa0ysK+dAVNSybO5llcyYzs0rDNDLxKPRFLqCjZ4B/3dvOr5raeHlvG2+19wJQXVrAjfMmc+Nl1dxw2WSNxUtWUOiLnKWrL07jgRP86752frWnjR3NnUBq7pXr51bxBx+o58Z51cyfUqqQl6yj0JfQ6+yLs/nAiTPj8m8cOUXSIRbJY8nsSfzJ8sv5wLxqrq6r0Ji8ZD2Ffo5xd3oHEnT2xek8PUhnX5yuvjgDg0kGk04i6cQTTiKZxMwoiAZzlETyyI/mURyLUF6YT1lhlPKifEpikZw6m00mnb2t3bx28CSvHjzBqwdPsKelG3fIjxjXzqzk3g/PY9ncyVw7q1LXxkvOUehnmf7BBIc6etnflvrxhuZgju+h+b6Pd/ad+QGHsZBnqS8LTS4tYHJJjOqyAqpLYlSXFlBbXsiU8gKmlBVSW15AZXFsQn29vy+eYPexLnY0d7IzuO1q7qKrPzUPekVRPtfOmsRvXjWd6+orFfISCgr9CSqRdPa39bD96Cl2HO1k17HUL/IcPtH7jhkOi/IjTAumfr1+ThVTygupLM6nvCj/zBl7WWGUwvwI0TwjkmdE8/KIRIxkMjW74kAiSXzQz8wz3tU3SOfpePBXwiAnegdo707ddh7tpK27/x0/IDEkP2LUlBYwpTz1JjD0ZlBdWpCapra0gOrS1P1o/4JwdzpPD9LW009bVz8tXf0c7OjlYHsvB9p7ONjRy7HOvjMTj5XEIiyYVs7Ka6dz9YxJLJlVydzqkgn1JiUyHhT6E0RHzwAb93ewcX8Hrx06wa7mLk7HU/OuxCJ5zJtSytUzKrh98XTm1JQwp7qU+snFVBTlZ2T4pS+eoDUI25bO1F8Yxzr7aenqo7Wrn/1tPWzY38HJ3vg5j4/kGSWxCGWF+ZQWRCkpiFAQjRDJM/LyjIil9nEnNQ1wPJmap30wQW9/gvaefuKJd/9FU11awOzJxdxw2WRmV5VwxdRSrpxWzszKYgW8CAr9jOkdGOSlN1v5VVMbG/d38ObxbgAKonlcPaOC37tuJouml7NoegXzppQSi06sDxAL8yPMrCpmZlXxe+7XF0/Q0ZP6K6Gtp5+O7tRP0HX1DdLdn/qroqd/kK7+eDB/u5Pw1Nh7IumYpZ6rKD9CZXE+BfkRSmKRt4ebSlN/SVSXxZhRWZwzP3QhcqnoX8g4OtEzwAs7j/P89uP8ck8r/YNJSmIR3l9fxcrFdVw/p4qrZlRMuN/UHI3C/AjTJxUxfVJRppsiIij0L7nBRJIXdh7niQ0HeXlvO4mkM72ikDuXzmL5olqW1lfpMkARGTcK/UukrbufpzYe5IkNB2k+1UfdpCL+6ENzuWXRVK6qq8ipyyBFJHso9MfYwfZe/u6FN3l2azMDiSQ3zavmz357ER+9spaIPkgUkQxT6I+RU6fjfPvFJh799QEiecanr5/Fv1s2m3lTSjPdNBGRMxT6oxRPJHly40G+ue5NTp6O84klM/iTW66gtrww000TEXkXhf4ovHrwBP/5X15nb2sPN8ydzJd/80reV1eR6WaJiJyXQn8EkklnzS/38bXnd1NbXsh3PtvAx66cog9nRWTCU+hfpLbufv7TM6/z0putrLhqKn/1O1dTUZSf6WaJiAyLQv8ivNzUxn1Pb6HzdJy//Pj7+PTSWTq7F5GsotAfpu/+ch9/+dxO5laX8P27l7JganmmmyQictEU+hfg7nz952/y9y82seKqqXztk9dQHNN/NhHJTkqv95BMOl9Zu41/fuUgdy6dxV/c/j59wUpEstoFJ30xs0fMrMXMtqXVqsxsnZntCe4rg7qZ2bfMrMnMtprZkrRjVgX77zGzVZemO2NnYDDJfU9v4Z9fOcgffegy/tvHFfgikv2GM9PXo8CtZ9UeANa7+3xgfbAOcBswP7itBh6G1JsE8BBwPbAUeGjojWIiOj2QYPX3G/nfrx/lv966gAduW6APbEUkJ1ww9N39JaDjrPJK4LFg+THg9rT6457yCjDJzKYBtwDr3L3D3U8A63j3G8mE0D+Y4POPN/L/3mzlr37nKv745ssy3SQRkTEz0jl9a929OVg+BtQGy3XAobT9Dge189XfxcxWm1mjmTW2traOsHkjk0w69z/zOr9qauNvf/dq7lw6a1yfX0TkUhv1RO7u7sCY/RK3u69x9wZ3b6ipqRmrhx3O8/LVZ3fw7NZmHrhtAZ9smDluzy0iMl5GGvrHg2EbgvuWoH4ESE/LGUHtfPUJ4x9+sZdHXz7A3TfN4QsfnJvp5oiIXBIjDf21wNAVOKuAn6bVPxtcxbMMOBUMAz0PLDezyuAD3OVBbUJ4etNB/vvzu7l98XS+vOJKfWgrIjnrgtfpm9mTwM1AtZkdJnUVzl8Dz5jZ3cBbwKeC3Z8DVgBNQC9wF4C7d5jZnwObgv2+6u5nfzicES/sOM6DP36DD15ew99+4hrydFmmiOQwSw3JT0wNDQ3e2Nh4yR7/jcOn+OQ/vswVtWX84PPLKCnQd9VEJPuZ2WZ3bzjXttD+IvexU3384eObmFxSwHdXXafAF5FQCGXS9fQPcvdjm+jpT/DDP15KTVlBppskIjIuQnemn0g6X3x6CzubO/mfn75Ws2WKSKiELvT/5me7WLfjOF/5rYV8+IopmW6OiMi4ClXoP7XxIGte2sdnls1m1QfqM90cEZFxF5rQ/z9bm/nST1KXZj70bxfqWnwRCaVQhP7Pth3jPz71Gu+fXcnDv7+EaCQU3RYReZecT78XdhznPzz5KlfPqOCf7lqqSzNFJNRyOvR/sbuFf//EqyycVs5jn1tKqQJfREIuZ0P/l3taWf39zVw+tZTHP3c95YX5mW6SiEjG5WTo723t5g8fa+SymlK+/7nrqShW4IuIQI5+I3dudQn3L7+c310yg8qSWKabIyIyYeRk6JsZqz+onzkUETlbTg7viIjIuSn0RURCRKEvIhIiCn0RkRBR6IuIhIhCX0QkRBT6IiIhotAXEQkRhb6ISIgo9EVEQkShLyISIgp9EZEQUeiLiISIQl9EJEQU+iIiIaLQFxEJEYW+iEiIKPRFREJEoS8iEiIKfRGREBlV6JvZATN7w8y2mFljUKsys3Vmtie4rwzqZmbfMrMmM9tqZkvGogMiIjJ8Y3Gm/2F3X+zuDcH6A8B6d58PrA/WAW4D5ge31cDDY/DcIiJyES7F8M5K4LFg+THg9rT6457yCjDJzKZdgucXEZHzGG3oO/BzM9tsZquDWq27NwfLx4DaYLkOOJR27OGg9g5mttrMGs2ssbW1dZTNExGRdNFRHn+Tux8xsynAOjPblb7R3d3M/GIe0N3XAGsAGhoaLupYERF5b6M603f3I8F9C/ATYClwfGjYJrhvCXY/AsxMO3xGUBMRkXEy4tA3sxIzKxtaBpYD24C1wKpgt1XAT4PltcBng6t4lgGn0oaBRERkHIxmeKcW+ImZDT3OD9z9Z2a2CXjGzO4G3gI+Fez/HLACaAJ6gbtG8dwiIjICIw59d98HXHOOejvw0XPUHbhnpM8nIiKjp2/kioiEiEJfRCREFPoiIiGi0BcRCRGFvohIiCj0RURCRKEvIhIiCn0RkRBR6IuIhIhCX0QkRBT6IiIhotAXEQkRhb6ISIgo9EVEQkShLyISIgp9EZEQUeiLiISIQl9EJEQU+iIiIaLQFxEJEYW+iEiIKPRFREJEoS8iEiIKfRGREFHoi4iEiEJfRCREFPoiIiGi0BcRCRGFvohIiCj0RURCRKEvIhIiCn0RkRBR6IuIhMi4h76Z3Wpmu82sycweGO/nFxEJs3ENfTOLAN8GbgMWAnea2cLxbIOISJiN95n+UqDJ3fe5+wDwFLBynNsgIhJa0XF+vjrgUNr6YeD69B3MbDWwOljtNrPdo3i+aqBtFMdPJLnUF8it/uRSX0D9mciG25fZ59sw3qF/Qe6+BlgzFo9lZo3u3jAWj5VpudQXyK3+5FJfQP2ZyMaiL+M9vHMEmJm2PiOoiYjIOBjv0N8EzDezOWYWA+4A1o5zG0REQmtch3fcfdDM7gWeByLAI+6+/RI+5ZgME00QudQXyK3+5FJfQP2ZyEbdF3P3sWiIiIhkAX0jV0QkRBT6IiIhkpOhn+1TPZjZI2bWYmbb0mpVZrbOzPYE95WZbONwmdlMM3vRzHaY2XYzuy+oZ2t/Cs1so5m9HvTnz4L6HDPbELzmng4uVMgKZhYxs9fM7NlgPZv7csDM3jCzLWbWGNSy8rUGYGaTzOyHZrbLzHaa2Q2j7U/OhX6OTPXwKHDrWbUHgPXuPh9YH6xng0HgfndfCCwD7gn+f2Rrf/qBj7j7NcBi4FYzWwb8DfBNd58HnADuzlwTL9p9wM609WzuC8CH3X1x2vXs2fpaA/gfwM/cfQFwDan/T6Prj7vn1A24AXg+bf1B4MFMt2sE/agHtqWt7wamBcvTgN2ZbuMI+/VT4DdyoT9AMfAqqW+VtwHRoP6O1+BEvpH6rsx64CPAs4Bla1+C9h4Aqs+qZeVrDagA9hNccDNW/cm5M33OPdVDXYbaMpZq3b05WD4G1GayMSNhZvXAtcAGsrg/wXDIFqAFWAfsBU66+2CwSza95v4O+C9AMlifTPb2BcCBn5vZ5mBKF8je19ocoBX4p2D47btmVsIo+5OLoZ/zPPUWn1XX2ppZKfAj4Ivu3pm+Ldv64+4Jd19M6ix5KbAgsy0aGTP7LaDF3Tdnui1j6CZ3X0JqePceM/tg+sYse61FgSXAw+5+LdDDWUM5I+lPLoZ+rk71cNzMpgEE9y0Zbs+wmVk+qcB/wt1/HJSztj9D3P0k8CKpIZBJZjb0Zcdsec3dCPy2mR0gNePtR0iNIWdjXwBw9yPBfQvwE1Jvytn6WjsMHHb3DcH6D0m9CYyqP7kY+rk61cNaYFWwvIrU2PiEZ2YGfA/Y6e7fSNuUrf2pMbNJwXIRqc8ndpIK/08Eu2VFf9z9QXef4e71pP6d/F93/32ysC8AZlZiZmVDy8ByYBtZ+lpz92PAITO7Iih9FNjBaPuT6Q8rLtEHICuAN0mNtX450+0ZQfufBJqBOKl3+7tJjbWuB/YALwBVmW7nMPtyE6k/P7cCW4Lbiizuz9XAa0F/tgFfCepzgY1AE/AvQEGm23qR/boZeDab+xK0+/Xgtn3o3362vtaCti8GGoPX2/8CKkfbH03DICISIrk4vCMiIueh0BcRCRGFvohIiCj0RURCRKEvIhIiCn0RkRBR6IuIhMj/B1wviBY5mdloAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure()\n",
"plt.plot(rmse.mean(axis=0))\n",
"plt.ylim((0, 2600))\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment