Skip to content

Instantly share code, notes, and snippets.

@rkube
Created September 21, 2022 20:47
Show Gist options
  • Save rkube/dfe6847198c4ca6b21425c4824d76ae9 to your computer and use it in GitHub Desktop.
Save rkube/dfe6847198c4ca6b21425c4824d76ae9 to your computer and use it in GitHub Desktop.
AE_pred_LSTMtest_1723xx.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "b2cea2eb",
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "2cb09a0b",
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"# USe the D3D loader to fetch and prepare data\n",
"# https://github.com/PlasmaControl/d3d_loaders\n",
"sys.path.append(\"/home/rkube/repos/d3d_loaders\")\n",
"\n",
"from os.path import join\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.autograd import Variable\n",
"\n",
"\n",
"from torch.utils.data import ConcatDataset\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from matplotlib.ticker import MultipleLocator\n",
"import logging"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cee5f538",
"metadata": {},
"outputs": [],
"source": [
"from d3d_loaders.d3d_loaders import D3D_dataset\n",
"from d3d_loaders.samplers import SequentialSequenceSampler, RandomSequenceSampler, collate_fn_randseq"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "18e5a535",
"metadata": {},
"outputs": [],
"source": [
"# Set up a shot and define the time-interval for prediction\n",
"shot_list_train = [172337, 172339, 172341, 172342]\n",
"shot_list_valid = [172340]\n",
"tstart = 110.0 # Time of first sample for upper triangularity is 100.0\n",
"tend = 1000.0\n",
"# Define re-sampling paramters\n",
"t_params = {\"tstart\": tstart, \"tend\": tend, \"tsample\": 10.0}\n",
"t_shift = 10.0\n",
"\n",
"device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"seq_length = 10\n",
"batch_size = (seq_length + 1) * 16"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a13dfa29",
"metadata": {},
"outputs": [],
"source": [
"# Create exhaustive list of predictors, as in https://doi.org/10.1088/1741-4326/abe08d\n",
"# In that paper, there are 5 profiles\n",
"# and 10 scalars defined\n",
"# We have 13-1 scalars\n",
"# 10 + neutronsrate + ae_prob + echpwrc - doutl (not available)\n",
"pred_list = [\"pinj\", # Injected power\n",
" \"tinj\", # Injected torque\n",
" \"neutronsrate\", # MEasured neutron rate\n",
" \"iptipp\", # Target current\n",
" \"dstdenp\", # Target density\n",
" \"doutu\", # Top triangularity\n",
" # \"doutl\", # Bottom triangularity not available\n",
" # \"elongm\", # Plasma elongation - data not available\n",
" #\"vout\", # Plasma volume - data not available\n",
" # \"ali\", # Internal inductance - THis data is just zero\n",
" # \"echpwrc\", # Total ECH power - No ECH in this show\n",
" \"dssdenest\", # Line-averaged density\n",
" \"ae_prob\"] # AE mode probability\n",
"\n",
"# Set the list of targets, see \n",
"targ_list = [\"ae_prob\"]\n",
"\n",
"# Instantiate a dataset. For multiple shots, use https://pytorch.org/docs/stable/data.html#torch.utils.data.ConcatDataset\n",
"# At this stage, the data needs to be downloaded using https://github.com/PlasmaControl/d3d_loaders/blob/main/d3d_loaders/downloading.py\n",
"ds_list_train = []\n",
"for shotnr in shot_list_train:\n",
" ds = D3D_dataset(shotnr, t_params,\n",
" predictors=pred_list,\n",
" targets=targ_list,\n",
" shift_targets={\"ae_prob\": t_shift},\n",
" datapath=\"/projects/EKOLEMEN/d3dloader/test\",\n",
" device=device)\n",
" ds_list_train.append(ds)\n",
"ds_train = ConcatDataset(ds_list_train)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7fff4bf7",
"metadata": {},
"outputs": [],
"source": [
"ds_list_valid = []\n",
"for shotnr in shot_list_valid:\n",
" ds = D3D_dataset(shotnr, {\"tstart\": tstart, \"tend\": tend, \"tsample\": 10.0},\n",
" predictors=pred_list,\n",
" targets=targ_list,\n",
" shift_targets={\"ae_prob\": 10.0},\n",
" datapath=\"/projects/EKOLEMEN/d3dloader/test\",\n",
" device=device)\n",
" ds_list_valid.append(ds)\n",
"ds_valid = ConcatDataset(ds_list_valid)"
]
},
{
"cell_type": "markdown",
"id": "220ff11d",
"metadata": {},
"source": [
"When running loader_train in standard fashion, we get something like\n",
"x.shape = [16, 12], and y.shape = [16, 5].\n",
"The first dimension is the batch size, the second dimension is the feature length\n",
"\n",
"An LSTM expect dimensions of [L, N, H], where L is the sequence length. N is the batch size. And H is the hidden size.\n",
"\n",
"To achieve this, we can use random sequence sampler:\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "fc0432cc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(10, 176)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"seq_length, batch_size"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "b34c59f4",
"metadata": {},
"outputs": [],
"source": [
"# Bad hack to get collate_fn_seq, appropriate for current sequence size:\n",
"\n",
"class collate_functor():\n",
" def __init__(self, seq_length):\n",
" self.seq_length = seq_length\n",
" \n",
" def __call__(self, x):\n",
" return collate_fn_randseq(x, self.seq_length)\n",
" \n",
" \n",
"my_collate_fn = collate_functor(seq_length)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "f5a12aa4",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([10, 16, 12]) torch.Size([1, 16, 5])\n"
]
}
],
"source": [
"loader_train = torch.utils.data.DataLoader(ds_train, batch_size=batch_size, num_workers=0,\n",
" sampler=RandomSequenceSampler(range(len(ds_train)), seq_length=seq_length),\n",
" collate_fn=my_collate_fn)\n",
"\n",
"loader_valid = torch.utils.data.DataLoader(ds_valid, batch_size=batch_size, num_workers=0,\n",
" sampler=RandomSequenceSampler(range(len(ds_valid)), seq_length=seq_length),\n",
" collate_fn=my_collate_fn)\n",
"\n",
"x, y = next(iter(loader_train))\n",
"print(x.shape, y.shape)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "a62e00e8",
"metadata": {},
"outputs": [],
"source": [
"class my_lstm(nn.Module):\n",
" def __init__(self, num_classes, input_size, hidden_size, seq_length, num_layers=2, device=\"cpu\"):\n",
" super(my_lstm, self).__init__()\n",
" self.num_classes = num_classes # Number of output features\n",
" self.input_size = input_size # Number of features in the input x\n",
" self.hidden_size = hidden_size # Number of features in hidden state h\n",
" self.seq_length = seq_length \n",
" self.num_layers = num_layers\n",
" self.device = device\n",
" \n",
" self.lstm = nn.LSTM(input_size=self.input_size, \n",
" hidden_size=self.hidden_size,\n",
" num_layers=self.num_layers, \n",
" batch_first=False).to(device) #lstm\n",
" self.fc_1 = nn.Linear(hidden_size, 128).to(device) #fully connected 1\n",
" self.fc = nn.Linear(128, num_classes).to(device) #fully connected last layer\n",
"\n",
" self.relu = nn.ReLU()\n",
" \n",
" def forward(self, x):\n",
" h_0 = Variable(torch.zeros(self.num_layers, x.size(1), self.hidden_size)).to(self.device) #hidden state\n",
" c_0 = Variable(torch.zeros(self.num_layers, x.size(1), self.hidden_size)).to(self.device) #internal state\n",
" # Propagate input through LSTM\n",
" output, (hn, cn) = self.lstm(x, (h_0, c_0)) #lstm with input, hidden, and internal state\n",
" \n",
" hn = hn[-1, ...] # Take output of the last hidden recurrent layer\n",
" out = self.relu(hn)\n",
" out = self.fc_1(out) #first Dense\n",
" out = self.relu(out) #relu\n",
" out = self.fc(out) # Final Output\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "65ae6d08",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/rkube/.conda/envs/ml202205/lib/python3.10/site-packages/torch/backends/cudnn/__init__.py:73: UserWarning: PyTorch was compiled without cuDNN/MIOpen support. To use cuDNN/MIOpen, rebuild PyTorch making sure the library is visible to the build system.\n",
" warnings.warn(\n"
]
}
],
"source": [
"model = my_lstm(5, 12, hidden_size=64, seq_length=seq_length, device=device)\n",
"\n",
"loss_fn = nn.MSELoss()\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "a3e1482b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/rkube/.conda/envs/ml202205/lib/python3.10/site-packages/torch/nn/modules/loss.py:530: UserWarning: Using a target size (torch.Size([1, 16, 5])) that is different to the input size (torch.Size([16, 5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
" return F.mse_loss(input, target, reduction=self.reduction)\n",
"/home/rkube/.conda/envs/ml202205/lib/python3.10/site-packages/torch/nn/modules/loss.py:530: UserWarning: Using a target size (torch.Size([1, 10, 5])) that is different to the input size (torch.Size([10, 5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
" return F.mse_loss(input, target, reduction=self.reduction)\n",
"/home/rkube/.conda/envs/ml202205/lib/python3.10/site-packages/torch/nn/modules/loss.py:530: UserWarning: Using a target size (torch.Size([1, 15, 5])) that is different to the input size (torch.Size([15, 5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n",
" return F.mse_loss(input, target, reduction=self.reduction)\n"
]
}
],
"source": [
"num_epochs = 10\n",
"\n",
"losses_train_epoch = np.zeros(num_epochs)\n",
"losses_valid_epoch = np.zeros(num_epochs)\n",
"\n",
"for epoch in range(num_epochs):\n",
" model.train()\n",
"\n",
" loss_train = 0.0\n",
" loss_valid = 0.0\n",
" for i, (data, target) in enumerate(loader_train):\n",
" optimizer.zero_grad()\n",
"\n",
" outputs = model(data)\n",
"\n",
" loss = loss_fn(outputs, target)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" loss_train += loss.item()\n",
" #print(f\"batch {i}: loss = {loss.item()}\")\n",
" \n",
" with torch.no_grad():\n",
" for i, (data, target) in enumerate(loader_valid):\n",
" outputs = model(data)\n",
" loss_valid += loss_fn(outputs, target).item()\n",
" \n",
" losses_train_epoch[epoch] = loss_train / len(loader_train) / batch_size\n",
" losses_valid_epoch[epoch] = loss_valid / len(loader_valid) / batch_size"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "97849d26",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x2008bcd43cd0>]"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(losses_train_epoch)\n",
"plt.plot(losses_valid_epoch)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "a14fc739",
"metadata": {},
"outputs": [],
"source": [
"# Instantiate data loader that processes entire shot in-order\n",
"loader_rcr = torch.utils.data.DataLoader(ds_valid, batch_size=batch_size, num_workers=0,\n",
" sampler=RandomSequenceSampler(range(len(ds_valid)), seq_length=seq_length),\n",
" collate_fn=my_collate_fn)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "007d22d3",
"metadata": {},
"outputs": [],
"source": [
"ds = ds_list_valid[0]\n",
"\n",
"mean = ds.predictors[\"ae_prob\"].data_mean.cpu()\n",
"std = ds.predictors[\"ae_prob\"].data_std.cpu()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "66836a91",
"metadata": {},
"outputs": [],
"source": [
"# Compile sequential output for an entire shot.\n",
"\n",
"output_list = []\n",
"target_list = []\n",
"\n",
"for i, (data, target) in enumerate(loader_rcr):\n",
" with torch.no_grad():\n",
" output_list.append(model(data).cpu() * std + mean)\n",
" target_list.append(target.cpu() * std + mean)\n",
" \n",
"targets = torch.cat([t[0, :, :] for t in target_list])\n",
"outputs = torch.cat(output_list)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "809527b7",
"metadata": {},
"outputs": [],
"source": [
"def make_pred_plots(outputs, targets, epoch=None):\n",
" \"\"\"Plots predicted vs true change in AE probability.\"\"\"\n",
" ax_dx, ax_dy = 0.15, 0.15\n",
"\n",
" fig_vs = plt.figure(figsize=(12, 6))\n",
" ax_list = []\n",
" # Plot\n",
" for i in range(5):\n",
" ax_list.append(fig_vs.add_axes([0.1 + i * ax_dx, 0.1, ax_dx, ax_dy]))\n",
" ax_list[i].plot([0.0, 1.0], [0.0, 1.0], 'k', alpha=0.5)\n",
" ax_list[i].plot(outputs[:, i].cpu(), targets[:, i].cpu(), '.', ms=1)\n",
" ax_list[i].set_xlim((0.0, 1.0))\n",
" ax_list[i].set_ylim((0.0, 1.0))\n",
"\n",
" # Pretty-fy\n",
" for i, ax in enumerate(ax_list):\n",
" ax.xaxis.set_major_locator(MultipleLocator(0.25))\n",
" ax.xaxis.set_minor_locator(MultipleLocator(0.05))\n",
" \n",
" ax.yaxis.set_major_locator(MultipleLocator(0.25))\n",
" ax.yaxis.set_minor_locator(MultipleLocator(0.05))\n",
" \n",
" ax.set_xlabel(\"predicted\")\n",
" ax.set_title(f\"AE mode {i}\")\n",
" \n",
" for ax in ax_list[1:4]:\n",
" ax.yaxis.set_ticklabels([])\n",
"\n",
" ax_list[0].set_ylabel(\"True\")\n",
" ax_list[4].yaxis.tick_right()\n",
"\n",
" try:\n",
" fig_vs.text(0.5, 0.3, f\"epoch {epoch:03d}: d(Probability of AE modes) / dt\", ha=\"center\")\n",
" except TypeError: # When epoch is None\n",
" fig_vs.text(0.5, 0.3, f\"Probability of AE modes: True vs. predicted\", ha=\"center\")\n",
"\n",
" return fig_vs"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "c80fc854",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 864x432 with 5 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"res = make_pred_plots(outputs, targets);"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment