Skip to content

Instantly share code, notes, and snippets.

@peterm790
Created December 4, 2023 11:59
Show Gist options
  • Save peterm790/ecc75a7657e58836e75be9febdbe662e to your computer and use it in GitHub Desktop.
Save peterm790/ecc75a7657e58836e75be9febdbe662e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "972fc250-9b71-4b1a-a0ac-60df04ad5057",
"metadata": {},
"source": [
"# A simple Convolution Neural Net for Downscaling Single Band Raster Data"
]
},
{
"cell_type": "markdown",
"id": "e63fed36-a864-4b21-b381-98c1b0a41390",
"metadata": {},
"source": [
"The goal of this notebook is to reproduce the CNN used in [\"Physics-Constrained Deep Learning for Climate Downscaling\" (Harder et al. 2023](https://arxiv.org/pdf/2208.05424.pdf) using as simple code as possible. I have done my best rewrite the code in as verbose a way a possible with comments throughout. \n",
"\n",
"The original code is available from this [github repo](https://github.com/RolnickLab/constrained-downscaling/tree/main). Although much of it has been rewritten so it is probably fair to treat this as a generic super resolution CNN. \n",
"\n",
"This notebook was executed on the Microsoft Planatery Computer Jupyter Hub deployment with a T4 Nvidia GPU. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2974da68-ea8c-4b60-b1db-cef477664873",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"\n",
"import torch.optim as optim\n",
"import torch.nn as nn\n",
"from torch.utils.data import DataLoader, TensorDataset\n",
"\n",
"device = 'cpu'\n",
"dim_channels = 1"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "79b9661e-85ac-4cc6-a3bb-081dbbcd89fd",
"metadata": {},
"outputs": [],
"source": [
"import fsspec\n",
"\n",
"fs = fsspec.filesystem(\"\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1ca154c9-4c3d-4640-86c4-1359b221d58c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['/terra/projects/heat_center/code/downscaling/ERA5_to_WRF_3km/test',\n",
" '/terra/projects/heat_center/code/downscaling/ERA5_to_WRF_3km/.ipynb_checkpoints',\n",
" '/terra/projects/heat_center/code/downscaling/ERA5_to_WRF_3km/prediction',\n",
" '/terra/projects/heat_center/code/downscaling/ERA5_to_WRF_3km/train',\n",
" '/terra/projects/heat_center/code/downscaling/ERA5_to_WRF_3km/val']"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"fs.ls(\"ERA5_to_WRF_3km/\")"
]
},
{
"cell_type": "markdown",
"id": "76f2768f-42c9-48c7-ba86-038a84a58037",
"metadata": {},
"source": [
"# Load Data \n",
"\n",
"Pytorch relies on a data loader. In this case loading from a pytorch specific (.pt) dataset. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1d49f97b-0aee-4e7d-9b12-72f378e8055f",
"metadata": {},
"outputs": [],
"source": [
"def load_data(validation_data = 'test', batch_size = 64, dim_channels = 1): # batch size to control memory usage and dim_channels where 1 is for raster data with a single band and 3 could be for RGB data etc.\n",
" # load training data\n",
" input_train = torch.load(f'ERA5_to_WRF_3km/train/input_train.pt') # opens the data need to set if train test or val \n",
" target_train = torch.load(f'ERA5_to_WRF_3km/train/target_train.pt')\n",
" # load validation data\n",
" if validation_data == 'test':\n",
" input_val = torch.load(f'ERA5_to_WRF_3km/test/input_test.pt')\n",
" target_val = torch.load(f'ERA5_to_WRF_3km/test/target_test.pt')\n",
" elif validation_data == 'val':\n",
" input_val = torch.load(f'ERA5_to_WRF_3km/val/input_val.pt')\n",
" target_val = torch.load(f'ERA5_to_WRF_3km/val/target_val.pt')\n",
" # get dimensions\n",
" train_shape_in = input_train.shape\n",
" train_shape_out = target_train.shape\n",
" val_shape_in = input_val.shape\n",
" val_shape_out = target_val.shape\n",
" # get mean, std of training data\n",
" mean = target_train.mean()\n",
" std = target_train.std()\n",
" # min and max - first constructing an array to save to, assume this speeds things up\n",
" max_val = torch.zeros((dim_channels,1))\n",
" min_val = torch.zeros((dim_channels,1))\n",
" for i in range(dim_channels):\n",
" max_val[i] = target_train[:,0,i,...].max()\n",
" min_val[i] = target_train[:,0,i,...].min() \n",
" # standardise the data\n",
" for i in range(dim_channels):\n",
" input_train[:,0,i,...] = (input_train[:,0,i,...]-min_val[i]) /(max_val[i]-min_val[i]) # (X - min(x)) / (max(X) - min(X)) = scaled to max 1 and min 0\n",
" target_train[:,0,i,...] = (target_train[:,0,i,...] -min_val[i])/(max_val[i]-min_val[i])\n",
" input_val[:,0,i,...] = (input_val[:,0,i,...]-min_val[i])/(max_val[i]-min_val[i])\n",
" target_val[:,0,i,...] = (target_val[:,0,i,...]-min_val[i])/(max_val[i]-min_val[i])\n",
" # create a TensorDataset with input and target training data \n",
" train_data = TensorDataset(input_train, target_train)\n",
" # and for validation data\n",
" val_data = TensorDataset(input_val, target_val)\n",
" # create a data loader with a pre-defined batch size \n",
" train = DataLoader(train_data, batch_size=batch_size, shuffle=True) # not sure why one is shuffled the other not\n",
" val = DataLoader(val_data, batch_size=batch_size, shuffle=False)\n",
" return [train, val, mean, std, max_val, min_val, train_shape_in, train_shape_out, val_shape_in, val_shape_out] "
]
},
{
"cell_type": "markdown",
"id": "3db29f37-d303-4182-9876-7d100d088138",
"metadata": {},
"source": [
"# Create CNN model\n",
"\n",
"Next up we construct the model in pytorch. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a8bd94a4-33b8-4b9f-b588-af1599df23e6",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import numpy as np\n",
"from torch.autograd import Variable\n",
"\n",
"\n",
"# this is doing the lifting in our convolution neural net. A kernel size of 3 sets the kernel window to (3,3) \n",
"# thus pixels are considered in the context of their immediate neighbours\n",
"def conv3x3(in_channels, out_channels, stride=1):\n",
" return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)\n",
"\n",
"\n",
"# the residual block aka skip connection - only implements a forward block here \n",
"# in this case the block is made up of the following layers:\n",
"# 1) 3x3 2-dimensional convolution layer\n",
"# 2) A ReLU (Rectified Linear Unit) Activation Layer: adds non linearity by setting negative values to Zero\n",
"# 3) another 3x3 2-dimensional convolution layer but with no stride this time so only moving one pixel along\n",
"# 4) the addition of a residual or constant value to the ouput (implicit bias correction?)\n",
"# 5) Another ReLU (Rectified Linear Unit) Activation Layer\n",
"class ResidualBlock(nn.Module):\n",
" def __init__(self, in_channels, out_channels, stride=1, downsample=None):\n",
" super(ResidualBlock, self).__init__()\n",
" self.conv1 = conv3x3(in_channels, out_channels, stride)\n",
" self.relu = nn.ReLU(inplace=True)\n",
" self.conv2 = conv3x3(out_channels, out_channels)\n",
" \n",
" def forward(self, x):\n",
" residual = x\n",
" out = self.conv1(x)\n",
" out = self.relu(out)\n",
" out = self.conv2(out)\n",
" out += residual\n",
" out = self.relu(out)\n",
" return out\n",
"\n",
" \n",
"# here is the the model itself in this case - The ResNet model is based on the Deep Residual Learning for Image Recognition paper.\n",
"# in this case the model consists of only a forward pass made up of the following layers:\n",
"# 1) conv1: 3x3 convolution with ReLU activation to the input tensor, (considering the first channel which has no impact as this is single channel data)\n",
" # this layer takes 1 input and produces 64 outputs for each pixel (considered in a 3x3 matrix of neighbouring points)\n",
"# 2) upscale: in the case of upsampling by a factor of 4, two nn.ConvTranspose2d layer are added.\n",
" # this performs a transposed convolution operation, also known as fractionally strided convolution or deconvolution. \n",
" # This operation effectively \"blows up\" the spatial dimensions of the input feature map, creating a larger output feature map.\n",
"# 3) conv2: 3x3 convolution with ReLU activation,\n",
" # this layer takes 64 inputs and produces 64 outputs \n",
"# 4) residual: the residual blocks described above now occur. There are repeated 4 times in this case\n",
"# 5) conv3: 3x3 convolution with ReLU activation,\n",
" # this layer takes 64 inputs and produces 64 outputs \n",
"# 6) conv4: 1x1 convolution with no activation layer \n",
" # this layer takes in 64 inputs and 1 input no padding is applied in this layer\n",
"class ResNet(nn.Module):\n",
" def __init__(self, number_channels=64, number_residual_blocks=4, upsampling_factor=2, dim=1):\n",
" super(ResNet, self).__init__()\n",
" # First layer\n",
" self.conv1 = nn.Sequential(nn.Conv2d(dim, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True))\n",
" #Residual Blocks\n",
" self.res_blocks = nn.ModuleList()\n",
" for k in range(number_residual_blocks):\n",
" self.res_blocks.append(ResidualBlock(number_channels, number_channels))\n",
" # Second conv layer post residual blocks\n",
" self.conv2 = nn.Sequential(\n",
" nn.Conv2d(number_channels, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True))\n",
" # Upsampling layers\n",
" self.upsampling = nn.ModuleList()\n",
" for k in range(int(np.rint(np.log2(upsampling_factor)))):\n",
" self.upsampling.append(nn.ConvTranspose2d(number_channels, number_channels, kernel_size=2, padding=0, stride=2) )\n",
" # Next layer after upper sampling\n",
" self.conv3 = nn.Sequential(nn.Conv2d(number_channels, number_channels, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True))\n",
" # Final output layer\n",
" self.conv4 = nn.Conv2d(number_channels, dim, kernel_size=1, stride=1, padding=0) \n",
" self.dim = dim \n",
" \n",
" def forward(self, x, z=None): \n",
" out = self.conv1(x[:,0,...])\n",
" for layer in self.upsampling:\n",
" out = layer(out)\n",
" out = self.conv2(out) \n",
" for layer in self.res_blocks:\n",
" out = layer(out)\n",
" out = self.conv3(out)\n",
" out = self.conv4(out)\n",
" out = out.unsqueeze(1)\n",
" return out "
]
},
{
"cell_type": "markdown",
"id": "f9cd5610-c96c-44e2-b219-af8079f626e4",
"metadata": {},
"source": [
"And load the model with our desired parameters"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "5caabf8d-0bb2-411a-ad0d-cb51f793ad41",
"metadata": {},
"outputs": [],
"source": [
"# saved as a function as we need to reload the same model again later for evaluation\n",
"def load_model():\n",
" model = ResNet(number_channels=32, \n",
" number_residual_blocks=4, \n",
" upsampling_factor=10, \n",
" dim=1)\n",
" model = model.to(device)\n",
" return model\n",
"\n",
"model = load_model() "
]
},
{
"cell_type": "markdown",
"id": "f1fd386b-7b54-4810-8ef4-3573aff636c3",
"metadata": {},
"source": [
"Initialize the optimiser, here using the Adam (Adaptive Moment Estimation) optimiser."
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3c419c8d-96b0-4128-aef2-6aa2daba39eb",
"metadata": {},
"outputs": [],
"source": [
"optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-9) # learning rate set here"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "cd901f6c-4552-4f8e-aa5b-ea338d6e689b",
"metadata": {},
"outputs": [],
"source": [
"def optimizer_step(model, optimizer, inputs, targets, tepoch, discriminator=False):\n",
" optimizer.zero_grad() # Resets the gradients of all optimized tensors to zero\n",
" outputs = model(inputs) # get outputs from current model iteration\n",
" loss = torch.nn.functional.mse_loss(outputs, targets) # calculate loss (simple MSE between input and target)\n",
" loss.backward() # Computes the gradient of loss (RMSE) tensor wrt graph leaves.\n",
" optimizer.step() # perform a single optimization step.\n",
" return loss.item() # returns tensor as standard python number (implies the tensor is a single value not an array)"
]
},
{
"cell_type": "markdown",
"id": "427c41e6-c9c7-4ac7-ae2c-4cae27c65fd3",
"metadata": {},
"source": [
"chatGPT :: In simpler terms, the statement \"Computes the gradient of the current tensor with respect to graph leaves\" means that you are calculating how much the current tensor contributes to the overall output of your neural network, and specifically, how sensitive it is to changes in the parameters (graph leaves) of the network. This is a crucial step in the training process known as backpropagation, where the network adjusts its parameters to minimize the difference between its predictions and the actual target values."
]
},
{
"cell_type": "markdown",
"id": "ad42e3d5-37a5-48df-9012-367c3b64d6aa",
"metadata": {},
"source": [
"Validation function - to assess current model performance against out of training validation data - \"val los\".\n",
"\n",
"This is done alongside training to monitor when val loss stops decreasing alongside training loss - implying overfitting is occuring (I assume)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f026818f-9c8c-4c88-9b93-16baac3394a6",
"metadata": {},
"outputs": [],
"source": [
"def validate_model(model, data, best, epoch, discriminator_model=None, criterion_discr=None):\n",
" model.eval() # set the model to evaluation mode only. ensure no layers or gradients are changed during evaluation \n",
" running_loss = 0 \n",
" for i, (inputs, targets) in enumerate(data):\n",
" inputs = inputs.to(device) # cast data to GPU if applicable\n",
" targets = targets.to(device)\n",
" outputs = model(inputs) # run the model \n",
" loss = torch.nn.functional.mse_loss(outputs, targets) # Mean Square Error\n",
" running_loss = running_loss + loss.item() # sum up MSE across all data (ussualy time steps in our case) \n",
" loss = running_loss/len(data) # get mean\n",
" model.train() # this just sets the model back to training mode. ie. the opposite of model.eval()\n",
" return loss"
]
},
{
"cell_type": "markdown",
"id": "9a1cced3-d53c-4b4a-8ef2-e6c02210b7bd",
"metadata": {},
"source": [
"This sets the best inital val los to infinite"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "6695fc03-82ef-4edb-8836-e6ce4f353ae8",
"metadata": {},
"outputs": [],
"source": [
"best = np.inf"
]
},
{
"cell_type": "markdown",
"id": "f60f2500-4207-4021-9990-284188e460e2",
"metadata": {},
"source": [
"Best will continualy be updated by each val los as the model is trained (loops through each epoch). The function below saves the model if it performs better than the previous epoch. So if you run 100 epochs and val_loss is constant or not decreasing for the last 50 the model will not be updated. ie saved model is model from epoch 50. "
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "5125324e-dca6-47a1-b293-d3d41a924fea",
"metadata": {},
"outputs": [],
"source": [
"def checkpoint(model, val_loss, best, epoch):\n",
" if val_loss < best:\n",
" print(f'saving at epoch: {epoch}')\n",
" checkpoint = {'model': model,'state_dict': model.state_dict()}\n",
" torch.save(checkpoint, './models/minimal_CNN_ERA5_WRF_3km.pth')"
]
},
{
"cell_type": "markdown",
"id": "bddd5bbf-72cb-4d6f-90f2-b9d0450e94aa",
"metadata": {},
"source": [
"# Training time!"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7af9ebaa-77ec-4ff5-a5ed-198c94def9c6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1, Train Loss: 0.06642\n",
"Val loss: 0.01966\n",
"saving at epoch: 0\n",
"Epoch 2, Train Loss: 0.01311\n",
"Val loss: 0.00571\n",
"saving at epoch: 1\n",
"Epoch 3, Train Loss: 0.00439\n",
"Val loss: 0.00357\n",
"saving at epoch: 2\n",
"Epoch 4, Train Loss: 0.00341\n",
"Val loss: 0.00358\n",
"Epoch 5, Train Loss: 0.00314\n",
"Val loss: 0.00325\n",
"saving at epoch: 4\n",
"Epoch 6, Train Loss: 0.00310\n",
"Val loss: 0.00361\n",
"Epoch 7, Train Loss: 0.00304\n",
"Val loss: 0.00441\n",
"Epoch 8, Train Loss: 0.00334\n",
"Val loss: 0.00315\n",
"saving at epoch: 7\n",
"Epoch 9, Train Loss: 0.00308\n",
"Val loss: 0.00319\n",
"Epoch 10, Train Loss: 0.00337\n",
"Val loss: 0.00339\n",
"Epoch 11, Train Loss: 0.00310\n",
"Val loss: 0.00312\n",
"saving at epoch: 10\n",
"Epoch 12, Train Loss: 0.00294\n",
"Val loss: 0.00316\n",
"Epoch 13, Train Loss: 0.00294\n",
"Val loss: 0.00322\n",
"Epoch 14, Train Loss: 0.00307\n",
"Val loss: 0.00322\n",
"Epoch 15, Train Loss: 0.00310\n",
"Val loss: 0.00331\n",
"Epoch 16, Train Loss: 0.00339\n",
"Val loss: 0.00318\n",
"Epoch 17, Train Loss: 0.00336\n",
"Val loss: 0.00321\n",
"Epoch 18, Train Loss: 0.00308\n",
"Val loss: 0.00321\n",
"Epoch 19, Train Loss: 0.00297\n",
"Val loss: 0.00323\n",
"Epoch 20, Train Loss: 0.00300\n",
"Val loss: 0.00320\n",
"Epoch 21, Train Loss: 0.00300\n",
"Val loss: 0.00360\n",
"Epoch 22, Train Loss: 0.00321\n",
"Val loss: 0.00308\n",
"saving at epoch: 21\n",
"Epoch 23, Train Loss: 0.00301\n",
"Val loss: 0.00355\n",
"Epoch 24, Train Loss: 0.00298\n",
"Val loss: 0.00336\n",
"Epoch 25, Train Loss: 0.00293\n",
"Val loss: 0.00373\n",
"Epoch 26, Train Loss: 0.00302\n",
"Val loss: 0.00333\n",
"Epoch 27, Train Loss: 0.00295\n",
"Val loss: 0.00314\n",
"Epoch 28, Train Loss: 0.00292\n",
"Val loss: 0.00331\n",
"Epoch 29, Train Loss: 0.00297\n",
"Val loss: 0.00308\n",
"Epoch 30, Train Loss: 0.00297\n",
"Val loss: 0.00308\n",
"Epoch 31, Train Loss: 0.00315\n",
"Val loss: 0.00418\n",
"Epoch 32, Train Loss: 0.00319\n",
"Val loss: 0.00311\n",
"Epoch 33, Train Loss: 0.00308\n",
"Val loss: 0.00314\n",
"Epoch 34, Train Loss: 0.00293\n",
"Val loss: 0.00332\n",
"Epoch 35, Train Loss: 0.00292\n",
"Val loss: 0.00309\n",
"Epoch 36, Train Loss: 0.00294\n",
"Val loss: 0.00316\n",
"Epoch 37, Train Loss: 0.00295\n",
"Val loss: 0.00310\n",
"Epoch 38, Train Loss: 0.00304\n",
"Val loss: 0.00312\n",
"Epoch 39, Train Loss: 0.00294\n",
"Val loss: 0.00355\n",
"Epoch 40, Train Loss: 0.00289\n",
"Val loss: 0.00314\n",
"Epoch 41, Train Loss: 0.00304\n",
"Val loss: 0.00306\n",
"saving at epoch: 40\n",
"Epoch 42, Train Loss: 0.00313\n",
"Val loss: 0.00337\n",
"Epoch 43, Train Loss: 0.00311\n",
"Val loss: 0.00314\n",
"Epoch 44, Train Loss: 0.00292\n",
"Val loss: 0.00306\n",
"Epoch 45, Train Loss: 0.00299\n",
"Val loss: 0.00365\n",
"Epoch 46, Train Loss: 0.00349\n",
"Val loss: 0.00328\n",
"Epoch 47, Train Loss: 0.00290\n",
"Val loss: 0.00306\n",
"Epoch 48, Train Loss: 0.00309\n",
"Val loss: 0.00381\n",
"Epoch 49, Train Loss: 0.00293\n",
"Val loss: 0.00310\n",
"Epoch 50, Train Loss: 0.00296\n",
"Val loss: 0.00306\n",
"Epoch 51, Train Loss: 0.00288\n",
"Val loss: 0.00321\n",
"Epoch 52, Train Loss: 0.00305\n",
"Val loss: 0.00355\n",
"Epoch 53, Train Loss: 0.00294\n",
"Val loss: 0.00341\n",
"Epoch 54, Train Loss: 0.00298\n",
"Val loss: 0.00314\n",
"Epoch 57, Train Loss: 0.00290\n",
"Val loss: 0.00342\n",
"Epoch 58, Train Loss: 0.00304\n",
"Val loss: 0.00307\n",
"Epoch 59, Train Loss: 0.00293\n",
"Val loss: 0.00309\n",
"Epoch 60, Train Loss: 0.00291\n",
"Val loss: 0.00322\n",
"Epoch 61, Train Loss: 0.00284\n",
"Val loss: 0.00306\n",
"saving at epoch: 60\n",
"Epoch 62, Train Loss: 0.00289\n",
"Val loss: 0.00317\n",
"Epoch 63, Train Loss: 0.00298\n",
"Val loss: 0.00316\n",
"Epoch 64, Train Loss: 0.00298\n",
"Val loss: 0.00327\n",
"Epoch 65, Train Loss: 0.00294\n",
"Val loss: 0.00310\n",
"Epoch 66, Train Loss: 0.00301\n",
"Val loss: 0.00307\n",
"Epoch 67, Train Loss: 0.00295\n",
"Val loss: 0.00407\n",
"Epoch 68, Train Loss: 0.00316\n",
"Val loss: 0.00313\n",
"Epoch 69, Train Loss: 0.00323\n",
"Val loss: 0.00355\n",
"Epoch 70, Train Loss: 0.00303\n",
"Val loss: 0.00304\n",
"saving at epoch: 69\n",
"Epoch 71, Train Loss: 0.00286\n",
"Val loss: 0.00348\n",
"Epoch 72, Train Loss: 0.00291\n",
"Val loss: 0.00317\n",
"Epoch 73, Train Loss: 0.00299\n",
"Val loss: 0.00315\n",
"Epoch 74, Train Loss: 0.00297\n",
"Val loss: 0.00312\n",
"Epoch 75, Train Loss: 0.00292\n",
"Val loss: 0.00373\n",
"Epoch 76, Train Loss: 0.00301\n",
"Val loss: 0.00307\n",
"Epoch 77, Train Loss: 0.00294\n",
"Val loss: 0.00325\n",
"Epoch 78, Train Loss: 0.00327\n",
"Val loss: 0.00313\n",
"Epoch 79, Train Loss: 0.00287\n",
"Val loss: 0.00305\n",
"Epoch 80, Train Loss: 0.00298\n",
"Val loss: 0.00316\n",
"Epoch 81, Train Loss: 0.00288\n",
"Val loss: 0.00308\n",
"Epoch 82, Train Loss: 0.00289\n",
"Val loss: 0.00303\n",
"saving at epoch: 81\n",
"Epoch 83, Train Loss: 0.00288\n",
"Val loss: 0.00334\n",
"Epoch 84, Train Loss: 0.00302\n",
"Val loss: 0.00363\n",
"Epoch 85, Train Loss: 0.00293\n",
"Val loss: 0.00306\n",
"Epoch 86, Train Loss: 0.00290\n",
"Val loss: 0.00308\n",
"Epoch 87, Train Loss: 0.00286\n",
"Val loss: 0.00308\n",
"Epoch 88, Train Loss: 0.00292\n",
"Val loss: 0.00305\n",
"Epoch 89, Train Loss: 0.00294\n",
"Val loss: 0.00338\n",
"Epoch 90, Train Loss: 0.00311\n",
"Val loss: 0.00328\n",
"Epoch 91, Train Loss: 0.00295\n",
"Val loss: 0.00326\n",
"Epoch 92, Train Loss: 0.00299\n",
"Val loss: 0.00308\n",
"Epoch 93, Train Loss: 0.00290\n",
"Val loss: 0.00304\n",
"Epoch 94, Train Loss: 0.00295\n",
"Val loss: 0.00331\n",
"Epoch 95, Train Loss: 0.00298\n",
"Val loss: 0.00311\n",
"Epoch 96, Train Loss: 0.00296\n",
"Val loss: 0.00303\n",
"Epoch 97, Train Loss: 0.00293\n",
"Val loss: 0.00359\n",
"Epoch 98, Train Loss: 0.00301\n",
"Val loss: 0.00385\n",
"Epoch 99, Train Loss: 0.00299\n",
"Val loss: 0.00304\n",
"Epoch 100, Train Loss: 0.00298\n",
"Val loss: 0.00302\n",
"saving at epoch: 99\n",
"Epoch 101, Train Loss: 0.00294\n",
"Val loss: 0.00320\n",
"Epoch 102, Train Loss: 0.00294\n",
"Val loss: 0.00306\n",
"Epoch 103, Train Loss: 0.00292\n",
"Val loss: 0.00308\n",
"Epoch 104, Train Loss: 0.00302\n",
"Val loss: 0.00306\n",
"Epoch 105, Train Loss: 0.00292\n",
"Val loss: 0.00303\n",
"Epoch 106, Train Loss: 0.00295\n",
"Val loss: 0.00332\n",
"Epoch 107, Train Loss: 0.00288\n",
"Val loss: 0.00317\n",
"Epoch 108, Train Loss: 0.00292\n",
"Val loss: 0.00330\n",
"Epoch 109, Train Loss: 0.00288\n",
"Val loss: 0.00307\n",
"Epoch 110, Train Loss: 0.00287\n",
"Val loss: 0.00313\n",
"Epoch 111, Train Loss: 0.00285\n",
"Val loss: 0.00313\n",
"Epoch 112, Train Loss: 0.00287\n",
"Val loss: 0.00304\n",
"Epoch 113, Train Loss: 0.00289\n",
"Val loss: 0.00336\n",
"Epoch 114, Train Loss: 0.00291\n",
"Val loss: 0.00304\n",
"Epoch 115, Train Loss: 0.00290\n",
"Val loss: 0.00331\n",
"Epoch 116, Train Loss: 0.00299\n",
"Val loss: 0.00321\n",
"Epoch 117, Train Loss: 0.00302\n",
"Val loss: 0.00337\n",
"Epoch 118, Train Loss: 0.00308\n",
"Val loss: 0.00323\n",
"Epoch 119, Train Loss: 0.00307\n",
"Val loss: 0.00304\n",
"Epoch 120, Train Loss: 0.00289\n",
"Val loss: 0.00309\n",
"Epoch 121, Train Loss: 0.00293\n",
"Val loss: 0.00308\n",
"Epoch 122, Train Loss: 0.00297\n",
"Val loss: 0.00318\n",
"Epoch 123, Train Loss: 0.00289\n",
"Val loss: 0.00373\n",
"Epoch 124, Train Loss: 0.00305\n",
"Val loss: 0.00304\n",
"Epoch 125, Train Loss: 0.00290\n",
"Val loss: 0.00309\n",
"Epoch 126, Train Loss: 0.00286\n",
"Val loss: 0.00305\n",
"Epoch 127, Train Loss: 0.00292\n",
"Val loss: 0.00327\n",
"Epoch 128, Train Loss: 0.00321\n",
"Val loss: 0.00309\n",
"Epoch 129, Train Loss: 0.00295\n",
"Val loss: 0.00338\n",
"Epoch 130, Train Loss: 0.00300\n",
"Val loss: 0.00306\n",
"Epoch 131, Train Loss: 0.00289\n",
"Val loss: 0.00349\n",
"Epoch 132, Train Loss: 0.00295\n",
"Val loss: 0.00321\n",
"Epoch 133, Train Loss: 0.00286\n",
"Val loss: 0.00309\n",
"Epoch 134, Train Loss: 0.00288\n",
"Val loss: 0.00317\n",
"Epoch 135, Train Loss: 0.00288\n",
"Val loss: 0.00330\n",
"Epoch 136, Train Loss: 0.00292\n",
"Val loss: 0.00305\n",
"Epoch 137, Train Loss: 0.00287\n",
"Val loss: 0.00307\n",
"Epoch 138, Train Loss: 0.00295\n",
"Val loss: 0.00308\n",
"Epoch 139, Train Loss: 0.00297\n",
"Val loss: 0.00308\n",
"Epoch 140, Train Loss: 0.00294\n",
"Val loss: 0.00305\n",
"Epoch 141, Train Loss: 0.00289\n",
"Val loss: 0.00314\n",
"Epoch 142, Train Loss: 0.00283\n",
"Val loss: 0.00305\n",
"Epoch 143, Train Loss: 0.00284\n",
"Val loss: 0.00304\n",
"Epoch 144, Train Loss: 0.00284\n",
"Val loss: 0.00306\n",
"Epoch 145, Train Loss: 0.00286\n",
"Val loss: 0.00307\n",
"Epoch 146, Train Loss: 0.00300\n",
"Val loss: 0.00324\n",
"Epoch 147, Train Loss: 0.00289\n",
"Val loss: 0.00330\n",
"Epoch 148, Train Loss: 0.00298\n",
"Val loss: 0.00322\n",
"Epoch 149, Train Loss: 0.00293\n",
"Val loss: 0.00306\n",
"Epoch 150, Train Loss: 0.00296\n",
"Val loss: 0.00315\n",
"Epoch 151, Train Loss: 0.00285\n",
"Val loss: 0.00321\n",
"Epoch 152, Train Loss: 0.00285\n",
"Val loss: 0.00305\n",
"Epoch 153, Train Loss: 0.00287\n",
"Val loss: 0.00337\n",
"Epoch 154, Train Loss: 0.00298\n",
"Val loss: 0.00352\n",
"Epoch 155, Train Loss: 0.00288\n",
"Val loss: 0.00309\n",
"Epoch 156, Train Loss: 0.00289\n",
"Val loss: 0.00306\n",
"Epoch 157, Train Loss: 0.00286\n",
"Val loss: 0.00303\n",
"Epoch 158, Train Loss: 0.00295\n",
"Val loss: 0.00328\n",
"Epoch 159, Train Loss: 0.00307\n",
"Val loss: 0.00306\n",
"Epoch 160, Train Loss: 0.00292\n",
"Val loss: 0.00306\n",
"Epoch 161, Train Loss: 0.00283\n",
"Val loss: 0.00320\n",
"Epoch 162, Train Loss: 0.00286\n",
"Val loss: 0.00310\n",
"Epoch 163, Train Loss: 0.00284\n",
"Val loss: 0.00301\n",
"saving at epoch: 162\n",
"Epoch 164, Train Loss: 0.00296\n",
"Val loss: 0.00315\n",
"Epoch 165, Train Loss: 0.00300\n",
"Val loss: 0.00328\n",
"Epoch 166, Train Loss: 0.00301\n",
"Val loss: 0.00301\n",
"Epoch 167, Train Loss: 0.00282\n",
"Val loss: 0.00314\n",
"Epoch 168, Train Loss: 0.00298\n",
"Val loss: 0.00308\n",
"Epoch 169, Train Loss: 0.00287\n",
"Val loss: 0.00304\n",
"Epoch 170, Train Loss: 0.00287\n",
"Val loss: 0.00304\n",
"Epoch 171, Train Loss: 0.00293\n",
"Val loss: 0.00302\n",
"Epoch 172, Train Loss: 0.00304\n",
"Val loss: 0.00410\n",
"Epoch 173, Train Loss: 0.00325\n",
"Val loss: 0.00313\n",
"Epoch 174, Train Loss: 0.00282\n",
"Val loss: 0.00303\n",
"Epoch 175, Train Loss: 0.00291\n",
"Val loss: 0.00309\n",
"Epoch 176, Train Loss: 0.00291\n",
"Val loss: 0.00308\n",
"Epoch 177, Train Loss: 0.00291\n",
"Val loss: 0.00307\n",
"Epoch 178, Train Loss: 0.00294\n",
"Val loss: 0.00307\n",
"Epoch 179, Train Loss: 0.00289\n",
"Val loss: 0.00303\n",
"Epoch 180, Train Loss: 0.00284\n",
"Val loss: 0.00302\n",
"Epoch 181, Train Loss: 0.00283\n",
"Val loss: 0.00302\n",
"Epoch 182, Train Loss: 0.00291\n",
"Val loss: 0.00302\n",
"Epoch 183, Train Loss: 0.00289\n",
"Val loss: 0.00323\n",
"Epoch 184, Train Loss: 0.00293\n",
"Val loss: 0.00312\n",
"Epoch 185, Train Loss: 0.00286\n",
"Val loss: 0.00308\n",
"Epoch 186, Train Loss: 0.00308\n",
"Val loss: 0.00314\n",
"Epoch 187, Train Loss: 0.00285\n",
"Val loss: 0.00306\n",
"Epoch 188, Train Loss: 0.00298\n",
"Val loss: 0.00309\n",
"Epoch 189, Train Loss: 0.00287\n",
"Val loss: 0.00311\n",
"Epoch 190, Train Loss: 0.00286\n",
"Val loss: 0.00309\n",
"Epoch 191, Train Loss: 0.00286\n",
"Val loss: 0.00305\n",
"Epoch 192, Train Loss: 0.00286\n",
"Val loss: 0.00313\n",
"Epoch 193, Train Loss: 0.00288\n",
"Val loss: 0.00302\n",
"Epoch 194, Train Loss: 0.00297\n",
"Val loss: 0.00309\n",
"Epoch 195, Train Loss: 0.00291\n",
"Val loss: 0.00340\n",
"Epoch 196, Train Loss: 0.00307\n",
"Val loss: 0.00318\n",
"Epoch 197, Train Loss: 0.00290\n",
"Val loss: 0.00306\n",
"Epoch 198, Train Loss: 0.00290\n",
"Val loss: 0.00307\n",
"Epoch 199, Train Loss: 0.00285\n",
"Val loss: 0.00311\n",
"Epoch 200, Train Loss: 0.00293\n",
"Val loss: 0.00312\n",
"CPU times: user 3d 1h 14min 11s, sys: 19h 11min 3s, total: 3d 20h 25min 14s\n",
"Wall time: 3h 53min 4s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"train, val, mean, std, max_val, min_val, train_shape_in, train_shape_out, val_shape_in, val_shape_out = load_data(validation_data = 'test') \n",
"\n",
"epochs = 200 # only 5 epochs\n",
"\n",
"for epoch in range(epochs): # loop through n times number of epochs set. \n",
" running_loss = 0 \n",
" running_discr_loss = 0\n",
" running_adv_loss = 0\n",
" for (inputs, targets) in train: # for each time step in data \n",
" inputs = inputs.to(device) # cast data to GPU if applicable \n",
" targets = targets.to(device)\n",
" loss = optimizer_step(model, optimizer, inputs, targets, train) # run optimizer given this particular set of data (time step)\n",
" running_loss = running_loss + loss # sum up MSE across all data (ussualy time steps in our case) \n",
" loss = running_loss/len(train) # mean MSE for this epoch\n",
" print('Epoch {}, Train Loss: {:.5f}'.format(epoch+1, loss))\n",
" val_loss = validate_model(model, val, best, epoch) # validate model against validation data (slightly unfair advantage to train loss as model is incrementally optimised for each time step in training data) \n",
" print('Val loss: {:.5f}'.format(val_loss))\n",
" checkpoint(model, val_loss, best, epoch) # save model \n",
" best = np.minimum(best, val_loss)# update best if val_loss < best"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d547b038-3e24-4aae-9780-d174e27c2dcf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0030101195443421602"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"best"
]
},
{
"cell_type": "markdown",
"id": "037326c3-7a08-406d-9d3a-66a66a3437fd",
"metadata": {},
"source": [
"Epoch 162 was the best performant / last saved "
]
},
{
"cell_type": "markdown",
"id": "5245b489-feeb-4da0-96b5-7c1bdb46c2fc",
"metadata": {},
"source": [
"This is the end of training. Everything below here is just evaluating the model."
]
},
{
"cell_type": "markdown",
"id": "1ea17aaf-d2c7-4f28-8bfd-3b4591d50d37",
"metadata": {},
"source": [
"# Evaluate model performance\n",
"\n",
"against validation set"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "31e05f3e-a29d-4d39-8db3-23aa3f88d993",
"metadata": {},
"outputs": [],
"source": [
"def load_trained_model(model): # renamed from load_weights to avoid confusion \n",
" checkpoint = torch.load('./models/minimal_CNN_ERA5_WRF_3km.pth') \n",
" model.load_state_dict(checkpoint['state_dict'])\n",
" model = model.to(device)\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "1ee8b1f5-a5a5-4df9-a18f-0973bb3c24e6",
"metadata": {},
"outputs": [],
"source": [
"train, val, mean, std, max_val,min_val, train_shape_in, train_shape_out, val_shape_in, val_shape_out = load_data('val') "
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "e2ec1758-831c-4135-86bb-db2a43f1e74c",
"metadata": {},
"outputs": [],
"source": [
"def process_for_eval(outputs, targets, mean, std, max_val, min_val): \n",
" for i in range(dim_channels):\n",
" outputs[:,0,i,...] = outputs[:,0,i,...]*(max_val[i].to(device)-min_val[i].to(device))+min_val[i].to(device) \n",
" targets[:,0,i,...] = targets[:,0,i,...]*(max_val[i].to(device)-min_val[i].to(device))+min_val[i].to(device)\n",
" return outputs, targets"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "0023ed02-352d-463c-8e6b-86d465946237",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 51.2 s, sys: 13.5 s, total: 1min 4s\n",
"Wall time: 3.01 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"model = load_model() # so this loads the model as defined above\n",
"load_trained_model(model) # then load the weights we have pretrained above \n",
"model.eval() # set the model to eval model \n",
"full_pred = torch.zeros(val_shape_out) # an empty array to save result to \n",
"\n",
"batch_size = 64 # number of time steps at a time, so in this case 64 at a time until last step which is remainder (16 in this example) - explicitly commented here as the source of 16*128*128 array confused me\n",
"\n",
"for i,(inputs, targets) in enumerate(val): \n",
" inputs = inputs.to(device) # cast data to GPU if applicable \n",
" targets = targets.to(device)\n",
" outputs = model(inputs) \n",
" outputs, targets = process_for_eval(outputs, targets,mean, std, max_val, min_val) \n",
" full_pred[i*batch_size:i*batch_size+outputs.shape[0],...] = outputs.detach().cpu()\n",
"torch.save(full_pred, 'ERA5_to_WRF_3km/prediction/eval_out.pt')"
]
},
{
"cell_type": "markdown",
"id": "903137fc-fc8c-4d69-bc70-e6918ce3cfb8",
"metadata": {},
"source": [
"So this is now a torch tensor array with downscaled outputs from out eval input dataset"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "f964b1ab-296a-4a0f-9974-edd87d5a9f7e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([293, 1, 1, 104, 112])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"full_pred.shape"
]
},
{
"cell_type": "markdown",
"id": "a050af3c-a784-4fd6-bd0f-3a2a464fb49c",
"metadata": {},
"source": [
"Calculate Scores"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "a75aa475-c212-4aef-8485-432a83f86c62",
"metadata": {},
"outputs": [],
"source": [
"input_val = torch.load('ERA5_to_WRF_3km/val/input_val.pt')\n",
"target_val = torch.load('ERA5_to_WRF_3km/val/target_val.pt')\n",
"val_data = TensorDataset(input_val, target_val)\n",
"pred = np.zeros(target_val.shape)\n",
"max_val = target_val.max()\n",
"min_val = target_val.min()"
]
},
{
"cell_type": "markdown",
"id": "1ecdd264-30c5-4b8e-81a8-f94d0113e1f7",
"metadata": {},
"source": [
"define some evaluation functions "
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "15708960-bdfa-4794-9f21-c5337c3c20b3",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-11-30 00:52:10.898945: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2023-11-30 00:52:33.778733: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
]
}
],
"source": [
"from torchmetrics.functional import multiscale_structural_similarity_index_measure, structural_similarity_index_measure\n",
"from skimage import transform\n",
"\n",
"def pearsonr(x, y):\n",
" mean_x = torch.mean(x)\n",
" mean_y = torch.mean(y)\n",
" xm = x.sub(mean_x)\n",
" ym = y.sub(mean_y)\n",
" r_num = xm.dot(ym)\n",
" r_den = torch.norm(xm, 2) * torch.norm(ym, 2)\n",
" r_val = r_num / r_den\n",
" return r_val\n",
"\n",
"def calculate_pnsr(mse, max_val):\n",
" return 20 * torch.log10(max_val / torch.sqrt(torch.Tensor([mse])))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "38a50895-e459-4ad2-98fb-6c48d30670c8",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/pmarsh/.local/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:70: FutureWarning: Importing `multiscale_structural_similarity_index_measure` from `torchmetrics.functional` was deprecated and will be removed in 2.0. Import `multiscale_structural_similarity_index_measure` from `torchmetrics.image` instead.\n",
" _future_warning(\n",
"/home/pmarsh/.local/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:70: FutureWarning: Importing `spectral_angle_mapper` from `torchmetrics.functional` was deprecated and will be removed in 2.0. Import `spectral_angle_mapper` from `torchmetrics.image` instead.\n",
" _future_warning(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 1min 3s, sys: 803 ms, total: 1min 4s\n",
"Wall time: 3.57 s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"mse = 0\n",
"mae = 0\n",
"ssim = 0\n",
"mean_bias = 0\n",
"mean_abs_bias = 0\n",
"mass_violation = 0\n",
"ms_ssim = 0\n",
"corr = 0\n",
"crps = 0\n",
"neg_mean = 0\n",
"neg_num = 0\n",
"\n",
"l2_crit = nn.MSELoss() # creates a criteron that measures the mean squared error (squared L2 norm) between each element in the input and target (assume this is used for performance benefits from pytorch)\n",
"l1_crit = nn.L1Loss() # creates a criteron that measures the mean absolute error (MAE) between each element in the input and target\n",
"\n",
"en_pred = torch.load('ERA5_to_WRF_3km/prediction/eval_out.pt') # load predictions of validations set (saved in cell aboves)\n",
"pred = en_pred.detach().cpu().numpy()\n",
"\n",
"upsampling_factor = 8\n",
"n = input_val.shape[0] # number of time steps (length)\n",
"\n",
"j = 0 \n",
"for i,(lr, hr) in enumerate(val_data):\n",
" im = lr.numpy()\n",
" mse += l2_crit(torch.Tensor(pred[i,j,...]), hr[j,...]).item()\n",
" mae += l1_crit(torch.Tensor(pred[i,j,...]), hr[j,...]).item()\n",
" mean_bias += torch.mean( hr[j,...]-torch.Tensor(pred[i,j,...]))\n",
" mean_abs_bias += torch.abs(torch.mean( hr[j,...]-torch.Tensor(pred[i,j,...])))\n",
" corr += pearsonr(torch.Tensor(pred[i,j,...]).flatten(), hr[j,...].flatten())\n",
" ms_ssim += multiscale_structural_similarity_index_measure(torch.Tensor(pred[i,j:j+1,...]), hr[j:j+1,...], data_range=max_val-min_val, kernel_size=11, betas=(0.2856, 0.3001, 0.2363))\n",
" ssim += structural_similarity_index_measure(torch.Tensor(pred[i,j:j+1,...]), hr[j:j+1,...] , data_range=max_val-min_val, kernel_size=11)\n",
" neg_num += np.sum(pred[i,j,...] < 0)\n",
" neg_mean += np.sum(pred[pred < 0])/(pred.shape[-1]*pred.shape[-1])\n",
" mass_violation += np.mean( np.abs(transform.downscale_local_mean(pred[i,j,...], (1,upsampling_factor,upsampling_factor)) -im[j,...]))\n",
"\n",
"mse = mse/n\n",
"mae = mae/n\n",
"ssim = ssim/n\n",
"mean_bias = mean_bias/n\n",
"mean_abs_bias = mean_abs_bias/n\n",
"corr = corr/n\n",
"ms_ssim = ms_ssim/n\n",
"crps = crps/n\n",
"neg_mean = neg_mean/n\n",
"mass_violation = mass_violation/n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "37f7935c-8e7b-473c-bd9a-32b1faacb909",
"metadata": {},
"outputs": [],
"source": [
"psnr = calculate_pnsr(mse, target_val.max() ) \n",
"rmse = torch.sqrt(torch.Tensor([mse])).numpy()[0]\n",
"ssim = float(ssim.numpy())\n",
"ms_ssim =float( ms_ssim.numpy())\n",
"psnr = psnr.numpy()\n",
"corr = float(corr.numpy())\n",
"mean_bias = float(mean_bias.numpy())\n",
"mean_abs_bias = float(mean_abs_bias.numpy())\n",
"\n",
"scores = {'MSE':mse, 'RMSE':rmse, 'PSNR': psnr[0], 'MAE':mae, 'SSIM':ssim, 'MS SSIM': ms_ssim, 'Pearson corr': corr, 'Mean bias': mean_bias, 'Mean abs bias': mean_abs_bias, 'Mass_violation': mass_violation, 'neg mean': neg_mean, 'neg num': neg_num,'CRPS': crps}"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "af4c45f3-9086-454a-b3ae-ad6cd78af9ce",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"MSE : 4.333741138005826\n",
"RMSE : 2.081764\n",
"PSNR : 24.841316\n",
"MAE : 1.6233453388507049\n",
"SSIM : 0.8405888676643372\n",
"MS SSIM : 0.8236749768257141\n",
"Pearson corr : 0.8625420928001404\n",
"Mean bias : 0.23750866949558258\n",
"Mean abs bias : 1.1562156677246094\n",
"Mass_violation : 0.5928562571367713\n",
"neg mean : 0.0\n",
"neg num : 0\n",
"CRPS : 0.0\n"
]
}
],
"source": [
"for index in scores:\n",
" print(index, ':', scores[index])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cda40650-4c1f-4295-97c2-a4a337437ac0",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Pangeo",
"language": "python",
"name": "pangeo"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment