Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save jasonjewik/b611e836f7fd8dbcb485a4eefb09035b to your computer and use it in GitHub Desktop.
Save jasonjewik/b611e836f7fd8dbcb485a4eefb09035b to your computer and use it in GitHub Desktop.
Temporary ClimateLearn forecasting quickstart script as we update the documentation. Assumes the data is already downloaded. You must set ROOT_DIR.
from climate_learn.data.climate_dataset.args import ERA5Args
from climate_learn.data.task.args import ForecastingArgs
from climate_learn.data.dataset import MapDatasetArgs
from climate_learn.data import DataModule
from climate_learn.models import load_model, set_climatology
from climate_learn.training import Trainer
import torch
# Load the data
ROOT_DIR = "/path/to/your/data" # CHANGE ME!
variables = ["temperature_850", "geopotential_500"]
climate_dataset_args = ERA5Args(
ROOT_DIR,
variables,
years=range(1979, 2017),
split="train"
)
task_args = ForecastingArgs(
variables,
variables,
[],
history=3, # time steps behind
pred_range=72, # hours ahead to predict
subsample=6 # hours per time step
)
train_dataset_args = MapDatasetArgs(climate_dataset_args, task_args)
val_dataset_args = train_dataset_args.create_copy({
"climate_dataset_args": {
"years": range(2015, 2017),
"split": "val"
}
})
test_dataset_args = val_dataset_args.create_copy({
"climate_dataset_args": {
"years": range(2017, 2019),
"split": "val"
}
})
dm = DataModule(
train_dataset_args,
val_dataset_args,
test_dataset_args,
batch_size=32,
num_workers=8
)
# Load the model
model_kwargs = {
"in_channels": len(variables),
"history": 3, # as set above in 'ForecastingArgs'
"n_blocks": 4 # number of residual blocks to use
}
optim_kwargs = {} # just use the default settings
mm = load_model("resnet", "forecasting", model_kwargs, optim_kwargs)
set_climatology(mm, dm)
# Train
trainer = Trainer()
trainer.fit(mm, dm)
# Test
trainer.test(mm, dm)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment