Skip to content

Instantly share code, notes, and snippets.

@reachsumit
Last active December 28, 2023 22:18
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save reachsumit/f4a55186706675a085157c64fd1e0634 to your computer and use it in GitHub Desktop.
Save reachsumit/f4a55186706675a085157c64fd1e0634 to your computer and use it in GitHub Desktop.
Bayesian LSTM end-to-end demo in PyTorch
Display the source blob
Display the rendered blob
Raw
{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.7.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"code","source":"import os\nimport pandas as pd\nimport numpy as np\nfrom tqdm import trange, tqdm\n\nfrom io import BytesIO\nfrom urllib.request import urlopen\nfrom zipfile import ZipFile\n\nfrom pandas import read_csv\nfrom scipy import stats\n\ndef prep_data(data, covariates, data_start, train = True):\n time_len = data.shape[0]\n input_size = window_size-stride_size\n windows_per_series = np.full((num_series), (time_len-input_size-target_window_size) // stride_size)\n if train: windows_per_series -= (data_start+stride_size-1) // stride_size\n total_windows = np.sum(windows_per_series)\n x_input = np.zeros((total_windows, window_size, 1 + num_covariates), dtype='float32')\n label = np.zeros((total_windows, target_window_size, 1 + num_covariates), dtype='float32')\n v_input = np.zeros((total_windows, 2), dtype='float32')\n count = 0\n for series in trange(num_series): # for each time series\n for i in range(windows_per_series[series]):\n if train:\n window_start = stride_size*i+data_start[series]\n else:\n window_start = stride_size*i\n window_end = window_start+window_size\n target_window_end = window_end+target_window_size\n x_input[count, :, 0] = data[window_start:window_end, series]\n x_input[count, :, 1:1+num_covariates] = covariates[window_start:window_end, :]\n label[count, :, 0] = data[window_end:target_window_end, series]\n label[count,:, 1:1+num_covariates] = covariates[window_end:target_window_end, :]\n nonzero_sum = (x_input[count, 1:input_size, 0]!=0).sum()\n if nonzero_sum == 0:\n v_input[count, 0] = 0\n else:\n v_input[count, 0] = np.true_divide(x_input[count, :input_size, 0].sum(),nonzero_sum)+1\n x_input[count, :, 0] = x_input[count, :, 0]/v_input[count, 0]\n label[count, :, 0] = label[count, :, 0]/v_input[count, 0]\n count += 1\n return x_input, v_input, label\n\ndef gen_covariates(times, num_covariates):\n covariates = np.zeros((times.shape[0], num_covariates))\n for i, input_time in enumerate(times):\n covariates[i, 0] = input_time.weekday()\n covariates[i, 1] = input_time.hour\n covariates[i, 2] = input_time.month\n for i in range(num_covariates):\n covariates[:,i] = stats.zscore(covariates[:,i])\n return covariates[:, :num_covariates]","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:32:55.634300Z","iopub.execute_input":"2023-01-15T07:32:55.634704Z","iopub.status.idle":"2023-01-15T07:32:55.979749Z","shell.execute_reply.started":"2023-01-15T07:32:55.634622Z","shell.execute_reply":"2023-01-15T07:32:55.978454Z"},"trusted":true},"execution_count":1,"outputs":[]},{"cell_type":"code","source":"name = 'LD2011_2014.txt'\nsave_name = 'elect'\nwindow_size = 192\nstride_size = 24\ntarget_window_size = 24\nnum_covariates = 3\ntrain_start = '2011-01-01 00:00:00'\ntrain_end = '2013-12-31 23:00:00'\nvalidation_start = '2014-01-01 23:00:00'\nvalidation_end = '2014-08-31 23:00:00'\ntest_start = '2014-08-25 00:00:00'\ntest_end = '2014-09-07 23:00:00'\n\nsave_path = os.path.join('data', save_name)\nif not os.path.exists(save_path):\n os.makedirs(save_path)\ncsv_path = os.path.join(save_path, name)\nif not os.path.exists(csv_path):\n zipurl = 'https://archive.ics.uci.edu/ml/machine-learning-databases/00321/LD2011_2014.txt.zip'\n with urlopen(zipurl) as zipresp:\n with ZipFile(BytesIO(zipresp.read())) as zfile:\n zfile.extractall(save_path)\n\ndata_frame = pd.read_csv(csv_path, sep=\";\", index_col=0, parse_dates=True, decimal=',')\ndata_frame = data_frame.resample('1H',label = 'left',closed = 'right').sum()[train_start:test_end]\ndata_frame.fillna(0, inplace=True)\n\ncovariates = gen_covariates(data_frame[train_start:test_end].index, num_covariates)\n\ntrain_data = data_frame[train_start:train_end].values\nvalidation_data = data_frame[validation_start:validation_end].values\ntest_data = data_frame[test_start:test_end].values\n\ndata_start = (train_data!=0).argmax(axis=0) #find first nonzero value in each time series\ntotal_time = data_frame.shape[0] #32304\nnum_series = data_frame.shape[1] #370\n\nX_train, v_train, y_train = prep_data(train_data, covariates, data_start)\nX_validation, v_validation, y_validation = prep_data(validation_data, covariates, data_start, train=False)\nX_test, v_test, y_test = prep_data(test_data, covariates, data_start, train=False)","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:32:55.982048Z","iopub.execute_input":"2023-01-15T07:32:55.982410Z","iopub.status.idle":"2023-01-15T07:33:27.003205Z","shell.execute_reply.started":"2023-01-15T07:32:55.982371Z","shell.execute_reply":"2023-01-15T07:33:27.002086Z"},"trusted":true},"execution_count":2,"outputs":[{"name":"stderr","text":"100%|██████████| 370/370 [00:07<00:00, 49.72it/s]\n100%|██████████| 370/370 [00:01<00:00, 198.72it/s]\n100%|██████████| 370/370 [00:00<00:00, 8429.07it/s]\n","output_type":"stream"}]},{"cell_type":"code","source":"import torch\nimport torch.nn as nn","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:33:27.004994Z","iopub.execute_input":"2023-01-15T07:33:27.005416Z","iopub.status.idle":"2023-01-15T07:33:27.580017Z","shell.execute_reply.started":"2023-01-15T07:33:27.005377Z","shell.execute_reply":"2023-01-15T07:33:27.578928Z"},"trusted":true},"execution_count":3,"outputs":[]},{"cell_type":"code","source":"class VariationalDropout(nn.Module):\n \"\"\"\n See https://arxiv.org/abs/1512.05287 for more details.\n \"\"\"\n def __init__(self, dropout):\n super().__init__()\n self.dropout = dropout\n\n def forward(self, x: torch.Tensor) -> torch.Tensor:\n if not self.training:\n return x\n max_batch_size = x.size(1)\n # Drop same mask across entire sequence\n m = x.new_empty(1, max_batch_size, x.size(2), requires_grad=False).bernoulli_(1 - self.dropout)\n x = x.masked_fill(m == 0, 0) / (1 - self.dropout)\n return x\n\nclass LSTM(nn.LSTM):\n def __init__(self,\n *args, \n dropouti=0.,\n dropoutw=0., \n dropouto=0.,\n unit_forget_bias=True, \n **kwargs):\n super().__init__(*args, **kwargs)\n self.unit_forget_bias = unit_forget_bias\n self.dropoutw = dropoutw\n self.input_drop = VariationalDropout(dropouti)\n self.output_drop = VariationalDropout(dropouto)\n\n def _drop_weights(self):\n for name, param in self.named_parameters():\n if \"weight_hh\" in name:\n getattr(self, name).data = torch.nn.functional.dropout(param.data, p=self.dropoutw, training=self.training).contiguous()\n\n def forward(self, input):\n self._drop_weights()\n input = self.input_drop(input)\n seq, state = super().forward(input)\n return self.output_drop(seq), state\n\nclass VDEncoder(nn.Module):\n def __init__(self, in_features, out_features, p):\n super(VDEncoder, self).__init__()\n self.model = nn.ModuleDict({\n 'lstm1': LSTM(in_features, 32,dropouto=p),\n 'lstm2': LSTM(32, 8, dropouto=p),\n 'lstm3': LSTM(8, out_features, dropouto=p)\n })\n \n def forward(self, x):\n out, _ = self.model['lstm1'](x)\n out, _ = self.model['lstm2'](out)\n out, _ = self.model['lstm3'](out)\n\n return out\n\n\nclass VDDecoder(nn.Module):\n def __init__(self, p):\n super(VDDecoder, self).__init__()\n self.model = nn.ModuleDict({\n 'lstm1': LSTM(1, 2, dropouto=p),\n 'lstm2': LSTM(2, 2, dropouto=p),\n 'lstm3': LSTM(2, 1, dropouto=p)\n })\n \n def forward(self, x):\n out, _ = self.model['lstm1'](x)\n out, _ = self.model['lstm2'](out)\n out, _ = self.model['lstm3'](out)\n\n return out\n\n\nclass VDEncoderDecoder(nn.Module):\n def __init__(self, in_features, output_steps, p):\n super(VDEncoderDecoder, self).__init__()\n self.enc_in_features = in_features\n self.output_steps = output_steps # f in the paper\n self.enc_out_features = 1\n self.traffic_col = 4\n \n self.model = nn.ModuleDict({\n 'encoder': VDEncoder(self.enc_in_features, self.enc_out_features, p),\n 'decoder': VDDecoder(p),\n 'fc1': nn.Linear(window_size, 32),\n 'fc2': nn.Linear(32, self.output_steps)\n })\n\n def forward(self, x):\n out = self.model['encoder'](x)\n out = self.model['decoder'](out)\n out = self.model['fc1'](out.squeeze().view(64,-1))\n out = self.model['fc2'](out)\n\n return out","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:33:27.583343Z","iopub.execute_input":"2023-01-15T07:33:27.583976Z","iopub.status.idle":"2023-01-15T07:33:27.606472Z","shell.execute_reply.started":"2023-01-15T07:33:27.583929Z","shell.execute_reply":"2023-01-15T07:33:27.605172Z"},"trusted":true},"execution_count":4,"outputs":[]},{"cell_type":"code","source":"from torch.utils.data import DataLoader, Dataset\nfrom torch.utils.data.sampler import RandomSampler\n\nclass TrainDataset(Dataset):\n def __init__(self, data, label):\n self.data = data\n self.label = label\n self.train_len = self.data.shape[0]\n\n def __len__(self):\n return self.train_len\n\n def __getitem__(self, index):\n return (self.data[index,:,0], self.data[index,:,1:1+num_covariates], self.label[index,:,0], self.label[index,:,1:1+num_covariates])\n\n\nclass ValidationAndTestDataset(Dataset):\n def __init__(self, data, v, label):\n self.data = data\n self.v = v\n self.label = label\n self.test_len = self.data.shape[0]\n\n def __len__(self):\n return self.test_len\n\n def __getitem__(self, index):\n return (self.data[index,:,0], self.data[index,:,1:1+num_covariates], self.v[index], self.label[index,:,0], self.label[index,:,1:1+num_covariates])","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:33:27.609439Z","iopub.execute_input":"2023-01-15T07:33:27.610296Z","iopub.status.idle":"2023-01-15T07:33:27.622084Z","shell.execute_reply.started":"2023-01-15T07:33:27.610253Z","shell.execute_reply":"2023-01-15T07:33:27.621084Z"},"trusted":true},"execution_count":5,"outputs":[]},{"cell_type":"code","source":"train_batch_size = 64\n\ntrain_set = TrainDataset(data=X_train, label=y_train)\nvalidation_set = ValidationAndTestDataset(data=X_validation, v=v_validation, label=y_validation)\ntest_set = ValidationAndTestDataset(data=X_test, v=v_test, label=y_test)\n\ntrain_loader = DataLoader(train_set, batch_size=train_batch_size, drop_last=True)\nvalidation_loader = DataLoader(validation_set, batch_size=train_batch_size, sampler=RandomSampler(test_set))\ntest_loader = DataLoader(test_set, batch_size=train_batch_size, sampler=RandomSampler(test_set), drop_last=True)","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:33:27.623721Z","iopub.execute_input":"2023-01-15T07:33:27.624296Z","iopub.status.idle":"2023-01-15T07:33:27.635745Z","shell.execute_reply.started":"2023-01-15T07:33:27.624256Z","shell.execute_reply":"2023-01-15T07:33:27.634751Z"},"trusted":true},"execution_count":6,"outputs":[]},{"cell_type":"code","source":"import torch.nn.functional as F\nimport torch.optim as optim\nfrom tqdm import tqdm\n\ndef train_encdec(model, device=torch.device('cuda'), num_epochs = 2, learning_rate = 1e-3):\n train_len = len(train_loader)\n optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n loss_summary = []\n loss_fn = F.mse_loss\n \n for epoch in range(num_epochs):\n model.train()\n epoch_loss_sum = 0.0\n total_sample = 0\n \n pbar = tqdm(train_loader)\n for (train_batch, current_covs_batch, labels_batch, next_covs_batch) in pbar:\n batch_size, seq_len, horizon_size = train_batch.shape[0], train_batch.shape[1], labels_batch.shape[0]\n total_sample += batch_size * seq_len * horizon_size\n optimizer.zero_grad()\n \n train_batch = train_batch.unsqueeze(2).permute(1,0,2).to(torch.float32).to(device)\n\n out = model(train_batch)\n loss = loss_fn(out.float(), labels_batch.squeeze().to(device).float())\n \n pbar.set_description(f\"Loss:{loss.item()}\")\n loss.backward()\n optimizer.step()\n \n loss_summary.append(loss.cpu().detach())\n return loss_summary","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:33:27.637437Z","iopub.execute_input":"2023-01-15T07:33:27.637726Z","iopub.status.idle":"2023-01-15T07:33:27.650246Z","shell.execute_reply.started":"2023-01-15T07:33:27.637701Z","shell.execute_reply":"2023-01-15T07:33:27.649432Z"},"trusted":true},"execution_count":7,"outputs":[]},{"cell_type":"code","source":"encdec_model = VDEncoderDecoder(in_features=1, output_steps=target_window_size, p=0.25).cuda()","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:33:27.651805Z","iopub.execute_input":"2023-01-15T07:33:27.652256Z","iopub.status.idle":"2023-01-15T07:33:31.383773Z","shell.execute_reply.started":"2023-01-15T07:33:27.652222Z","shell.execute_reply":"2023-01-15T07:33:31.382804Z"},"trusted":true},"execution_count":8,"outputs":[]},{"cell_type":"code","source":"train_encdec(encdec_model, num_epochs = 2, device=torch.device('cuda'))","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:33:31.385153Z","iopub.execute_input":"2023-01-15T07:33:31.386123Z","iopub.status.idle":"2023-01-15T07:40:10.269164Z","shell.execute_reply.started":"2023-01-15T07:33:31.386083Z","shell.execute_reply":"2023-01-15T07:40:10.268057Z"},"trusted":true},"execution_count":9,"outputs":[{"name":"stderr","text":"Loss:0.055384665727615356: 100%|██████████| 5004/5004 [03:20<00:00, 24.97it/s] \nLoss:0.08962950110435486: 100%|██████████| 5004/5004 [03:18<00:00, 25.22it/s] \n","output_type":"stream"},{"execution_count":9,"output_type":"execute_result","data":{"text/plain":"[tensor(0.0554), tensor(0.0896)]"},"metadata":{}}]},{"cell_type":"code","source":"class PredictionNetwork(nn.Module):\n def __init__(self, encoder_decoder, p=0.25):\n super(PredictionNetwork, self).__init__()\n self.encoder = encoder_decoder.model['encoder'].eval()\n self.model = nn.Sequential(\n nn.Linear((1 + num_covariates)*window_size, 128),\n nn.Dropout(p),\n nn.ReLU(),\n nn.Dropout(p),\n nn.Linear(128, 64),\n nn.ReLU(),\n nn.Dropout(p),\n nn.Linear(64, target_window_size)\n )\n\n def forward(self, x_input, cov_input):\n extracted = self.encoder(x_input)\n x_concat = torch.cat([extracted, cov_input], dim=-1)\n out = self.model(x_concat.view(train_batch_size, -1)).squeeze()\n return out","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:40:10.274057Z","iopub.execute_input":"2023-01-15T07:40:10.274982Z","iopub.status.idle":"2023-01-15T07:40:10.283384Z","shell.execute_reply.started":"2023-01-15T07:40:10.274943Z","shell.execute_reply":"2023-01-15T07:40:10.282143Z"},"trusted":true},"execution_count":10,"outputs":[]},{"cell_type":"code","source":"def train_prediction_network(model, device=torch.device('cuda'), num_epochs = 2, learning_rate = 1e-3):\n train_len = len(train_loader)\n optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n loss_summary = []\n loss_fn = F.mse_loss\n \n for epoch in range(num_epochs):\n model.train()\n epoch_loss_sum = 0.0\n total_sample = 0\n \n pbar = tqdm(train_loader)\n for (train_batch, current_covs_batch, labels_batch, next_covs_batch) in pbar:\n batch_size, seq_len, horizon_size = train_batch.shape[0], train_batch.shape[1], labels_batch.shape[0]\n total_sample += batch_size * seq_len * horizon_size\n optimizer.zero_grad()\n\n train_batch = train_batch.unsqueeze(2).permute(1,0,2).to(torch.float32).to(device)\n current_covs_batch = current_covs_batch.permute(1,0,2).to(torch.float32).to(device)\n\n out = model(train_batch, current_covs_batch)\n loss = loss_fn(out.float(), labels_batch.squeeze().to(device).float())\n \n pbar.set_description(f\"Loss:{loss.item()}\")\n loss.backward()\n optimizer.step()\n \n loss_summary.append(loss.cpu().detach())\n return loss_summary, optimizer\n\ndef evaluate_prediction_network(model, optimizer, device=torch.device('cuda')):\n criterion = nn.MSELoss()\n rmse_results = []\n\n with torch.no_grad():\n model.eval()\n loss_epoch = np.zeros(len(train_loader))\n\n pbar = tqdm(test_loader)\n for (ts_data_batch, current_covs_batch, v_batch, labels_batch, next_covs_batch) in pbar:\n optimizer.zero_grad()\n ts_data_batch = ts_data_batch.unsqueeze(2).permute(1,0,2).to(torch.float32).to(device)\n current_covs_batch = current_covs_batch.permute(1,0,2).to(torch.float32).to(device)\n\n out = model(ts_data_batch, current_covs_batch)\n rmse_results.append(torch.sqrt(criterion(out.detach().cpu(), labels_batch)).item())\n\n test_rmse = np.mean(rmse_results)\n return test_rmse","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:40:10.285085Z","iopub.execute_input":"2023-01-15T07:40:10.285465Z","iopub.status.idle":"2023-01-15T07:40:10.301457Z","shell.execute_reply.started":"2023-01-15T07:40:10.285429Z","shell.execute_reply":"2023-01-15T07:40:10.300495Z"},"trusted":true},"execution_count":11,"outputs":[]},{"cell_type":"code","source":"prednet_model = PredictionNetwork(encdec_model).cuda()","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:40:10.304650Z","iopub.execute_input":"2023-01-15T07:40:10.304981Z","iopub.status.idle":"2023-01-15T07:40:10.319577Z","shell.execute_reply.started":"2023-01-15T07:40:10.304949Z","shell.execute_reply":"2023-01-15T07:40:10.318568Z"},"trusted":true},"execution_count":12,"outputs":[]},{"cell_type":"code","source":"loss_summary, optimizer = train_prediction_network(prednet_model, num_epochs = 2, device=torch.device('cuda'))","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:40:10.321139Z","iopub.execute_input":"2023-01-15T07:40:10.321757Z","iopub.status.idle":"2023-01-15T07:43:41.837780Z","shell.execute_reply.started":"2023-01-15T07:40:10.321719Z","shell.execute_reply":"2023-01-15T07:43:41.836718Z"},"trusted":true},"execution_count":13,"outputs":[{"name":"stderr","text":"Loss:0.06553472578525543: 100%|██████████| 5004/5004 [01:45<00:00, 47.31it/s] \nLoss:0.08229856193065643: 100%|██████████| 5004/5004 [01:45<00:00, 47.33it/s] \n","output_type":"stream"}]},{"cell_type":"code","source":"evaluate_prediction_network(prednet_model, optimizer)","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:43:41.839144Z","iopub.execute_input":"2023-01-15T07:43:41.840070Z","iopub.status.idle":"2023-01-15T07:43:42.188301Z","shell.execute_reply.started":"2023-01-15T07:43:41.840021Z","shell.execute_reply":"2023-01-15T07:43:42.187247Z"},"trusted":true},"execution_count":14,"outputs":[{"name":"stderr","text":"100%|██████████| 34/34 [00:00<00:00, 100.99it/s]\n","output_type":"stream"},{"execution_count":14,"output_type":"execute_result","data":{"text/plain":"0.9305521276067285"},"metadata":{}}]},{"cell_type":"code","source":"def accuracy_RMSE(mu: torch.Tensor, labels: torch.Tensor, relative = False):\n zero_index = (labels != 0)\n diff = torch.sum(torch.mul((mu[zero_index] - labels[zero_index]), (mu[zero_index] - labels[zero_index]))).item()\n if relative is False:\n return [diff, torch.sum(zero_index).item(), torch.sum(zero_index).item()]\n else:\n summation = torch.sum(torch.abs(labels[zero_index])).item()\n if summation == 0:\n logger.error('summation denominator error! ')\n return [diff, summation, torch.sum(zero_index).item()]\n\ndef update_metrics(raw_metrics, sample_mu, labels, relative=False):\n # TODO: use samples to calcualte rou50, rou90 metrics\n raw_metrics['RMSE'] = raw_metrics['RMSE'] + accuracy_RMSE(sample_mu, labels, relative=relative)\n return raw_metrics\n\ndef final_metrics(raw_metrics):\n summary_metric = {}\n summary_metric['RMSE'] = np.sqrt(raw_metrics['RMSE'][0] / raw_metrics['RMSE'][2]) / (\n raw_metrics['RMSE'][1] / raw_metrics['RMSE'][2])\n return summary_metric\n\ndef dropout_on(m):\n if type(m) in [torch.nn.Dropout, LSTM]:\n m.train()\n\ndef dropout_off(m):\n if type(m) in [torch.nn.Dropout, LSTM]:\n m.eval()\n\ndef mc_dropout(model, B, device):\n model = model.apply(dropout_on)\n\n pbar = range(B)\n pbar = tqdm(pbar)\n\n y_hats = []\n for b in pbar:\n for (x, cov, v, y, ncov) in test_loader:\n x,cov,y = x.to(device), cov.to(device), y.to(device)\n break\n x = x.unsqueeze(2).permute(1, 0, 2).to(torch.float32).to(device)\n cov = cov.permute(1,0,2).to(torch.float32).to(device)\n\n y_hat_b = model(x, cov).float()\n y_hats.append(y_hat_b.cpu().detach().numpy())\n\n ymc_hats = np.mean(y_hats, axis=0)\n eta_1s = np.mean((ymc_hats[:,0] - np.stack(y_hats)[:,:,0])**2, axis=0)\n return ymc_hats, eta_1s\n\n\ndef inference(model, B=100, device=torch.device('cuda')):\n # mc dropout\n ymc_hats, eta_1s = mc_dropout(model, B, device)\n \n # inherent noise\n model.apply(dropout_off)\n for (x, cov, v, y, ncov) in validation_loader:\n x,cov,y = x.to(device), cov.to(device), y.to(device)\n break\n x = x.unsqueeze(2).permute(1, 0, 2).to(torch.float32).to(device)\n cov = cov.permute(1,0,2).to(torch.float32).to(device)\n y_hat_b = model(x, cov)\n \n eta_2sq = np.mean(y_hat_b.cpu().detach().numpy()[:,0])\n # total noise\n etas = np.sqrt(eta_1s + eta_2sq)\n \n for (x, cov, v, y, ncov) in test_loader:\n break\n normalized_label = y.T/v[:,0]\n diff = torch.nansum(torch.mul((normalized_label.T - torch.tensor(ymc_hats)), (normalized_label.T - torch.tensor(ymc_hats)))).item()\n test_rmse = np.sqrt(diff/len(test_set))\n \n return test_rmse","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:43:42.189997Z","iopub.execute_input":"2023-01-15T07:43:42.190552Z","iopub.status.idle":"2023-01-15T07:43:42.209569Z","shell.execute_reply.started":"2023-01-15T07:43:42.190514Z","shell.execute_reply":"2023-01-15T07:43:42.208558Z"},"trusted":true},"execution_count":15,"outputs":[]},{"cell_type":"code","source":"test_rmse = inference(prednet_model, B=500, device=torch.device('cuda'))","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:43:42.211160Z","iopub.execute_input":"2023-01-15T07:43:42.211514Z","iopub.status.idle":"2023-01-15T07:43:47.468414Z","shell.execute_reply.started":"2023-01-15T07:43:42.211477Z","shell.execute_reply":"2023-01-15T07:43:47.467354Z"},"trusted":true},"execution_count":16,"outputs":[{"name":"stderr","text":"100%|██████████| 500/500 [00:05<00:00, 95.69it/s] \n","output_type":"stream"}]},{"cell_type":"code","source":"test_rmse","metadata":{"execution":{"iopub.status.busy":"2023-01-15T07:43:47.470161Z","iopub.execute_input":"2023-01-15T07:43:47.470865Z","iopub.status.idle":"2023-01-15T07:43:47.477447Z","shell.execute_reply.started":"2023-01-15T07:43:47.470826Z","shell.execute_reply":"2023-01-15T07:43:47.476484Z"},"trusted":true},"execution_count":17,"outputs":[{"execution_count":17,"output_type":"execute_result","data":{"text/plain":"0.819039727423851"},"metadata":{}}]}]}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment