Skip to content

Instantly share code, notes, and snippets.

@isaaccorley
Last active August 31, 2022 00:58
Show Gist options
  • Save isaaccorley/725aaf23d6100859136ff58b15e21851 to your computer and use it in GitHub Desktop.
Save isaaccorley/725aaf23d6100859136ff58b15e21851 to your computer and use it in GitHub Desktop.
TorchGeo Minimum Segmentation Train/Val Example
import torch
import torch.nn as nn
import torch.optim as optim
import segmentation_models_pytorch as smp
from torchmetrics import Accuracy
from tqdm import tqdm
from torchgeo.datasets import ETCI2021
from torchgeo.datamodules import ETCI2021DataModule
# Download datasets
train_dataset = ETCI2021(root="data", split="train", download=True)
val_dataset = ETCI2021(root="data", split="val", download=True)
test_dataset = ETCI2021(root="data", split="test", download=True)
# Setup datamodule
dm = ETCI2021DataModule(root_dir="data", batch_size=16, num_workers=4)
dm.setup()
# Get dataloaders
train_dataloader = dm.train_dataloader()
val_dataloader = dm.val_dataloader()
test_dataloader = dm.test_dataloader()
epochs = 5
device = "cuda"
lr = 0.001
model = smp.Unet(
encoder_name="resnet50",
encoder_weights=None,
in_channels=6,
classes=2,
)
model = model.to(device)
opt = optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
train_acc = Accuracy(num_classes=2, ignore_index=0, mdmc_average="global").to(device)
val_acc = Accuracy(num_classes=2, ignore_index=0, mdmc_average="global").to(device)
for epoch in range(epochs):
# Train
model.train()
pbar = tqdm(train_dataloader, position=0, leave=True)
for batch in pbar:
opt.zero_grad()
x, y = batch["image"], batch["mask"]
x = x.to(device)
y = y.to(device)
y_hat = model(x)
y_hat_hard = y_hat.argmax(dim=1)
train_loss = loss_fn(y_hat, y)
train_loss.backward()
opt.step()
train_acc.update(y_hat_hard, y)
pbar.set_description(desc=f"Train Loss: {train_loss}")
# Validate
val_loss = 0
model.eval()
pbar = tqdm(val_dataloader, position=0, leave=True)
for batch in pbar:
x, y = batch["image"], batch["mask"]
x = x.to(device)
y = y.to(device)
with torch.no_grad():
y_hat = model(x)
val_loss += loss_fn(y_hat, y)
y_hat_hard = y_hat.argmax(dim=1)
val_acc.update(y_hat_hard, y)
val_loss = val_loss / len(val_dataloader)
print(f"Epoch:{epoch} | Train acc:{train_acc.compute()} | Val acc:{val_acc.compute()} | Val loss:{val_loss}")
train_acc.reset()
val_acc.reset()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment