Skip to content

Instantly share code, notes, and snippets.

@martinsbruveris
Created August 4, 2021 19:55
Show Gist options
  • Save martinsbruveris/8b3808ed9bb7c50683e3b0fbc59b05da to your computer and use it in GitHub Desktop.
Save martinsbruveris/8b3808ed9bb7c50683e3b0fbc59b05da to your computer and use it in GitHub Desktop.
Baseline model for turbulence prediction
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "09a9fb28",
"metadata": {},
"outputs": [],
"source": [
"import itertools\n",
"import math\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import torch\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "markdown",
"id": "2be4e4ad",
"metadata": {},
"source": [
"## Preprocess data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6c68f19d",
"metadata": {},
"outputs": [],
"source": [
"data = torch.load(\"data/tf_net/rbc_data.pt\")\n",
"print(data.shape) # (2000, 2, 256, 1792)\n",
"\n",
"# Standardization\n",
"# std = torch.std(data)\n",
"# avg = torch.mean(data)\n",
"# We computed this once on the whole dataset. This is a shortcut to save memory\n",
"std = 4506.0068\n",
"avg = -0.7982\n",
"data = (data - avg)/std\n",
"\n",
"# Subsampling (maybe averaging would be better)\n",
"data = data[:, :, ::4, ::4].clone().detach()\n",
"print(data.shape) # (2000, 2, 64, 448)\n",
"torch.save(data, \"data/tf_net/data.pt\")"
]
},
{
"cell_type": "markdown",
"id": "5138ca5c",
"metadata": {},
"source": [
"## Dataset loader"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f560ec3a",
"metadata": {},
"outputs": [],
"source": [
"class RBCDataset(torch.utils.data.Dataset):\n",
" def __init__(\n",
" self, \n",
" data,\n",
" sequence_length,\n",
" input_length,\n",
" midpoint,\n",
" output_length,\n",
" stack_x,\n",
" ):\n",
" if sequence_length > data.shape[0]:\n",
" raise ValueError(\"sequence_length larger than data.\")\n",
" if input_length > midpoint:\n",
" raise ValueError(\"input_length larger than midpoint.\")\n",
" if midpoint + output_length > sequence_length:\n",
" raise ValueError(\"output goes past end of sequence\")\n",
" \n",
" self.data = data # (2000, 2, 64, 448)\n",
" self.sequence_length = sequence_length\n",
" self.input_length = input_length\n",
" self.midpoint = midpoint\n",
" self.output_length = output_length\n",
" self.stack_x = stack_x\n",
" \n",
" self.region_width = 64 # Could make this a parameter...\n",
" width = self.data.shape[3]\n",
" self.indices = [\n",
" (a, b) \n",
" for a in range(self.data.shape[0] - self.sequence_length)\n",
" for b in range(width // self.region_width)\n",
" ]\n",
"\n",
" def __len__(self):\n",
" return len(self.indices)\n",
"\n",
" def __getitem__(self, index):\n",
" a, b = self.indices[index]\n",
" w_from = b * self.region_width\n",
" w_to = w_from + self.region_width\n",
" sequence = self.data[a:a+self.sequence_length, :, :, w_from:w_to]\n",
" x = sequence[(self.midpoint - self.input_length):self.midpoint]\n",
" y = sequence[self.midpoint:(self.midpoint + self.output_length)]\n",
" if self.stack_x:\n",
" # Stack time and (u, v) dimension into one\n",
" x = x.reshape(-1, x.shape[-2], x.shape[-1])\n",
" x = x.clone().detach()\n",
" y = y.clone().detach()\n",
" return x.float(), y.float()"
]
},
{
"cell_type": "markdown",
"id": "eb56311b",
"metadata": {},
"source": [
"## Baseline model"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "dcdca4db",
"metadata": {},
"outputs": [],
"source": [
"class BaselineModel:\n",
" def __init__(self, data, output_length, strides, neighbors, device=None):\n",
" self.data = data\n",
" self.output_length = output_length\n",
" self.strides = strides\n",
" self.neighbors = neighbors\n",
" \n",
" if device is None:\n",
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" self.device = device\n",
" \n",
" self.nb_frames = self.data.shape[0]\n",
" self.height = self.data.shape[2]\n",
" self.width = self.data.shape[3]\n",
" \n",
" def __call__(self, x):\n",
" # x has shape (N, 2, H, W)\n",
" y = torch.zeros((self.output_length, *x.shape[1:]))\n",
" y = y.to(self.device)\n",
" for l in range(y.shape[1]):\n",
" src = x[:, l, :, :]\n",
" \n",
" a = self.nb_frames - self.output_length - x.shape[0]\n",
" b = (self.height - x.shape[2]) // self.strides[0] + 1\n",
" c = (self.width - x.shape[3]) // self.strides[1] + 1\n",
" dists = torch.zeros((a, b, c)).to(self.device)\n",
" for i, j, k in itertools.product(range(a), range(b), range(c)):\n",
" h_from = j * self.strides[0]\n",
" h_to = h_from + x.shape[2]\n",
" w_from = k * self.strides[1]\n",
" w_to = w_from + x.shape[3]\n",
" vec = self.data[i:i+x.shape[0], l, h_from:h_to, w_from:w_to]\n",
" dists[i, j, k] = torch.sqrt(torch.mean((src - vec)**2))\n",
" \n",
" dists = dists.cpu().numpy()\n",
" sorted_dist_idx = np.argsort(dists.flatten())\n",
" total_weight = 0\n",
" for idx in sorted_dist_idx[:self.neighbors]:\n",
" i, j, k = np.unravel_index(idx, dists.shape)\n",
" t_from = i + x.shape[0]\n",
" t_to = t_from + self.output_length\n",
" h_from = j * self.strides[0]\n",
" h_to = h_from + x.shape[2]\n",
" w_from = k * self.strides[1]\n",
" w_to = w_from + x.shape[3]\n",
" \n",
" weight = 1. / dists[i, j, k]\n",
" y[:, l, :, :] += weight * self.data[t_from:t_to, l, h_from:h_to, w_from:w_to]\n",
" total_weight += weight\n",
" y[:, l, :, :] /= total_weight\n",
" return y "
]
},
{
"cell_type": "markdown",
"id": "c3b7d145",
"metadata": {},
"source": [
"## Loss"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "bef318ad",
"metadata": {},
"outputs": [],
"source": [
"def rmse_loss(x, y):\n",
" std = 4506.0068\n",
" avg = -0.7982\n",
" x = x * std + avg\n",
" y = y * std + avg\n",
" mean = torch.mean(torch.mean(torch.mean((x-y)**2, axis=-1), axis=-1), axis=-1)\n",
" loss = torch.sqrt(mean)\n",
" return loss"
]
},
{
"cell_type": "markdown",
"id": "f67f7437",
"metadata": {},
"source": [
"## Compute results"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "b753d1c2",
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "af80f164",
"metadata": {},
"outputs": [],
"source": [
"data = torch.load(\"../data/tf_net/data.pt\").to(device)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "8e0b8bfc",
"metadata": {},
"outputs": [],
"source": [
"def run_experiment(input_length, strides, neighbors):\n",
" # Other shared parameters\n",
" output_length = 60\n",
" \n",
" test_ds = RBCDataset(\n",
" data = data[1100:1500],\n",
" sequence_length=100,\n",
" input_length=input_length,\n",
" midpoint=40,\n",
" output_length=output_length,\n",
" stack_x=False,\n",
" )\n",
" \n",
" model = BaselineModel(\n",
" data=data[0:1100], \n",
" output_length=output_length, \n",
" strides=strides,\n",
" neighbors=neighbors,\n",
" )\n",
" \n",
" rmse = np.zeros((len(test_ds), output_length))\n",
" for j, (x, y) in tqdm(enumerate(test_ds)):\n",
" y_pred = model(x)\n",
" rmse[j] = rmse_loss(y, y_pred).cpu().numpy()\n",
"\n",
" np.save(\n",
" f\"../data/tf_net/baseline_rmse_n{neighbors}_i{input_length}_s{strides[1]}.npy\",\n",
" rmse\n",
" )\n",
" \n",
" return rmse"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "67a3a9fb",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"rmse = run_experiment(input_length=1, strides=(1, 8), neighbors=20)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "c6c299b9",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure()\n",
"plt.plot(rmse.mean(axis=0))\n",
"plt.ylim((0, 2600))\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment