Skip to content

Instantly share code, notes, and snippets.

@dudeperf3ct
Last active March 30, 2022 16:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dudeperf3ct/906e18acb8f4cc728109b49c1de0f458 to your computer and use it in GitHub Desktop.
Save dudeperf3ct/906e18acb8f4cc728109b49c1de0f458 to your computer and use it in GitHub Desktop.
import time
from network.model import CurveNet
from utils.dataset import PetaiSegDataset
from utils.utils import collate_fn
import torch
from utils import losses
import numpy as np
import os
import torch.optim as optim
import argparse
import cv2
import glob
import copy
import matplotlib.pyplot as plt
from os.path import join as pjoin
import pandas as pd
from datetime import datetime
from torch.utils import data
import wandb
import random
from tqdm import tqdm
import segmentation_models_pytorch as smp
import ray
from ray import train
from ray.train import Trainer
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from ray.tune.integration.wandb import wandb_mixin
# Fix the random seeds:
torch.backends.cudnn.deterministic = True
random.seed(42)
torch.manual_seed(2019)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(2019)
np.random.seed(seed=2019)
segmentation_classes = [
"background", "curve1", "curve2", "curve3"
]
def labels():
l = {}
for i, label in enumerate(segmentation_classes):
l[i] = label
return l
NUM_CLASSES = NUM_MASKS = 4
base_data_folder = "/path/"
# fp16 training
use_amp = True
epochs = 50
# # For ASHA scheduler in Ray Tune.
# MAX_NUM_EPOCHS = 25
# GRACE_PERIOD = 2
def train_one_epoch(model, trainloader, criterion, optimizer, lr_scheduler, scaler=None):
model.train() # Set model to training mode
running_loss, ious, f1_scores = 0.0, [], []
stream = tqdm(trainloader, desc="Training")
for i, (image, target, _) in enumerate(stream):
with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
if scaler is not None:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
# compute metric
tp, fp, fn, tn = smp.metrics.get_stats(output.argmax(1),
target.long(),
mode='multiclass',
num_classes=NUM_CLASSES,
ignore_index=0)
iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="macro")
f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="macro")
ious.append(iou_score)
f1_scores.append(f1_score)
running_loss += loss.item() * image.size(0)
lr_scheduler.step()
epoch_loss = running_loss / (len(trainloader.dataset))
epoch_iou = np.mean(ious)
epoch_f1 = np.mean(f1_scores)
return epoch_loss, epoch_iou, epoch_f1
def valid_one_epoch(model, valloader, criterion):
model.eval() # Set model to evaluate mode
running_loss, ious, f1_scores = 0.0, [], []
with torch.inference_mode():
stream = tqdm(valloader, desc="Validation")
for i, (image, target, _) in enumerate(stream):
output = model(image)
loss = criterion(output, target)
# compute metric
tp, fp, fn, tn = smp.metrics.functional.get_stats(output.argmax(1),
target.long(),
mode='multiclass',
num_classes=NUM_CLASSES,
ignore_index=0)
iou_score = smp.metrics.functional.iou_score(tp, fp, fn, tn, reduction="macro")
f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="macro")
ious.append(iou_score)
f1_scores.append(f1_score)
running_loss += loss.item() * image.size(0)
epoch_loss = running_loss / (len(valloader.dataset))
epoch_iou = np.mean(ious)
epoch_f1 = np.mean(f1_scores)
return epoch_loss, epoch_iou, epoch_f1
def wb_mask(bg_img, pred_mask, true_mask):
return wandb.Image(bg_img, masks={
"prediction" : {"mask_data" : pred_mask, "class_labels" : labels()},
"ground truth" : {"mask_data" : true_mask, "class_labels" : labels()}})
def log_image_predictions(model, valloader, device):
model.to(device)
model.eval() # Set model to evaluate mode
with torch.inference_mode():
stream = tqdm(valloader, desc="Validation")
for i, (image, target, shapes) in enumerate(stream):
image, target = image.to(device), target.to(device)
output = model(image)
mask_list = []
# print(image.shape, target.shape, output.shape)
for i in range(4):
# print(i, shapes[i], image[i].shape, output[i].shape, target[i].shape)
sh = shapes[i]
img = (image[i].cpu().numpy() * 255).astype(np.uint8)
out = (output[i].argmax(1).cpu().numpy()).astype(np.uint8)
gt = (target[i].cpu().numpy()).astype(np.uint8)
mask_list.append(wb_mask(img, out, gt))
return mask_list
@wandb_mixin
def train_seg_tuner(config):
world_rank = train.world_rank()
local_rank = train.local_rank()
world_size = train.world_size()
print(f"| distributed init | world rank: {world_rank}) | local rank: {local_rank}) | world size: {world_size}) |", flush=True)
def print_with_rank(msg):
print("[LOCAL RANK {} | WORLD RANK {}]: {}".format(local_rank, world_rank, msg))
exp = str(config["loss_fn"]) + "_loss"
fname = "Result_" + str(config["lr"]) + '_' + str(epochs) + '_' + exp
print_with_rank(f"Saving trained model name: {fname}")
if train.local_rank() == 0: # only on main process
wandb.init(project="v2", entity="p")
wandb.config.update(config)
# setup model
net = CurveNet(num_masks=NUM_MASKS)
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net)
net = train.torch.prepare_model(net)
# setup datasets -> split 80%-20% train-val dataset
train_files = glob.glob("/path/*/*/*/*/", recursive=True)
val_files = glob.glob("/path/", recursive=True)
# # use for debugging
# train_files, val_files = train_files[:2000], val_files[:500]
train_set = PetaiSegDataset(files=train_files)
print_with_rank("Training samples: {}".format(len(train_set)))
val_set = PetaiSegDataset(files=val_files)
print_with_rank("Validation samples: {}".format(len(val_set)))
# setup dataloaders2
trainloader = torch.utils.data.DataLoader(
train_set,
batch_size=16,
num_workers=10,
collate_fn=collate_fn,
drop_last=True,
pin_memory=True,
)
valloader = torch.utils.data.DataLoader(
val_set,
batch_size=16,
num_workers=10,
collate_fn=collate_fn,
pin_memory=True
)
trainloader = train.torch.prepare_data_loader(trainloader)
valloader = train.torch.prepare_data_loader(valloader)
# setup losses
if config["loss_fn"] == "dice":
criterion = smp.losses.dice.DiceLoss(mode="multiclass", from_logits=True, ignore_index=255)
if config["loss_fn"] == "focal":
criterion = smp.losses.focal.FocalLoss(mode="multiclass", ignore_index=255)
if config["loss_fn"] == "lovasz":
criterion = smp.losses.lovasz.LovaszLoss(mode="multiclass", from_logits=True, ignore_index=255)
if config["loss_fn"] == "tversky":
criterion = smp.losses.tversky.TverskyLoss(mode="multiclass", from_logits=True, ignore_index=255)
if config["loss_fn"] == "sce":
criterion = smp.losses.soft_ce.SoftCrossEntropyLoss(ignore_index=255, smooth_factor=0.1)
# setup optimizers
optimizer = optim.AdamW(net.parameters(), lr=config["lr"], weight_decay=config["weight_decay"], amsgrad=True)
# setup schedulers
iters_per_epoch = len(trainloader)
main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lambda x: (1 - x / (iters_per_epoch * (epochs - config["lr_warmup_epochs"]))) ** 0.9
)
if config["lr_warmup_epochs"] > 0:
warmup_iters = iters_per_epoch * config["lr_warmup_epochs"]
warmup_method = config["lr_warmup_method"].lower()
if warmup_method == "linear":
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=config["lr_warmup_decay"], total_iters=warmup_iters
)
elif warmup_method == "constant":
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
optimizer, factor=config["lr_warmup_decay"], total_iters=warmup_iters
)
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
)
else:
lr_scheduler = main_lr_scheduler
# setup amp training (fp16) more batch size and faster training
scaler = torch.cuda.amp.GradScaler() if use_amp else None
# training and validation loops
loss_train, loss_valid, best_loss, best_epoch, best_iou, best_f1 = [], [], 9999, 9999, 0, 0
since = time.time()
for epoch in range(1, epochs + 1):
# training step
print_with_rank(f"\n{'--' * 15} EPOCH: {epoch} | {epochs} {'--' * 15}\n")
epoch_train_loss, epoch_train_iou, epoch_train_f1 = train_one_epoch(net, trainloader, criterion, optimizer, lr_scheduler, scaler)
print_with_rank("\nPhase: train | Loss: {:.4f} | IoU: {:.4f} | F1: {:.4f}".format(epoch_train_loss, epoch_train_iou, epoch_train_f1))
# validation step
epoch_val_loss, epoch_val_iou, epoch_val_f1 = valid_one_epoch(net, valloader, criterion)
print_with_rank("\nPhase: val | Loss: {:.4f} | IoU: {:.4f} | F1: {:.4f}".format(epoch_val_loss, epoch_val_iou, epoch_val_f1))
# logs
loss_train.append(epoch_train_loss)
loss_valid.append(epoch_val_loss)
if train.local_rank() == 0:
wandb.log({"train_epoch_loss": epoch_train_loss})
wandb.log({"val_epoch_loss": epoch_val_loss})
wandb.log({"train_epoch_f1": epoch_train_f1})
wandb.log({"val_epoch_f1": epoch_val_f1})
wandb.log({"train_epoch_iou": epoch_train_iou})
wandb.log({"val_epoch_iou": epoch_val_iou})
diff = abs(np.round(epoch_val_iou, 4) - np.round(best_iou , 4))
if (np.round(epoch_val_iou, 4) > np.round(best_iou , 4)) and diff >= 1e-3:
print_with_rank("IoU improved from {:.4f} to {:.4f}.".format(best_iou, epoch_val_iou))
best_loss = epoch_val_loss
best_model_wts = copy.deepcopy(net.module.state_dict())
best_epoch = epoch
best_iou = epoch_val_iou
best_f1 = epoch_val_f1
train.report(loss=epoch_val_loss, f1=epoch_val_f1, iou=epoch_val_iou)
time_elapsed = time.time() - since
print_with_rank("Training complete in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60))
consume_prefix_in_state_dict_if_present(best_model_wts, "module.")
# log image predictions to wandb
bst_model = CurveNet(num_masks = NUM_MASKS)
bst_model.load_state_dict(best_model_wts)
mask_list = log_image_predictions(bst_model, valloader, "cuda")
if train.local_rank() == 0:
wandb.log({"prediction": mask_list})
# save best model
print_with_rank("Saving best model")
train.save_checkpoint(epoch=best_epoch, model_weights=best_model_wts)
print_with_rank("Best model loss: {:.4f} and corresponding iou: {:.4f}, f1 : {:.4f}".format(best_loss, best_iou, best_f1))
plt.plot(loss_train)
plt.plot(loss_valid)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(['train', 'val'])
plt.savefig(pjoin('outputs/raytune_result', fname + '_Loss.png'))
if train.local_rank() == 0:
wandb.log({"epoch_losses": plt})
df = pd.DataFrame([loss_train, loss_valid]).T
df.columns = ['train', 'val']
df.to_csv(pjoin('outputs/raytune_result',fname + '_loss.csv'), index=False)
def main(args, num_samples=2):
# timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
trainer = Trainer(
"torch",
num_workers=2,
use_gpu=True,
resources_per_worker={"GPU": 2, "CPU": 20},
# logdir=f"/home/ray_results/train_{timestr}"
)
config = {
"lr": tune.loguniform(1e-3, 1e-1),
"weight_decay": tune.loguniform(6e-6, 1e-3),
"lr_warmup_epochs": tune.choice([5, 7, 9, 10]),
"lr_warmup_method": tune.choice(['linear', 'constant']),
"lr_warmup_decay": tune.loguniform(1e-3, 1e-1),
"loss_fn": tune.choice(["focal", "dice", "sce", "lovasz", "tversky"])
}
# scheduler = ASHAScheduler(
# metric="iou",
# mode="max",
# max_t=MAX_NUM_EPOCHS,
# grace_period=GRACE_PERIOD,
# reduction_factor=2)
reporter = CLIReporter(metric_columns=["loss", "f1", "iou", "training_iteration"])
trainable = trainer.to_tune_trainable(train_seg_tuner)
result = tune.run(
trainable,
config=config,
num_samples=num_samples,
# scheduler=scheduler,
local_dir='outputs/raytune_result',
keep_checkpoints_num=1,
checkpoint_score_attr='max-iou',
progress_reporter=reporter)
best_trial = result.get_best_trial("iou", "max")
print("Best trial config: {}".format(best_trial.config))
print("Best trial final validation loss: {}".format(best_trial.last_result["loss"]))
print("Best trial final validation f1: {}".format(best_trial.last_result["f1"]))
print("Best trial final validation iou: {}".format(best_trial.last_result["iou"]))
best_checkpoint_dir = best_trial.checkpoint.value
print("Best checkpoint dir", best_checkpoint_dir)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Hyperparams')
parser.add_argument(
"--num-gpus",
"-g",
type=int,
default=2,
help="The number of GPUs per worker.",
)
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=2,
help="The number of workers for training.",
)
args = parser.parse_args()
start = time.time()
ray.init()
# num_samples -> number of random search experiments to run
main(args, num_samples=10)
stop = time.time()
print('Total execution time is {} min'.format((stop - start)/(60)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment