Skip to content

Instantly share code, notes, and snippets.

@adamjstewart
Last active June 11, 2022 04:24
Show Gist options
  • Save adamjstewart/07d722986c60e3a2177193c40f42d068 to your computer and use it in GitHub Desktop.
Save adamjstewart/07d722986c60e3a2177193c40f42d068 to your computer and use it in GitHub Desktop.
Training a model on a TorchGeo dataset with PyTorch Lightning
from pytorch_lightning import Trainer
from torchgeo.datamodules import InriaAerialImageLabelingDataModule
from torchgeo.trainers import SemanticSegmentationTask
datamodule = InriaAerialImageLabelingDataModule(root_dir="...", batch_size=64, num_workers=6)
task = SemanticSegmentationTask(model="resnet18", pretrained=True, learning_rate=0.1)
trainer = Trainer(gpus=1, default_root_dir="...")
trainer.fit(model=task, datamodule=datamodule)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment