Skip to content

Instantly share code, notes, and snippets.

@jasonjewik
Created April 21, 2023 22:51
Show Gist options
  • Save jasonjewik/c339325e4ae33c85e4cecc1356fdae38 to your computer and use it in GitHub Desktop.
Save jasonjewik/c339325e4ae33c85e4cecc1356fdae38 to your computer and use it in GitHub Desktop.
from climate_learn.data.climate_dataset.args import ERA5Args
from climate_learn.data.task.args import ForecastingArgs
from climate_learn.data.dataset import MapDatasetArgs
from climate_learn.data import DataModule
from climate_learn.models import set_climatology, load_model
from climate_learn.training import Trainer
# CHANGE ME
root_dir = "/data0/datasets/weatherbench/data/weatherbench/era5/5.625deg"
variables = ["temperature_850", "geopotential_500"]
climate_dataset_args = ERA5Args(
root_dir,
variables,
years=range(2014, 2015),
split="train"
)
task_args = ForecastingArgs(
variables,
variables,
[],
history=3,
pred_range=72,
subsample=6
)
train_dataset_args = MapDatasetArgs(climate_dataset_args, task_args)
val_dataset_args = train_dataset_args.create_copy({
"climate_dataset_args": {
"years": range(2015, 2016),
"split": "val"
}
})
test_dataset_args = val_dataset_args.create_copy({
"climate_dataset_args": {
"years": range(2016, 2017),
"split": "val"
}
})
dm = DataModule(
train_dataset_args,
val_dataset_args,
test_dataset_args,
batch_size=32,
num_workers=8
)
model_kwargs = {
"in_channels": 2,
"out_channels": 2,
"history": 3,
"hidden_channels": 128,
"activation": "leaky",
"norm": True,
"dropout": 0.1,
"n_blocks": 19
}
optim_kwargs = {}
mm = load_model("resnet", "forecasting", model_kwargs, optim_kwargs)
set_climatology(mm, dm)
for item in dm.get_climatology(split="train").values():
print(item.shape)
trainer = Trainer(
seed=0,
accelerator="gpu",
task="forecasting"
)
trainer.fit(mm, dm)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment