Last active
April 9, 2024 09:26
-
-
Save nilsleh/0a5333b586ad217e8e5a24f780e9b6dd to your computer and use it in GitHub Desktop.
DeepSensor Reproducable Example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any | |
import deepsensor.torch | |
import numpy as np | |
import xarray as xr | |
from deepsensor.data.loader import TaskLoader | |
from deepsensor.data.processor import DataProcessor | |
from deepsensor.train import Trainer | |
from hydra.utils import instantiate | |
from omegaconf import DictConfig, OmegaConf | |
from tqdm import tqdm | |
import random | |
import lab as B | |
import torch | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
from deepsensor.model.convnp import ConvNP | |
norm_params = { | |
"coords": { | |
"time": { | |
"name": "time" | |
}, | |
"x1": { | |
"name": "lat", | |
"map": [ | |
35.00000000000013, | |
74.94999999999786 | |
] | |
}, | |
"x2": { | |
"name": "lon", | |
"map": [ | |
-59.95000000000108, | |
-20.000000000003354 | |
] | |
} | |
}, | |
"sst": { | |
"method": "mean_std", | |
"params": { | |
"mean": 12.843529143943822, | |
"std": 5.802931939346345 | |
} | |
}, | |
"obs": { | |
"method": "mean_std", | |
"params": { | |
"mean": -0.03140673828521643, | |
"std": 0.30809609068904104 | |
} | |
}, | |
"ssh": { | |
"method": "mean_std", | |
"params": { | |
"mean": -0.022150070094930537, | |
"std": 0.30898174646146725 | |
} | |
} | |
} | |
def prepare_data() -> Any: | |
"""Creates a data module from a given configuration.""" | |
# obs_var_names = OmegaConf.to_object(cfg.input_vars) | |
obs_var_names = ["sst", "obs"] | |
data_processor = DataProcessor("processed_data") | |
time = pd.date_range("2013-01-01", periods=218) | |
x1 = np.linspace(0, 0.4994, 400) | |
x2 = np.linspace(0, 1.0, 800) | |
# Create the dummy data | |
sst_data = np.random.rand(218, 400, 800) | |
obs_data = np.random.rand(218, 400, 800) | |
ssh_data = np.random.rand(218, 400, 800) | |
ocean_mask_data = np.random.choice([True, False], size=(218, 400, 800)) | |
ssh_data[ocean_mask_data == False] = np.nan | |
# Create the xarray datasets | |
train_obs = xr.Dataset( | |
{"sst": (("time", "x1", "x2"), sst_data), "obs": (("time", "x1", "x2"), obs_data)}, | |
coords={"time": time, "x1": x1, "x2": x2}, | |
) | |
train_targets = xr.Dataset( | |
{"ssh": (("time", "x1", "x2"), ssh_data)}, | |
coords={"time": time, "x1": x1, "x2": x2}, | |
) | |
train_ocean_mask = xr.Dataset( | |
{"ocean_mask": (("time", "x1", "x2"), ocean_mask_data)}, | |
coords={"time": time, "x1": x1, "x2": x2}, | |
) | |
# or alternatively load the actual data | |
# train_obs = xr.open_dataset("processed_data/train_obs.nc")[obs_var_names] | |
# train_targets = xr.open_dataset("processed_data/train_targets.nc") | |
# train_ocean_mask = xr.open_dataset("processed_data/train_ocean_mask.nc") | |
# train loader | |
train_task_loader = TaskLoader( | |
context=[train_obs[var] for var in obs_var_names] + [train_ocean_mask], | |
target=train_targets, | |
) | |
return { | |
"train_task_loader": train_task_loader, | |
"data_processor": data_processor, | |
} | |
def gen_tasks(dates, task_loader, progress=True): | |
"""Generate Tasks.""" | |
tasks = [] | |
for date in tqdm(dates, disable=not progress): | |
task = task_loader( | |
date, context_sampling=0.8, target_sampling=0.7 | |
) | |
tasks.append(task) | |
return tasks | |
def run() -> None: | |
"""Runs the training pipeline.""" | |
data = prepare_data() | |
train_task_loader = data["train_task_loader"] | |
data_processor = data["data_processor"] | |
train_dates = train_task_loader.context[0].time.values | |
train_tasks = gen_tasks(train_dates, train_task_loader) | |
model = ConvNP(data_processor, train_task_loader, unet_channels=[128, 128, 128, 128, 128, 128, 128, 128]) | |
# set gpu device | |
torch.set_default_device("cuda") | |
# B.set_global_device(cfg.cuda_device) | |
def compute_val_rmse(model, target_var_ID, tasks): | |
"""Compute RMSE on a set of tasks.""" | |
errors = [] | |
for task in tasks: | |
mean = data_processor.map_array(model.mean(task), target_var_ID, unnorm=True) | |
true = data_processor.map_array(task["Y_t"][0], target_var_ID, unnorm=True) | |
errors.extend(np.abs(mean - true)) | |
return np.sqrt(np.mean(np.concatenate(errors) ** 2)) | |
fig = deepsensor.plot.task(train_tasks[0], train_task_loader) | |
fig.savefig("debug_task.png") | |
plt.show() | |
train_rmse = [] | |
eval_every = 1 | |
# Train model | |
trainer = Trainer(model, lr=5e-5) | |
for epoch in tqdm(range(300)): | |
random.shuffle(train_tasks) | |
batch_losses = trainer(train_tasks, batch_size=None) | |
train_loss = np.mean(batch_losses) | |
print(f"Epoch {epoch} | Train Loss: {train_loss}") | |
if epoch % eval_every == 0: | |
train_rmse.append(compute_val_rmse(model, "ssh", train_tasks)) | |
if __name__ == "__main__": | |
run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment