-
-
Save dudeperf3ct/eeacd759038b80e841e3595a287f4841 to your computer and use it in GitHub Desktop.
added support for ray amp instead to torch amp scaler
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
import time | |
import os | |
import argparse | |
import copy | |
import json | |
from os.path import join as pjoin | |
from tqdm import tqdm | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import torch | |
import torch.optim as optim | |
from torch.utils.data import DataLoader | |
import wandb | |
from models.detr import build | |
from datasets.fmi_dataset import FMIDataset | |
import util.misc as utils | |
import ray.train.torch | |
from ray import train | |
from ray.train import Trainer | |
from ray import tune | |
from ray.tune import CLIReporter | |
from ray.tune.integration.wandb import wandb_mixin | |
from warmup_scheduler import GradualWarmupScheduler | |
start = time.time() | |
os.environ["NCCL_DEBUG"] = "INFO" | |
def train_one_epoch( | |
model, trainloader, criterion, optimizer, lr_scheduler, max_norm | |
): | |
running_loss, class_err, bbox_err, ce_err = 0.0, [], [], [] | |
model.train() | |
stream = tqdm(trainloader, desc="Training") | |
for i, (image, targets) in enumerate(stream): | |
# assert image.shape[1] == 1 | |
# image = image.repeat(1, 3, 1, 1) | |
targets = [{k: v.to(image.device) for k, v in t.items()} for t in targets] | |
outputs = model(image) | |
loss_dict = criterion(outputs, targets) | |
weight_dict = criterion.weight_dict | |
losses = sum( | |
loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict | |
) | |
optimizer.zero_grad() | |
train.torch.backward(losses) | |
if max_norm > 0: | |
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) | |
optimizer.step() | |
loss_dict_reduced = utils.reduce_dict(loss_dict) | |
loss_dict_reduced_scaled = { | |
k: v * weight_dict[k] | |
for k, v in loss_dict_reduced.items() | |
if k in weight_dict | |
} | |
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) | |
loss_value = losses_reduced_scaled.item() | |
# compute metric | |
class_err.append(loss_dict_reduced["class_error"].cpu().numpy()) | |
bbox_err.append(loss_dict_reduced["loss_bbox"].cpu().detach().numpy()) | |
ce_err.append(loss_dict_reduced["loss_ce"].cpu().detach().numpy()) | |
running_loss += loss_value * image.size(0) | |
lr_scheduler.step() | |
epoch_loss = running_loss / (len(trainloader.dataset)) | |
epoch_class_err = np.mean(class_err) | |
epoch_bbox_err = np.mean(bbox_err) | |
epoch_ce_err = np.mean(ce_err) | |
return epoch_loss, epoch_class_err, epoch_bbox_err, epoch_ce_err | |
def valid_one_epoch( | |
model, valloader, postprocessors, criterion | |
): | |
running_loss, class_err, bbox_err, ce_err, res = 0.0, [], [], [], defaultdict(dict) | |
model.eval() | |
with torch.inference_mode(): | |
stream = tqdm(valloader, desc="Validation") | |
for i, (image, targets) in enumerate(stream): | |
# assert image.shape[1] == 1 | |
# image = image.repeat(1, 3, 1, 1) | |
targets = [{k: v.to(image.device) for k, v in t.items()} for t in targets] | |
outputs = model(image) | |
loss_dict = criterion(outputs, targets) | |
weight_dict = criterion.weight_dict | |
loss_dict_reduced = utils.reduce_dict(loss_dict) | |
loss_dict_reduced_scaled = { | |
k: v * weight_dict[k] | |
for k, v in loss_dict_reduced.items() | |
if k in weight_dict | |
} | |
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) | |
loss_value = losses_reduced_scaled.item() | |
# compute metric | |
class_err.append(loss_dict_reduced["class_error"].cpu().numpy()) | |
bbox_err.append(loss_dict_reduced["loss_bbox"].cpu().detach().numpy()) | |
ce_err.append(loss_dict_reduced["loss_ce"].cpu().detach().numpy()) | |
running_loss += loss_value * image.size(0) | |
results = postprocessors["bbox"](outputs) | |
# print("Results:", results, len(results)) | |
r = [] | |
for j in range(len(results)): | |
r.append({k: v.cpu().numpy() for k, v in results[j].items()}) | |
# print(r, len(r)) | |
res["Results"] = r | |
# print("Targets:", targets, len(targets)) | |
t = [] | |
for j in range(len(targets)): | |
t.append({k: v.cpu().numpy() for k, v in targets[j].items()}) | |
# print(t, len(t)) | |
res["Targets"] = t | |
epoch_loss = running_loss / (len(valloader.dataset)) | |
epoch_class_err = np.mean(class_err) | |
epoch_bbox_err = np.mean(bbox_err) | |
epoch_ce_err = np.mean(ce_err) | |
return epoch_loss, epoch_class_err, epoch_bbox_err, epoch_ce_err, res | |
@wandb_mixin | |
def train_ray_tuner(config): | |
# use amp for mixed precision training | |
train.torch.accelerate(amp=False) | |
# required for ray train | |
args = argparse.Namespace(**config) | |
# folder_name = "" | |
exp_name = str(args.lr) + "_" + str(args.epochs) + "_" + args.backbone + "_" + args.loss_bbox_type +"_" + args.loss_ce_type | |
folder_name = pjoin( | |
config['output_dir'], exp_name | |
) | |
world_rank = train.world_rank() | |
local_rank = train.local_rank() | |
world_size = train.world_size() | |
print( | |
f"| distributed init | world rank: {world_rank}) | local rank: {local_rank}) | world size: {world_size}) |", | |
flush=True, | |
) | |
# import time | |
# time.sleep(20) | |
def print_with_rank(msg): | |
print("[LOCAL RANK {} | WORLD RANK {}]: {}".format(local_rank, world_rank, msg)) | |
if train.local_rank() == 0 and args.log_wandb: # only on main process | |
wandb.init(project="cced-ray-tune-data-frac-zero-nan-patch685", entity="dklabs", name=exp_name, ) | |
wandb.config.update(config) | |
# RAY IS GETTING STUCK HERE | |
# run_name = wandb.run.name | |
# folder_name = pjoin(config['output_dir'], run_name) | |
# print(f"Saving trained model at folder {folder_name}") | |
# create folder to store results | |
os.makedirs(folder_name, exist_ok=True) | |
save_path = pjoin(folder_name, "checkpoint.pt") | |
print(f"Saving trained model with path: {save_path}") | |
# save hyperparameters to json | |
config["trained_save_path"] = save_path | |
with open(pjoin(folder_name, "hyperparmeters.json"), "w") as f: | |
json.dump(config, f, indent=4, sort_keys=True) | |
# setup model and losses | |
model, criterion, postprocessors = build(args) | |
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
print("number of params:", n_parameters) | |
# required for ray train | |
model = train.torch.prepare_model(model) | |
# setup datasets | |
dataset_train = FMIDataset(split="train", args=args) | |
dataset_val = FMIDataset(split="val", args=args) | |
# adjust this according to your GPU Memory and Model size | |
worker_batch_size = 128 | |
trainloader = DataLoader( | |
dataset_train, | |
batch_size=worker_batch_size, | |
collate_fn=utils.collate_fn, | |
num_workers=10, | |
pin_memory=True, | |
) | |
valloader = DataLoader( | |
dataset_val, | |
batch_size=worker_batch_size, | |
drop_last=False, | |
collate_fn=utils.collate_fn, | |
num_workers=10, | |
) | |
# required for ray train | |
trainloader = train.torch.prepare_data_loader(trainloader) | |
valloader = train.torch.prepare_data_loader(valloader) | |
# setup optimizers | |
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) | |
# required for ray train | |
optimizer = train.torch.prepare_optimizer(optimizer) | |
# setup schedulers | |
if args.lr_scheduler_type == 'step': | |
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) | |
elif args.lr_scheduler_type == 'cosine': | |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 5) | |
if args.lr_warmup: | |
lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=args.lr_warmup_epoch, after_scheduler=lr_scheduler) | |
if args.frozen_weights is not None: | |
checkpoint = torch.load(args.frozen_weights, map_location="cpu") | |
model.module.detr.load_state_dict(checkpoint["model"]) | |
# training and validation loops | |
( | |
loss_train, | |
loss_valid, | |
loss_train_error, | |
loss_valid_error, | |
loss_train_bbox_error, | |
loss_valid_bbox_error, | |
loss_train_ce_error, | |
loss_valid_ce_error, | |
best_loss, | |
best_epoch, | |
best_err, | |
) = ( | |
[], | |
[], | |
[], | |
[], | |
[], | |
[], | |
[], | |
[], | |
9999, | |
9999, | |
9999, | |
) | |
since = time.time() | |
all_val_results = defaultdict(dict) | |
ltype = "" | |
if args.loss_ce_type == "ce": | |
ltype = "CE" | |
elif args.loss_ce_type == "focal": | |
ltype = "Focal" | |
for epoch in range(1, args.epochs + 1): | |
print(f"\n{'--' * 15} EPOCH: {epoch} | {args.epochs} {'--' * 15}\n") | |
epoch_train_loss, epoch_train_cls_err, epoch_train_bbox_err, epoch_train_ce_err = train_one_epoch( | |
model, | |
trainloader, | |
criterion, | |
optimizer, | |
lr_scheduler, | |
args.clip_max_norm, | |
) | |
print( | |
"\nPhase: train | Loss: {:.4f} | Loss {}: {:.4f} | Loss BBox: {:.4f} | Class Error: {:.4f} | ".format( | |
epoch_train_loss, ltype, epoch_train_ce_err, epoch_train_bbox_err, epoch_train_cls_err | |
) | |
) | |
epoch_val_loss, epoch_val_cls_err, epoch_val_bbox_err, epoch_val_ce_err, results = valid_one_epoch(model, valloader, postprocessors, criterion) | |
print( | |
"\nPhase: val | Loss: {:.4f} | Loss {}: {:.4f} | Loss BBox: {:.4f} | Class Error: {:.4f} | ".format( | |
epoch_val_loss, ltype, epoch_val_ce_err, epoch_val_bbox_err, epoch_val_cls_err | |
) | |
) | |
loss_train.append(epoch_train_loss) | |
loss_valid.append(epoch_val_loss) | |
loss_train_error.append(epoch_train_cls_err) | |
loss_valid_error.append(epoch_val_cls_err) | |
loss_train_bbox_error.append(epoch_train_bbox_err) | |
loss_valid_bbox_error.append(epoch_val_bbox_err) | |
loss_train_ce_error.append(epoch_train_ce_err) | |
loss_valid_ce_error.append(epoch_val_ce_err) | |
all_val_results[epoch] = results | |
if train.local_rank() == 0 and args.log_wandb: | |
wandb.log({"train_epoch_loss": epoch_train_loss}) | |
wandb.log({"val_epoch_loss": epoch_val_loss}) | |
wandb.log({"train_epoch_cls_err": epoch_train_cls_err}) | |
wandb.log({"val_epoch_cls_err": epoch_val_cls_err}) | |
wandb.log({"train_epoch_bbox_loss": epoch_train_bbox_err}) | |
wandb.log({"val_epoch_bbox_loss": epoch_val_bbox_err}) | |
wandb.log({f"train_epoch_{ltype.lower()}_err": epoch_train_ce_err}) | |
wandb.log({f"val_epoch_{ltype.lower()}_err": epoch_val_ce_err}) | |
if epoch_val_cls_err < best_err: | |
print_with_rank( | |
"Class Error improved from {:.4f} to {:.4f}, corresponding loss {:.4f}.".format( | |
best_err, epoch_val_cls_err, epoch_val_loss | |
) | |
) | |
best_loss = epoch_val_loss | |
best_model_wts = copy.deepcopy(model.module.state_dict()) | |
best_epoch = epoch | |
best_err = epoch_val_cls_err | |
# required for ray train | |
train.report(loss=epoch_val_loss, class_error=epoch_val_cls_err) | |
time_elapsed = time.time() - since | |
print_with_rank( | |
"Training complete in {:.0f}m {:.0f}s".format( | |
time_elapsed // 60, time_elapsed % 60 | |
) | |
) | |
print_with_rank( | |
"Best model loss: {:.4f} and corresponding class error: {:.4f}".format( | |
best_loss, best_err | |
) | |
) | |
# save best model | |
if train.local_rank() == 0: | |
print_with_rank(f"Saving trained model at {save_path}") | |
torch.save({"epoch": best_epoch, "state_dict": best_model_wts}, save_path) | |
np.save(pjoin(folder_name, 'val_results.npy'), all_val_results) | |
fig = plt.figure(figsize=(12, 8)) | |
plt.plot(loss_train) | |
plt.plot(loss_valid) | |
plt.xlabel("Epochs") | |
plt.ylabel("Loss") | |
plt.legend(["train", "val"]) | |
plt.savefig(pjoin(folder_name, "loss.png")) | |
if train.local_rank() == 0 and args.log_wandb: | |
wandb.log({"epoch_losses": fig}) | |
# log class errors | |
fig = plt.figure(figsize=(12, 8)) | |
plt.plot(loss_train_error) | |
plt.plot(loss_valid_error) | |
plt.xlabel("Epochs") | |
plt.ylabel("Class Error") | |
plt.legend(["train", "val"]) | |
plt.savefig(pjoin(folder_name, "class_error.png")) | |
if train.local_rank() == 0 and args.log_wandb: | |
wandb.log({"epoch_cls_error": fig}) | |
# log cross entropy | |
fig = plt.figure(figsize=(12, 8)) | |
plt.plot(loss_train_ce_error) | |
plt.plot(loss_valid_ce_error) | |
plt.xlabel("Epochs") | |
plt.ylabel(f"Loss {ltype}") | |
plt.legend(["train", "val"]) | |
plt.savefig(pjoin(folder_name, f"{ltype.lower()}_error.png")) | |
if train.local_rank() == 0 and args.log_wandb: | |
wandb.log({f"epoch_{ltype.lower()}_error": fig}) | |
# log bbox error | |
fig = plt.figure(figsize=(12, 8)) | |
plt.plot(loss_train_bbox_error) | |
plt.plot(loss_valid_bbox_error) | |
plt.xlabel("Epochs") | |
plt.ylabel("Bbox Error") | |
plt.legend(["train", "val"]) | |
plt.savefig(pjoin(folder_name, "bbox_error.png")) | |
if train.local_rank() == 0 and args.log_wandb: | |
wandb.log({"epoch_bbox_error": fig}) | |
def main(args): | |
trainer = Trainer( | |
"torch", | |
num_workers=args.num_workers, | |
use_gpu=True, | |
resources_per_worker={"CPU": 6, "GPU": 1} | |
) | |
config = { | |
# model parameters | |
"lr": tune.loguniform(1e-6, 1e-2), | |
"weight_decay": tune.loguniform(1e-6, 1e-2), | |
"epochs": 100, | |
"lr_drop": 5, # adjust this according to epochs, after how many epochs to drop lr | |
"num_classes": 2, | |
# backbone | |
"clip_max_norm": 0.1, | |
"frozen_weights": None, | |
# "backbone": tune.choice(["resnet18", "resnet34", "convnext_tiny", "convnext_small", "mobilenet_v2", "mobilenet_v3_small", "mobilenet_v3_large", "efficientnet_b1", "efficientnet_b0"]) # --> NOT WORKING see backbones.py | |
"backbone": tune.choice(["resnet10", "resnet14", "resnet18"]), | |
# "dilation": tune.choice([[False, False, False], [False, False, True], [False, True, False], [True, False, False], [False, True, True], [True, False, True], [True, True, False]]), | |
"dilation": tune.choice([[False, False, False], [False, False, True]]), | |
"position_embedding": "learned", # default is sine "sine" not working yet | |
"enc_layers": tune.choice([2,3, 4]), | |
"dec_layers": tune.choice([2,3, 4]), | |
"dim_feedforward": tune.choice([512, 256, 1024]), | |
"hidden_dim": tune.choice([128, 256, 512]), | |
"dropout": tune.choice([0.1, 0.3, 0.5]), | |
"nheads": tune.choice([2, 4, 8]), | |
"num_queries": tune.choice([23, 24, 25]), | |
"pre_norm": False, | |
# loss | |
"aux_loss": tune.choice([False, True]), | |
# matcher | |
"set_cost_class": tune.uniform(1.0, 5.0), | |
"set_cost_bbox": tune.uniform(1.0, 5.0), | |
# * Loss coefficients | |
"loss_bbox_type": tune.choice(["l1", "l2"]), | |
"loss_ce_type": tune.choice(["focal", "ce"]), | |
"bbox_loss_coef": tune.uniform(1.0, 5.0), | |
"ce_loss_coef": tune.uniform(1.0, 5.0), | |
"eos_coef": tune.uniform(0.1, 0.5), | |
"root": "data/SH5_Bb_Frac_NaN_as_zero_patch685", # adjust this to your data folder | |
"output_dir": "ray_tune_results_zero_nan_patch685", # adjust this to your output folder | |
"seed": 42, | |
"eval": False, | |
"log_wandb": True, | |
"lr_warmup": tune.choice([False, True]), | |
"lr_warmup_epoch": tune.choice([3,4,5,6,7,8,9,10]), | |
"lr_scheduler_type": tune.choice(["step", "cosine"]) | |
} | |
reporter = CLIReporter(metric_columns=["loss", "class_error", "training_iteration"]) | |
trainable = trainer.to_tune_trainable(train_ray_tuner) | |
result = tune.run( | |
trainable, | |
config=config, | |
num_samples=args.num_samples, | |
local_dir=config['output_dir'], | |
keep_checkpoints_num=1, | |
progress_reporter=reporter, | |
) | |
best_trial = result.get_best_trial("class_error", "min") | |
print("Best trial config: {}".format(best_trial.config)) | |
print("Best trial final validation loss: {}".format(best_trial.last_result["loss"])) | |
print( | |
"Best trial final validation class_error: {}".format( | |
best_trial.last_result["class_error"] | |
) | |
) | |
best_checkpoint_dir = best_trial.checkpoint.value | |
print("Best checkpoint dir", best_checkpoint_dir) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Hyperparams") | |
parser.add_argument( | |
"--num-workers", | |
"-n", | |
type=int, | |
default=2, | |
help="The number of workers for training.", | |
) | |
parser.add_argument( | |
"--num_samples", | |
"-ns", | |
type=int, | |
default=2, | |
help="No of random search experiments to run.", | |
) | |
args = parser.parse_args() | |
start = time.time() | |
ray.init() | |
main(args) | |
stop = time.time() | |
print("Total execution time is {} min".format((stop - start) / (60))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment