Skip to content

Instantly share code, notes, and snippets.

@dudeperf3ct
Last active June 28, 2022 08:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dudeperf3ct/eeacd759038b80e841e3595a287f4841 to your computer and use it in GitHub Desktop.
Save dudeperf3ct/eeacd759038b80e841e3595a287f4841 to your computer and use it in GitHub Desktop.
added support for ray amp instead to torch amp scaler
from collections import defaultdict
import time
import os
import argparse
import copy
import json
from os.path import join as pjoin
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
from models.detr import build
from datasets.fmi_dataset import FMIDataset
import util.misc as utils
import ray.train.torch
from ray import train
from ray.train import Trainer
from ray import tune
from ray.tune import CLIReporter
from ray.tune.integration.wandb import wandb_mixin
from warmup_scheduler import GradualWarmupScheduler
start = time.time()
os.environ["NCCL_DEBUG"] = "INFO"
def train_one_epoch(
model, trainloader, criterion, optimizer, lr_scheduler, max_norm
):
running_loss, class_err, bbox_err, ce_err = 0.0, [], [], []
model.train()
stream = tqdm(trainloader, desc="Training")
for i, (image, targets) in enumerate(stream):
# assert image.shape[1] == 1
# image = image.repeat(1, 3, 1, 1)
targets = [{k: v.to(image.device) for k, v in t.items()} for t in targets]
outputs = model(image)
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
losses = sum(
loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict
)
optimizer.zero_grad()
train.torch.backward(losses)
if max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_scaled = {
k: v * weight_dict[k]
for k, v in loss_dict_reduced.items()
if k in weight_dict
}
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
loss_value = losses_reduced_scaled.item()
# compute metric
class_err.append(loss_dict_reduced["class_error"].cpu().numpy())
bbox_err.append(loss_dict_reduced["loss_bbox"].cpu().detach().numpy())
ce_err.append(loss_dict_reduced["loss_ce"].cpu().detach().numpy())
running_loss += loss_value * image.size(0)
lr_scheduler.step()
epoch_loss = running_loss / (len(trainloader.dataset))
epoch_class_err = np.mean(class_err)
epoch_bbox_err = np.mean(bbox_err)
epoch_ce_err = np.mean(ce_err)
return epoch_loss, epoch_class_err, epoch_bbox_err, epoch_ce_err
def valid_one_epoch(
model, valloader, postprocessors, criterion
):
running_loss, class_err, bbox_err, ce_err, res = 0.0, [], [], [], defaultdict(dict)
model.eval()
with torch.inference_mode():
stream = tqdm(valloader, desc="Validation")
for i, (image, targets) in enumerate(stream):
# assert image.shape[1] == 1
# image = image.repeat(1, 3, 1, 1)
targets = [{k: v.to(image.device) for k, v in t.items()} for t in targets]
outputs = model(image)
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_scaled = {
k: v * weight_dict[k]
for k, v in loss_dict_reduced.items()
if k in weight_dict
}
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
loss_value = losses_reduced_scaled.item()
# compute metric
class_err.append(loss_dict_reduced["class_error"].cpu().numpy())
bbox_err.append(loss_dict_reduced["loss_bbox"].cpu().detach().numpy())
ce_err.append(loss_dict_reduced["loss_ce"].cpu().detach().numpy())
running_loss += loss_value * image.size(0)
results = postprocessors["bbox"](outputs)
# print("Results:", results, len(results))
r = []
for j in range(len(results)):
r.append({k: v.cpu().numpy() for k, v in results[j].items()})
# print(r, len(r))
res["Results"] = r
# print("Targets:", targets, len(targets))
t = []
for j in range(len(targets)):
t.append({k: v.cpu().numpy() for k, v in targets[j].items()})
# print(t, len(t))
res["Targets"] = t
epoch_loss = running_loss / (len(valloader.dataset))
epoch_class_err = np.mean(class_err)
epoch_bbox_err = np.mean(bbox_err)
epoch_ce_err = np.mean(ce_err)
return epoch_loss, epoch_class_err, epoch_bbox_err, epoch_ce_err, res
@wandb_mixin
def train_ray_tuner(config):
# use amp for mixed precision training
train.torch.accelerate(amp=False)
# required for ray train
args = argparse.Namespace(**config)
# folder_name = ""
exp_name = str(args.lr) + "_" + str(args.epochs) + "_" + args.backbone + "_" + args.loss_bbox_type +"_" + args.loss_ce_type
folder_name = pjoin(
config['output_dir'], exp_name
)
world_rank = train.world_rank()
local_rank = train.local_rank()
world_size = train.world_size()
print(
f"| distributed init | world rank: {world_rank}) | local rank: {local_rank}) | world size: {world_size}) |",
flush=True,
)
# import time
# time.sleep(20)
def print_with_rank(msg):
print("[LOCAL RANK {} | WORLD RANK {}]: {}".format(local_rank, world_rank, msg))
if train.local_rank() == 0 and args.log_wandb: # only on main process
wandb.init(project="cced-ray-tune-data-frac-zero-nan-patch685", entity="dklabs", name=exp_name, )
wandb.config.update(config)
# RAY IS GETTING STUCK HERE
# run_name = wandb.run.name
# folder_name = pjoin(config['output_dir'], run_name)
# print(f"Saving trained model at folder {folder_name}")
# create folder to store results
os.makedirs(folder_name, exist_ok=True)
save_path = pjoin(folder_name, "checkpoint.pt")
print(f"Saving trained model with path: {save_path}")
# save hyperparameters to json
config["trained_save_path"] = save_path
with open(pjoin(folder_name, "hyperparmeters.json"), "w") as f:
json.dump(config, f, indent=4, sort_keys=True)
# setup model and losses
model, criterion, postprocessors = build(args)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of params:", n_parameters)
# required for ray train
model = train.torch.prepare_model(model)
# setup datasets
dataset_train = FMIDataset(split="train", args=args)
dataset_val = FMIDataset(split="val", args=args)
# adjust this according to your GPU Memory and Model size
worker_batch_size = 128
trainloader = DataLoader(
dataset_train,
batch_size=worker_batch_size,
collate_fn=utils.collate_fn,
num_workers=10,
pin_memory=True,
)
valloader = DataLoader(
dataset_val,
batch_size=worker_batch_size,
drop_last=False,
collate_fn=utils.collate_fn,
num_workers=10,
)
# required for ray train
trainloader = train.torch.prepare_data_loader(trainloader)
valloader = train.torch.prepare_data_loader(valloader)
# setup optimizers
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# required for ray train
optimizer = train.torch.prepare_optimizer(optimizer)
# setup schedulers
if args.lr_scheduler_type == 'step':
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
elif args.lr_scheduler_type == 'cosine':
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 5)
if args.lr_warmup:
lr_scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=args.lr_warmup_epoch, after_scheduler=lr_scheduler)
if args.frozen_weights is not None:
checkpoint = torch.load(args.frozen_weights, map_location="cpu")
model.module.detr.load_state_dict(checkpoint["model"])
# training and validation loops
(
loss_train,
loss_valid,
loss_train_error,
loss_valid_error,
loss_train_bbox_error,
loss_valid_bbox_error,
loss_train_ce_error,
loss_valid_ce_error,
best_loss,
best_epoch,
best_err,
) = (
[],
[],
[],
[],
[],
[],
[],
[],
9999,
9999,
9999,
)
since = time.time()
all_val_results = defaultdict(dict)
ltype = ""
if args.loss_ce_type == "ce":
ltype = "CE"
elif args.loss_ce_type == "focal":
ltype = "Focal"
for epoch in range(1, args.epochs + 1):
print(f"\n{'--' * 15} EPOCH: {epoch} | {args.epochs} {'--' * 15}\n")
epoch_train_loss, epoch_train_cls_err, epoch_train_bbox_err, epoch_train_ce_err = train_one_epoch(
model,
trainloader,
criterion,
optimizer,
lr_scheduler,
args.clip_max_norm,
)
print(
"\nPhase: train | Loss: {:.4f} | Loss {}: {:.4f} | Loss BBox: {:.4f} | Class Error: {:.4f} | ".format(
epoch_train_loss, ltype, epoch_train_ce_err, epoch_train_bbox_err, epoch_train_cls_err
)
)
epoch_val_loss, epoch_val_cls_err, epoch_val_bbox_err, epoch_val_ce_err, results = valid_one_epoch(model, valloader, postprocessors, criterion)
print(
"\nPhase: val | Loss: {:.4f} | Loss {}: {:.4f} | Loss BBox: {:.4f} | Class Error: {:.4f} | ".format(
epoch_val_loss, ltype, epoch_val_ce_err, epoch_val_bbox_err, epoch_val_cls_err
)
)
loss_train.append(epoch_train_loss)
loss_valid.append(epoch_val_loss)
loss_train_error.append(epoch_train_cls_err)
loss_valid_error.append(epoch_val_cls_err)
loss_train_bbox_error.append(epoch_train_bbox_err)
loss_valid_bbox_error.append(epoch_val_bbox_err)
loss_train_ce_error.append(epoch_train_ce_err)
loss_valid_ce_error.append(epoch_val_ce_err)
all_val_results[epoch] = results
if train.local_rank() == 0 and args.log_wandb:
wandb.log({"train_epoch_loss": epoch_train_loss})
wandb.log({"val_epoch_loss": epoch_val_loss})
wandb.log({"train_epoch_cls_err": epoch_train_cls_err})
wandb.log({"val_epoch_cls_err": epoch_val_cls_err})
wandb.log({"train_epoch_bbox_loss": epoch_train_bbox_err})
wandb.log({"val_epoch_bbox_loss": epoch_val_bbox_err})
wandb.log({f"train_epoch_{ltype.lower()}_err": epoch_train_ce_err})
wandb.log({f"val_epoch_{ltype.lower()}_err": epoch_val_ce_err})
if epoch_val_cls_err < best_err:
print_with_rank(
"Class Error improved from {:.4f} to {:.4f}, corresponding loss {:.4f}.".format(
best_err, epoch_val_cls_err, epoch_val_loss
)
)
best_loss = epoch_val_loss
best_model_wts = copy.deepcopy(model.module.state_dict())
best_epoch = epoch
best_err = epoch_val_cls_err
# required for ray train
train.report(loss=epoch_val_loss, class_error=epoch_val_cls_err)
time_elapsed = time.time() - since
print_with_rank(
"Training complete in {:.0f}m {:.0f}s".format(
time_elapsed // 60, time_elapsed % 60
)
)
print_with_rank(
"Best model loss: {:.4f} and corresponding class error: {:.4f}".format(
best_loss, best_err
)
)
# save best model
if train.local_rank() == 0:
print_with_rank(f"Saving trained model at {save_path}")
torch.save({"epoch": best_epoch, "state_dict": best_model_wts}, save_path)
np.save(pjoin(folder_name, 'val_results.npy'), all_val_results)
fig = plt.figure(figsize=(12, 8))
plt.plot(loss_train)
plt.plot(loss_valid)
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend(["train", "val"])
plt.savefig(pjoin(folder_name, "loss.png"))
if train.local_rank() == 0 and args.log_wandb:
wandb.log({"epoch_losses": fig})
# log class errors
fig = plt.figure(figsize=(12, 8))
plt.plot(loss_train_error)
plt.plot(loss_valid_error)
plt.xlabel("Epochs")
plt.ylabel("Class Error")
plt.legend(["train", "val"])
plt.savefig(pjoin(folder_name, "class_error.png"))
if train.local_rank() == 0 and args.log_wandb:
wandb.log({"epoch_cls_error": fig})
# log cross entropy
fig = plt.figure(figsize=(12, 8))
plt.plot(loss_train_ce_error)
plt.plot(loss_valid_ce_error)
plt.xlabel("Epochs")
plt.ylabel(f"Loss {ltype}")
plt.legend(["train", "val"])
plt.savefig(pjoin(folder_name, f"{ltype.lower()}_error.png"))
if train.local_rank() == 0 and args.log_wandb:
wandb.log({f"epoch_{ltype.lower()}_error": fig})
# log bbox error
fig = plt.figure(figsize=(12, 8))
plt.plot(loss_train_bbox_error)
plt.plot(loss_valid_bbox_error)
plt.xlabel("Epochs")
plt.ylabel("Bbox Error")
plt.legend(["train", "val"])
plt.savefig(pjoin(folder_name, "bbox_error.png"))
if train.local_rank() == 0 and args.log_wandb:
wandb.log({"epoch_bbox_error": fig})
def main(args):
trainer = Trainer(
"torch",
num_workers=args.num_workers,
use_gpu=True,
resources_per_worker={"CPU": 6, "GPU": 1}
)
config = {
# model parameters
"lr": tune.loguniform(1e-6, 1e-2),
"weight_decay": tune.loguniform(1e-6, 1e-2),
"epochs": 100,
"lr_drop": 5, # adjust this according to epochs, after how many epochs to drop lr
"num_classes": 2,
# backbone
"clip_max_norm": 0.1,
"frozen_weights": None,
# "backbone": tune.choice(["resnet18", "resnet34", "convnext_tiny", "convnext_small", "mobilenet_v2", "mobilenet_v3_small", "mobilenet_v3_large", "efficientnet_b1", "efficientnet_b0"]) # --> NOT WORKING see backbones.py
"backbone": tune.choice(["resnet10", "resnet14", "resnet18"]),
# "dilation": tune.choice([[False, False, False], [False, False, True], [False, True, False], [True, False, False], [False, True, True], [True, False, True], [True, True, False]]),
"dilation": tune.choice([[False, False, False], [False, False, True]]),
"position_embedding": "learned", # default is sine "sine" not working yet
"enc_layers": tune.choice([2,3, 4]),
"dec_layers": tune.choice([2,3, 4]),
"dim_feedforward": tune.choice([512, 256, 1024]),
"hidden_dim": tune.choice([128, 256, 512]),
"dropout": tune.choice([0.1, 0.3, 0.5]),
"nheads": tune.choice([2, 4, 8]),
"num_queries": tune.choice([23, 24, 25]),
"pre_norm": False,
# loss
"aux_loss": tune.choice([False, True]),
# matcher
"set_cost_class": tune.uniform(1.0, 5.0),
"set_cost_bbox": tune.uniform(1.0, 5.0),
# * Loss coefficients
"loss_bbox_type": tune.choice(["l1", "l2"]),
"loss_ce_type": tune.choice(["focal", "ce"]),
"bbox_loss_coef": tune.uniform(1.0, 5.0),
"ce_loss_coef": tune.uniform(1.0, 5.0),
"eos_coef": tune.uniform(0.1, 0.5),
"root": "data/SH5_Bb_Frac_NaN_as_zero_patch685", # adjust this to your data folder
"output_dir": "ray_tune_results_zero_nan_patch685", # adjust this to your output folder
"seed": 42,
"eval": False,
"log_wandb": True,
"lr_warmup": tune.choice([False, True]),
"lr_warmup_epoch": tune.choice([3,4,5,6,7,8,9,10]),
"lr_scheduler_type": tune.choice(["step", "cosine"])
}
reporter = CLIReporter(metric_columns=["loss", "class_error", "training_iteration"])
trainable = trainer.to_tune_trainable(train_ray_tuner)
result = tune.run(
trainable,
config=config,
num_samples=args.num_samples,
local_dir=config['output_dir'],
keep_checkpoints_num=1,
progress_reporter=reporter,
)
best_trial = result.get_best_trial("class_error", "min")
print("Best trial config: {}".format(best_trial.config))
print("Best trial final validation loss: {}".format(best_trial.last_result["loss"]))
print(
"Best trial final validation class_error: {}".format(
best_trial.last_result["class_error"]
)
)
best_checkpoint_dir = best_trial.checkpoint.value
print("Best checkpoint dir", best_checkpoint_dir)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Hyperparams")
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=2,
help="The number of workers for training.",
)
parser.add_argument(
"--num_samples",
"-ns",
type=int,
default=2,
help="No of random search experiments to run.",
)
args = parser.parse_args()
start = time.time()
ray.init()
main(args)
stop = time.time()
print("Total execution time is {} min".format((stop - start) / (60)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment