Skip to content

Instantly share code, notes, and snippets.

@nilsleh
Last active April 9, 2024 09:26
Show Gist options
  • Save nilsleh/0a5333b586ad217e8e5a24f780e9b6dd to your computer and use it in GitHub Desktop.
Save nilsleh/0a5333b586ad217e8e5a24f780e9b6dd to your computer and use it in GitHub Desktop.
DeepSensor Reproducable Example
from typing import Any
import deepsensor.torch
import numpy as np
import xarray as xr
from deepsensor.data.loader import TaskLoader
from deepsensor.data.processor import DataProcessor
from deepsensor.train import Trainer
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm
import random
import lab as B
import torch
import matplotlib.pyplot as plt
import pandas as pd
from deepsensor.model.convnp import ConvNP
norm_params = {
"coords": {
"time": {
"name": "time"
},
"x1": {
"name": "lat",
"map": [
35.00000000000013,
74.94999999999786
]
},
"x2": {
"name": "lon",
"map": [
-59.95000000000108,
-20.000000000003354
]
}
},
"sst": {
"method": "mean_std",
"params": {
"mean": 12.843529143943822,
"std": 5.802931939346345
}
},
"obs": {
"method": "mean_std",
"params": {
"mean": -0.03140673828521643,
"std": 0.30809609068904104
}
},
"ssh": {
"method": "mean_std",
"params": {
"mean": -0.022150070094930537,
"std": 0.30898174646146725
}
}
}
def prepare_data() -> Any:
"""Creates a data module from a given configuration."""
# obs_var_names = OmegaConf.to_object(cfg.input_vars)
obs_var_names = ["sst", "obs"]
data_processor = DataProcessor("processed_data")
time = pd.date_range("2013-01-01", periods=218)
x1 = np.linspace(0, 0.4994, 400)
x2 = np.linspace(0, 1.0, 800)
# Create the dummy data
sst_data = np.random.rand(218, 400, 800)
obs_data = np.random.rand(218, 400, 800)
ssh_data = np.random.rand(218, 400, 800)
ocean_mask_data = np.random.choice([True, False], size=(218, 400, 800))
ssh_data[ocean_mask_data == False] = np.nan
# Create the xarray datasets
train_obs = xr.Dataset(
{"sst": (("time", "x1", "x2"), sst_data), "obs": (("time", "x1", "x2"), obs_data)},
coords={"time": time, "x1": x1, "x2": x2},
)
train_targets = xr.Dataset(
{"ssh": (("time", "x1", "x2"), ssh_data)},
coords={"time": time, "x1": x1, "x2": x2},
)
train_ocean_mask = xr.Dataset(
{"ocean_mask": (("time", "x1", "x2"), ocean_mask_data)},
coords={"time": time, "x1": x1, "x2": x2},
)
# or alternatively load the actual data
# train_obs = xr.open_dataset("processed_data/train_obs.nc")[obs_var_names]
# train_targets = xr.open_dataset("processed_data/train_targets.nc")
# train_ocean_mask = xr.open_dataset("processed_data/train_ocean_mask.nc")
# train loader
train_task_loader = TaskLoader(
context=[train_obs[var] for var in obs_var_names] + [train_ocean_mask],
target=train_targets,
)
return {
"train_task_loader": train_task_loader,
"data_processor": data_processor,
}
def gen_tasks(dates, task_loader, progress=True):
"""Generate Tasks."""
tasks = []
for date in tqdm(dates, disable=not progress):
task = task_loader(
date, context_sampling=0.8, target_sampling=0.7
)
tasks.append(task)
return tasks
def run() -> None:
"""Runs the training pipeline."""
data = prepare_data()
train_task_loader = data["train_task_loader"]
data_processor = data["data_processor"]
train_dates = train_task_loader.context[0].time.values
train_tasks = gen_tasks(train_dates, train_task_loader)
model = ConvNP(data_processor, train_task_loader, unet_channels=[128, 128, 128, 128, 128, 128, 128, 128])
# set gpu device
torch.set_default_device("cuda")
# B.set_global_device(cfg.cuda_device)
def compute_val_rmse(model, target_var_ID, tasks):
"""Compute RMSE on a set of tasks."""
errors = []
for task in tasks:
mean = data_processor.map_array(model.mean(task), target_var_ID, unnorm=True)
true = data_processor.map_array(task["Y_t"][0], target_var_ID, unnorm=True)
errors.extend(np.abs(mean - true))
return np.sqrt(np.mean(np.concatenate(errors) ** 2))
fig = deepsensor.plot.task(train_tasks[0], train_task_loader)
fig.savefig("debug_task.png")
plt.show()
train_rmse = []
eval_every = 1
# Train model
trainer = Trainer(model, lr=5e-5)
for epoch in tqdm(range(300)):
random.shuffle(train_tasks)
batch_losses = trainer(train_tasks, batch_size=None)
train_loss = np.mean(batch_losses)
print(f"Epoch {epoch} | Train Loss: {train_loss}")
if epoch % eval_every == 0:
train_rmse.append(compute_val_rmse(model, "ssh", train_tasks))
if __name__ == "__main__":
run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment