-
-
Save dudeperf3ct/906e18acb8f4cc728109b49c1de0f458 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from network.model import CurveNet | |
from utils.dataset import PetaiSegDataset | |
from utils.utils import collate_fn | |
import torch | |
from utils import losses | |
import numpy as np | |
import os | |
import torch.optim as optim | |
import argparse | |
import cv2 | |
import glob | |
import copy | |
import matplotlib.pyplot as plt | |
from os.path import join as pjoin | |
import pandas as pd | |
from datetime import datetime | |
from torch.utils import data | |
import wandb | |
import random | |
from tqdm import tqdm | |
import segmentation_models_pytorch as smp | |
import ray | |
from ray import train | |
from ray.train import Trainer | |
from ray import tune | |
from ray.tune import CLIReporter | |
from ray.tune.schedulers import ASHAScheduler | |
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present | |
from ray.tune.integration.wandb import wandb_mixin | |
# Fix the random seeds: | |
torch.backends.cudnn.deterministic = True | |
random.seed(42) | |
torch.manual_seed(2019) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed_all(2019) | |
np.random.seed(seed=2019) | |
segmentation_classes = [ | |
"background", "curve1", "curve2", "curve3" | |
] | |
def labels(): | |
l = {} | |
for i, label in enumerate(segmentation_classes): | |
l[i] = label | |
return l | |
NUM_CLASSES = NUM_MASKS = 4 | |
base_data_folder = "/path/" | |
# fp16 training | |
use_amp = True | |
epochs = 50 | |
# # For ASHA scheduler in Ray Tune. | |
# MAX_NUM_EPOCHS = 25 | |
# GRACE_PERIOD = 2 | |
def train_one_epoch(model, trainloader, criterion, optimizer, lr_scheduler, scaler=None): | |
model.train() # Set model to training mode | |
running_loss, ious, f1_scores = 0.0, [], [] | |
stream = tqdm(trainloader, desc="Training") | |
for i, (image, target, _) in enumerate(stream): | |
with torch.cuda.amp.autocast(enabled=scaler is not None): | |
output = model(image) | |
loss = criterion(output, target) | |
optimizer.zero_grad() | |
if scaler is not None: | |
scaler.scale(loss).backward() | |
scaler.step(optimizer) | |
scaler.update() | |
else: | |
loss.backward() | |
optimizer.step() | |
# compute metric | |
tp, fp, fn, tn = smp.metrics.get_stats(output.argmax(1), | |
target.long(), | |
mode='multiclass', | |
num_classes=NUM_CLASSES, | |
ignore_index=0) | |
iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="macro") | |
f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="macro") | |
ious.append(iou_score) | |
f1_scores.append(f1_score) | |
running_loss += loss.item() * image.size(0) | |
lr_scheduler.step() | |
epoch_loss = running_loss / (len(trainloader.dataset)) | |
epoch_iou = np.mean(ious) | |
epoch_f1 = np.mean(f1_scores) | |
return epoch_loss, epoch_iou, epoch_f1 | |
def valid_one_epoch(model, valloader, criterion): | |
model.eval() # Set model to evaluate mode | |
running_loss, ious, f1_scores = 0.0, [], [] | |
with torch.inference_mode(): | |
stream = tqdm(valloader, desc="Validation") | |
for i, (image, target, _) in enumerate(stream): | |
output = model(image) | |
loss = criterion(output, target) | |
# compute metric | |
tp, fp, fn, tn = smp.metrics.functional.get_stats(output.argmax(1), | |
target.long(), | |
mode='multiclass', | |
num_classes=NUM_CLASSES, | |
ignore_index=0) | |
iou_score = smp.metrics.functional.iou_score(tp, fp, fn, tn, reduction="macro") | |
f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="macro") | |
ious.append(iou_score) | |
f1_scores.append(f1_score) | |
running_loss += loss.item() * image.size(0) | |
epoch_loss = running_loss / (len(valloader.dataset)) | |
epoch_iou = np.mean(ious) | |
epoch_f1 = np.mean(f1_scores) | |
return epoch_loss, epoch_iou, epoch_f1 | |
def wb_mask(bg_img, pred_mask, true_mask): | |
return wandb.Image(bg_img, masks={ | |
"prediction" : {"mask_data" : pred_mask, "class_labels" : labels()}, | |
"ground truth" : {"mask_data" : true_mask, "class_labels" : labels()}}) | |
def log_image_predictions(model, valloader, device): | |
model.to(device) | |
model.eval() # Set model to evaluate mode | |
with torch.inference_mode(): | |
stream = tqdm(valloader, desc="Validation") | |
for i, (image, target, shapes) in enumerate(stream): | |
image, target = image.to(device), target.to(device) | |
output = model(image) | |
mask_list = [] | |
# print(image.shape, target.shape, output.shape) | |
for i in range(4): | |
# print(i, shapes[i], image[i].shape, output[i].shape, target[i].shape) | |
sh = shapes[i] | |
img = (image[i].cpu().numpy() * 255).astype(np.uint8) | |
out = (output[i].argmax(1).cpu().numpy()).astype(np.uint8) | |
gt = (target[i].cpu().numpy()).astype(np.uint8) | |
mask_list.append(wb_mask(img, out, gt)) | |
return mask_list | |
@wandb_mixin | |
def train_seg_tuner(config): | |
world_rank = train.world_rank() | |
local_rank = train.local_rank() | |
world_size = train.world_size() | |
print(f"| distributed init | world rank: {world_rank}) | local rank: {local_rank}) | world size: {world_size}) |", flush=True) | |
def print_with_rank(msg): | |
print("[LOCAL RANK {} | WORLD RANK {}]: {}".format(local_rank, world_rank, msg)) | |
exp = str(config["loss_fn"]) + "_loss" | |
fname = "Result_" + str(config["lr"]) + '_' + str(epochs) + '_' + exp | |
print_with_rank(f"Saving trained model name: {fname}") | |
if train.local_rank() == 0: # only on main process | |
wandb.init(project="v2", entity="p") | |
wandb.config.update(config) | |
# setup model | |
net = CurveNet(num_masks=NUM_MASKS) | |
net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) | |
net = train.torch.prepare_model(net) | |
# setup datasets -> split 80%-20% train-val dataset | |
train_files = glob.glob("/path/*/*/*/*/", recursive=True) | |
val_files = glob.glob("/path/", recursive=True) | |
# # use for debugging | |
# train_files, val_files = train_files[:2000], val_files[:500] | |
train_set = PetaiSegDataset(files=train_files) | |
print_with_rank("Training samples: {}".format(len(train_set))) | |
val_set = PetaiSegDataset(files=val_files) | |
print_with_rank("Validation samples: {}".format(len(val_set))) | |
# setup dataloaders2 | |
trainloader = torch.utils.data.DataLoader( | |
train_set, | |
batch_size=16, | |
num_workers=10, | |
collate_fn=collate_fn, | |
drop_last=True, | |
pin_memory=True, | |
) | |
valloader = torch.utils.data.DataLoader( | |
val_set, | |
batch_size=16, | |
num_workers=10, | |
collate_fn=collate_fn, | |
pin_memory=True | |
) | |
trainloader = train.torch.prepare_data_loader(trainloader) | |
valloader = train.torch.prepare_data_loader(valloader) | |
# setup losses | |
if config["loss_fn"] == "dice": | |
criterion = smp.losses.dice.DiceLoss(mode="multiclass", from_logits=True, ignore_index=255) | |
if config["loss_fn"] == "focal": | |
criterion = smp.losses.focal.FocalLoss(mode="multiclass", ignore_index=255) | |
if config["loss_fn"] == "lovasz": | |
criterion = smp.losses.lovasz.LovaszLoss(mode="multiclass", from_logits=True, ignore_index=255) | |
if config["loss_fn"] == "tversky": | |
criterion = smp.losses.tversky.TverskyLoss(mode="multiclass", from_logits=True, ignore_index=255) | |
if config["loss_fn"] == "sce": | |
criterion = smp.losses.soft_ce.SoftCrossEntropyLoss(ignore_index=255, smooth_factor=0.1) | |
# setup optimizers | |
optimizer = optim.AdamW(net.parameters(), lr=config["lr"], weight_decay=config["weight_decay"], amsgrad=True) | |
# setup schedulers | |
iters_per_epoch = len(trainloader) | |
main_lr_scheduler = torch.optim.lr_scheduler.LambdaLR( | |
optimizer, lambda x: (1 - x / (iters_per_epoch * (epochs - config["lr_warmup_epochs"]))) ** 0.9 | |
) | |
if config["lr_warmup_epochs"] > 0: | |
warmup_iters = iters_per_epoch * config["lr_warmup_epochs"] | |
warmup_method = config["lr_warmup_method"].lower() | |
if warmup_method == "linear": | |
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR( | |
optimizer, start_factor=config["lr_warmup_decay"], total_iters=warmup_iters | |
) | |
elif warmup_method == "constant": | |
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR( | |
optimizer, factor=config["lr_warmup_decay"], total_iters=warmup_iters | |
) | |
lr_scheduler = torch.optim.lr_scheduler.SequentialLR( | |
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters] | |
) | |
else: | |
lr_scheduler = main_lr_scheduler | |
# setup amp training (fp16) more batch size and faster training | |
scaler = torch.cuda.amp.GradScaler() if use_amp else None | |
# training and validation loops | |
loss_train, loss_valid, best_loss, best_epoch, best_iou, best_f1 = [], [], 9999, 9999, 0, 0 | |
since = time.time() | |
for epoch in range(1, epochs + 1): | |
# training step | |
print_with_rank(f"\n{'--' * 15} EPOCH: {epoch} | {epochs} {'--' * 15}\n") | |
epoch_train_loss, epoch_train_iou, epoch_train_f1 = train_one_epoch(net, trainloader, criterion, optimizer, lr_scheduler, scaler) | |
print_with_rank("\nPhase: train | Loss: {:.4f} | IoU: {:.4f} | F1: {:.4f}".format(epoch_train_loss, epoch_train_iou, epoch_train_f1)) | |
# validation step | |
epoch_val_loss, epoch_val_iou, epoch_val_f1 = valid_one_epoch(net, valloader, criterion) | |
print_with_rank("\nPhase: val | Loss: {:.4f} | IoU: {:.4f} | F1: {:.4f}".format(epoch_val_loss, epoch_val_iou, epoch_val_f1)) | |
# logs | |
loss_train.append(epoch_train_loss) | |
loss_valid.append(epoch_val_loss) | |
if train.local_rank() == 0: | |
wandb.log({"train_epoch_loss": epoch_train_loss}) | |
wandb.log({"val_epoch_loss": epoch_val_loss}) | |
wandb.log({"train_epoch_f1": epoch_train_f1}) | |
wandb.log({"val_epoch_f1": epoch_val_f1}) | |
wandb.log({"train_epoch_iou": epoch_train_iou}) | |
wandb.log({"val_epoch_iou": epoch_val_iou}) | |
diff = abs(np.round(epoch_val_iou, 4) - np.round(best_iou , 4)) | |
if (np.round(epoch_val_iou, 4) > np.round(best_iou , 4)) and diff >= 1e-3: | |
print_with_rank("IoU improved from {:.4f} to {:.4f}.".format(best_iou, epoch_val_iou)) | |
best_loss = epoch_val_loss | |
best_model_wts = copy.deepcopy(net.module.state_dict()) | |
best_epoch = epoch | |
best_iou = epoch_val_iou | |
best_f1 = epoch_val_f1 | |
train.report(loss=epoch_val_loss, f1=epoch_val_f1, iou=epoch_val_iou) | |
time_elapsed = time.time() - since | |
print_with_rank("Training complete in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60)) | |
consume_prefix_in_state_dict_if_present(best_model_wts, "module.") | |
# log image predictions to wandb | |
bst_model = CurveNet(num_masks = NUM_MASKS) | |
bst_model.load_state_dict(best_model_wts) | |
mask_list = log_image_predictions(bst_model, valloader, "cuda") | |
if train.local_rank() == 0: | |
wandb.log({"prediction": mask_list}) | |
# save best model | |
print_with_rank("Saving best model") | |
train.save_checkpoint(epoch=best_epoch, model_weights=best_model_wts) | |
print_with_rank("Best model loss: {:.4f} and corresponding iou: {:.4f}, f1 : {:.4f}".format(best_loss, best_iou, best_f1)) | |
plt.plot(loss_train) | |
plt.plot(loss_valid) | |
plt.xlabel('Epochs') | |
plt.ylabel('Loss') | |
plt.legend(['train', 'val']) | |
plt.savefig(pjoin('outputs/raytune_result', fname + '_Loss.png')) | |
if train.local_rank() == 0: | |
wandb.log({"epoch_losses": plt}) | |
df = pd.DataFrame([loss_train, loss_valid]).T | |
df.columns = ['train', 'val'] | |
df.to_csv(pjoin('outputs/raytune_result',fname + '_loss.csv'), index=False) | |
def main(args, num_samples=2): | |
# timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") | |
trainer = Trainer( | |
"torch", | |
num_workers=2, | |
use_gpu=True, | |
resources_per_worker={"GPU": 2, "CPU": 20}, | |
# logdir=f"/home/ray_results/train_{timestr}" | |
) | |
config = { | |
"lr": tune.loguniform(1e-3, 1e-1), | |
"weight_decay": tune.loguniform(6e-6, 1e-3), | |
"lr_warmup_epochs": tune.choice([5, 7, 9, 10]), | |
"lr_warmup_method": tune.choice(['linear', 'constant']), | |
"lr_warmup_decay": tune.loguniform(1e-3, 1e-1), | |
"loss_fn": tune.choice(["focal", "dice", "sce", "lovasz", "tversky"]) | |
} | |
# scheduler = ASHAScheduler( | |
# metric="iou", | |
# mode="max", | |
# max_t=MAX_NUM_EPOCHS, | |
# grace_period=GRACE_PERIOD, | |
# reduction_factor=2) | |
reporter = CLIReporter(metric_columns=["loss", "f1", "iou", "training_iteration"]) | |
trainable = trainer.to_tune_trainable(train_seg_tuner) | |
result = tune.run( | |
trainable, | |
config=config, | |
num_samples=num_samples, | |
# scheduler=scheduler, | |
local_dir='outputs/raytune_result', | |
keep_checkpoints_num=1, | |
checkpoint_score_attr='max-iou', | |
progress_reporter=reporter) | |
best_trial = result.get_best_trial("iou", "max") | |
print("Best trial config: {}".format(best_trial.config)) | |
print("Best trial final validation loss: {}".format(best_trial.last_result["loss"])) | |
print("Best trial final validation f1: {}".format(best_trial.last_result["f1"])) | |
print("Best trial final validation iou: {}".format(best_trial.last_result["iou"])) | |
best_checkpoint_dir = best_trial.checkpoint.value | |
print("Best checkpoint dir", best_checkpoint_dir) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Hyperparams') | |
parser.add_argument( | |
"--num-gpus", | |
"-g", | |
type=int, | |
default=2, | |
help="The number of GPUs per worker.", | |
) | |
parser.add_argument( | |
"--num-workers", | |
"-n", | |
type=int, | |
default=2, | |
help="The number of workers for training.", | |
) | |
args = parser.parse_args() | |
start = time.time() | |
ray.init() | |
# num_samples -> number of random search experiments to run | |
main(args, num_samples=10) | |
stop = time.time() | |
print('Total execution time is {} min'.format((stop - start)/(60))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment