Skip to content

Instantly share code, notes, and snippets.

@nilsleh
Last active January 16, 2024 15:44
Show Gist options
  • Save nilsleh/a38b3c681eb341ad79f2934ffeaab5aa to your computer and use it in GitHub Desktop.
Save nilsleh/a38b3c681eb341ad79f2934ffeaab5aa to your computer and use it in GitHub Desktop.
Ocean Bench Lightning Datamodule
"""Ocean Bench Datamodules."""
import itertools
import os
from collections import namedtuple
from typing import Any
import hydra
import numpy as np
import ocn_tools._src.geoprocessing.gridding as obgrid
import pandas as pd
import torch
import xarray as xr
import xrpatcher
from oceanbench._src.utils.hydra import pipe
from omegaconf import DictConfig
from torch import Tensor
from torch.utils.data import DataLoader, Dataset, Subset, default_collate
from torchgeo.datamodules import NonGeoDataModule
def get_cfg(cfg_path) -> Any:
"""Loads and returns the configuration from a given path.
Args:
cfg_path: The path to the configuration file.
Returns:
The loaded configuration.
"""
with hydra.initialize("../../../config", version_base="1.3"):
cfg = hydra.compose(cfg_path)
base_dir = cfg.task.data.base_dir
cfg.task.data.base_dir = base_dir.replace("../", "home/user/")
cfg = hydra.compose(cfg_path).task.outputs
print(cfg)
return hydra.utils.call(cfg)
def norm_stats(ds: xr.Dataset) -> tuple[float, float]:
"""Compute normalization statistics from dataset (mean and std).
Args:
ds: Dataset
"""
return ds.da.sel(variable="ssh").pipe(
lambda da: (da.mean().item(), da.std().item())
)
def patcher_from_osse_task(
task: DictConfig,
patcher_kw: dict[str, Any],
ref_var: str = "ssh",
split: str = "trainval",
):
"""Creates a patcher from an OSSE task.
Args:
task: The OSSE task
patcher_kw: The patcher keyword arguments.
ref_var: The reference variable. Defaults to 'ssh'.
split: The split type. Defaults to 'trainval'.
Returns:
xrpatcher.XRDAPatcher: The created patcher.
"""
default_domain_limits = dict(
time=slice(*task.splits[split]),
lat=slice(*task.domain.lat),
lon=slice(*task.domain.lon),
)
domain_limits = {**default_domain_limits, **patcher_kw.get("domain_limits", {})}
task_data = {k: v().sel(domain_limits) for k, v in task.data.items()}
da = xr.Dataset(
{
k: v.assign_coords(task_data[ref_var].coords) if k != ref_var else v
for k, v in task_data.items()
}
).to_array()
return xrpatcher.XRDAPatcher(da, **patcher_kw)
def patcher_from_ose_task(
task: DictConfig,
tgt_grid_resolution: dict[str, Any],
patcher_kw: dict[str, Any],
ref_var: str = "ssh",
split: str = "train",
) -> xrpatcher.XRDAPatcher:
"""Creates a patcher from an OSE task.
Args:
task: The OSE task.
tgt_grid_resolution: The target grid resolution.
patcher_kw: The patcher keyword arguments.
ref_var: The reference variable. Defaults to 'ssh'.
split: The split type. Defaults to 'train'.
Returns:
xrpatcher.XRDAPatcher: The created patcher.
"""
default_domain_limits = dict(
time=task.splits[split],
lat=task.domain.lat,
lon=task.domain.lon,
)
domain_limits = {**default_domain_limits, **patcher_kw.get("domain_limits", {})}
select = lambda da: (
da.sel(time=slice(*domain_limits["time"]))
.where(lambda da: da.lat > domain_limits["lat"][0], drop=True)
.where(lambda da: da.lon > domain_limits["lon"][0], drop=True)
.where(lambda da: da.lat < domain_limits["lat"][1], drop=True)
.where(lambda da: da.lon < domain_limits["lon"][1], drop=True)
)
tgt_grid = xr.Dataset(
coords=dict(
lat=np.arange(*domain_limits["lat"], tgt_grid_resolution["lat"]),
lon=np.arange(*domain_limits["lon"], tgt_grid_resolution["lon"]),
time=pd.date_range(
*domain_limits["time"], freq=tgt_grid_resolution["time"]
),
)
)
data = dict(
train=xr.combine_nested(
[v().pipe(select) for k, v in task.data["train"].items()], concat_dim="time"
),
test=xr.combine_nested(
[v().pipe(select) for k, v in task.data["test"].items()], concat_dim="time"
),
)
da = xr.Dataset(
{
k: obgrid.coord_based_to_grid(v.to_dataset(name="ssh"), tgt_grid).ssh
for k, v in data.items()
}
).to_array()
return xrpatcher.XRDAPatcher(da, **patcher_kw)
class XrTorchDataset(Dataset):
"""Dataset for Xarray Datasets with XR Patcher."""
def __init__(self, patcher: xrpatcher.XRDAPatcher, item_postpro=None):
"""Initialize a new instance of XrTorchDataset.
Args:
patcher: XR Patcher
item_postpro: Postprocessing function for items
"""
self.patcher = patcher
self.postpro = item_postpro
def __getitem__(self, idx) -> dict[str, Tensor]:
"""Get item at index `idx`.
Args:
idx: Index
Returns:
Item at index `idx`
"""
item = self.patcher[idx].load().values
if self.postpro:
item = self.postpro(item)
return {"input": item[0, ...], "target": item[1, ...]}
def reconstruct_from_batches(self, batches, **rec_kws):
"""Reconstruct dataset from batches.
Args:
batches: List of batches
"""
return self.patcher.reconstruct([*itertools.chain(*batches)], **rec_kws)
def __len__(self):
"""Get length of dataset."""
return len(self.patcher)
class OceanBenchDataModule(NonGeoDataModule):
"""Ocean Bench DataModule."""
valid_tasks = ["osse_gf_nadir", "osse_gf_nadirswot", "osse_gf_nadir_sst", "ose_gf"]
def __init__(
self,
task_name: str,
patcher_kw: dict[str, Any],
ref_var: str = "ssh",
batch_size: int = 32,
num_workers=0,
**kwargs,
):
"""Initialize a new instance of OceanBenchDataModule.
Args:
task: Task name, one of `valid_tasks`
patcher_kw: Keyword arguments for xrpatcher.XRDAPatcher
ref_var: Reference variable for patcher
batch_size: Batch size
num_workers: Number of workers
"""
super().__init__(XrTorchDataset, batch_size, num_workers, **kwargs)
assert task_name in self.valid_tasks, f"Task must be one of {self.valid_tasks}"
self.task_name = task_name
self.task_cfg = get_cfg(f"task/{task_name}/task")
self.ref_var = ref_var
self.patcher_kw = patcher_kw
# collate function for tensors
self.collate_fn = default_collate
if "osse" in self.task_name:
self.patcher_task_fn = patcher_from_osse_task
else:
self.patcher_task_fn = patcher_from_ose_task
self.train_patcher = self.patcher_task_fn(
self.task_cfg, self.patcher_kw, ref_var=self.ref_var, split="trainval"
)
self.test_patcher = self.patcher_task_fn(
self.task_cfg, self.patcher_kw, ref_var=self.ref_var, split="test"
)
self.mean, self.std = norm_stats(self.train_patcher)
self.item_postpro = lambda item: pipe(
item,
[
lambda i: (i - self.mean) / self.std,
lambda i: i.astype(np.float32),
],
)
def setup(self, stage: str) -> None:
"""Set up datasets.
Called at the beginning of fit, validate, test, or predict. During distributed
training, this method is called from every process across all the nodes. Setting
state here is recommended.
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
if stage in ["fit", "validate"]:
train_dataset = XrTorchDataset(
self.train_patcher, item_postpro=self.item_postpro
)
# create train and validation split randomly by index
total_length = len(train_dataset)
train_length = int(total_length * 0.8)
val_length = total_length - train_length
train_indices, val_indices = torch.utils.data.random_split(
range(len(train_dataset)),
[train_length, val_length],
generator=torch.Generator().manual_seed(42),
)
self.train_dataset = Subset(train_dataset, train_indices)
self.val_dataset = Subset(train_dataset, val_indices)
if stage in ["test"]:
self.test_dataset = XrTorchDataset(
self.test_patcher, item_postpro=self.item_postpro
)
dm = OceanBenchDataModule(
task_name="osse_gf_nadir",
patcher_kw=dict(patches={"time": 5}, strides={"time": 1}),
ref_var="ssh",
batch_size=32,
num_workers=0,
)
dm.setup(stage="fit")
train_loader = dm.train_dataloader()
batch = next(iter(train_loader))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment