Created
September 9, 2022 14:23
-
-
Save hoangkhoiLE/d6f09a7165a16321237a59278ffe7826 to your computer and use it in GitHub Desktop.
File demo tuning with PB2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# yolov7 | |
import argparse | |
import logging | |
import math | |
import os | |
import random | |
import time | |
from copy import deepcopy | |
from pathlib import Path | |
from threading import Thread | |
import numpy as np | |
import torch.distributed as dist | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torch.optim.lr_scheduler as lr_scheduler | |
import torch.utils.data | |
import yaml | |
from torch.cuda import amp | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
from torch.utils.tensorboard import SummaryWriter | |
from tqdm import tqdm | |
import test # import test.py to get mAP after each epoch | |
from models.experimental import attempt_load | |
from models.yolo import Model | |
from utils.autoanchor import check_anchors | |
from utils.datasets import create_dataloader | |
from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \ | |
fitness, strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \ | |
check_requirements, print_mutation, set_logging, one_cycle, colorstr | |
from utils.google_utils import attempt_download | |
from utils.loss import ComputeLoss, ComputeLossOTA | |
from utils.plots import plot_images, plot_labels, plot_results, plot_evolution | |
from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_distributed_zero_first, is_parallel | |
from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume | |
logger = logging.getLogger(__name__) | |
# yolov7 | |
from ray import tune | |
import numpy as np | |
import argparse | |
import os | |
import numpy as np | |
import torch | |
import torch.optim as optim | |
from ray import air, tune | |
from ray.air import session | |
from ray.air.checkpoint import Checkpoint | |
from ray.tune.schedulers.pb2 import PB2 | |
from train import train | |
import yaml | |
import torch | |
import numpy as np | |
import random | |
import shutil | |
def set_seed(seed=42, loader=None): | |
torch.backends.cudnn.deterministic=True | |
torch.backends.cudnn.benchmark=True | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
set_seed() | |
search_space = { | |
"fl_gamma": tune.grid_search([0, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0]), | |
"lr0": tune.loguniform(0.0001, 0.1), | |
"lrf": tune.loguniform(0.001, 0.2), | |
"momentum": tune.grid_search(np.arange(0.6, 0.98 + 0.038, 0.038).tolist()), | |
"iou_t": tune.grid_search([0.2, 0.3, 0.4, 0.5]), | |
"mosaic": tune.quniform(0.0, 1.0, 0.2), | |
"label_smoothing": tune.grid_search([0.0, 0.5, 1.0, 2.0]), | |
"linear_lr": tune.grid_search([True, False]), | |
"adam": tune.grid_search([True, False]), | |
"image_weights": tune.grid_search([True, False]), | |
"batch_size": tune.grid_search([ 2, 4, 8, 16]), | |
"mixup": tune.grid_search([0.0, 0.05, 0.1,0.15]), | |
"paste_in": tune.grid_search([0.0, 0.05, 0.1,0.15]), | |
"scale": tune.grid_search(np.arange(0.2, 0.9 + 0.14, 0.14).tolist()), | |
"box": tune.grid_search(np.arange(0.02, 0.2 + 0.036, 0.036).tolist()), | |
"cls": tune.grid_search(np.arange(0.2, 4.0 + 0.19, 0.19).tolist()), | |
"obj": tune.grid_search(np.arange(0.2, 4.0 + 0.19, 0.19).tolist()), | |
"obj_pw": tune.grid_search(np.arange(0.5, 2.0 + 0.15, 0.15).tolist()), | |
"cls_pw": tune.grid_search(np.arange(0.5, 2.0 + 0.15, 0.15).tolist()), | |
} | |
pb2 = PB2( | |
time_attr="training_iteration", # Set the time attribute as training iterations | |
metric="map", | |
mode="max", | |
perturbation_interval=12, | |
hyperparam_bounds={ | |
"fl_gamma": [0, 3.0], | |
"lr0": [0.0001, 0.1], | |
"lrf": [0.001, 0.2], | |
"momentum": [0.6, 0.98], | |
"iou_t": [0.2, 0.5], | |
"mosaic": [0.0, 1.0], | |
"label_smoothing": [0.0, 2.0], | |
"linear_lr": [True, False], | |
"adam": [True, False], | |
"image_weights": [True, False], | |
"batch_size": [2,16], | |
"mixup": [0.0, 0.15], | |
"paste_in": [0.0, 0.15], | |
"scale": [0.2, 0.9], | |
"box": [0.02, 0.2], | |
"cls": [0.2, 4.0], | |
"obj": [0.2, 4.0], | |
"obj_pw": [0.5, 2.0], | |
"cls_pw": [0.5, 2.0], | |
}) | |
import glob | |
def train_ray_tune(config): | |
opt = argparse.Namespace() | |
opt.adam=False | |
opt.artifact_alias='latest' | |
opt.batch_size=1 | |
opt.bbox_interval=-1 | |
opt.bucket='' | |
opt.cache_images=True | |
opt.cfg='/content/yolov7/cfg/training/yolov7-custom.yaml' | |
opt.data='/content/yolov7/data/data_fold1.yaml' | |
opt.device='' | |
opt.entity=None | |
opt.evolve=False | |
opt.exist_ok=False | |
opt.freeze=[0] | |
opt.global_rank=-1 | |
opt.hyp='/content/yolov7/data/hyp.scratch.custom.yaml' | |
opt.image_weights=False | |
opt.img_size=[640, 640] | |
opt.label_smoothing=0.0 | |
opt.linear_lr=False | |
opt.local_rank=-1 | |
opt.multi_scale=False | |
opt.name='test' | |
opt.noautoanchor=False | |
opt.nosave=False | |
opt.notest=False | |
opt.project='/content/drive/MyDrive/trash/Step2_tuning_yolo_v7/carton/PB2/yolov7_carton_raytune_pb2_24_08_2022/tmp/' | |
opt.quad=False | |
opt.rect=False | |
opt.resume=False | |
opt.save_dir='/content/drive/MyDrive/trash/Step2_tuning_yolo_v7/carton/PB2/yolov7_carton_raytune_pb2_24_08_2022/tmp/weights/' | |
opt.save_period=-1 | |
opt.single_cls=False | |
opt.sync_bn=False | |
opt.total_batch_size=1 | |
opt.upload_dataset=False | |
opt.weights='/content/yolov7/yolov7_training.pt' | |
opt.workers=1 | |
opt.world_size=1 | |
# Hyperparameters | |
with open(opt.hyp) as f: | |
hyp = yaml.load(f, Loader=yaml.SafeLoader) # load hyps | |
#opt.epochs = 10 | |
EPOCH = 300 | |
opt.epochs=EPOCH | |
# update opt | |
opt.batch_size = config["batch_size"] | |
opt.label_smoothing = config["label_smoothing"] | |
if config["linear_lr"] == True: | |
opt.linear_lr = True | |
else: | |
opt.linear_lr = False | |
if config["image_weights"] == True: | |
opt.image_weights = True | |
else: | |
opt.image_weights = False | |
if config["adam"] == True: | |
opt.adam = True | |
else: | |
opt.adam = False | |
hyp["fl_gamma"] = config["fl_gamma"] | |
hyp["lr0"] = config["lr0"] | |
hyp["lrf"] = config["lrf"] | |
hyp["momentum"] = config["momentum"] | |
hyp["iou_t"] = config["iou_t"] | |
hyp["mosaic"] = config["mosaic"] | |
hyp["paste_in"] = config["paste_in"] | |
hyp["mixup"] = config["mixup"] | |
hyp["scale"] = config["scale"] | |
hyp["box"] = config["box"] | |
hyp["cls"] = config["cls"] | |
hyp["obj"] = config["obj"] | |
hyp["obj_pw"] = config["obj_pw"] | |
hyp["cls_pw"] = config["cls_pw"] | |
# Set DDP variables | |
opt.world_size = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 | |
opt.global_rank = int(os.environ['RANK']) if 'RANK' in os.environ else -1 | |
set_logging(opt.global_rank) | |
# opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml') | |
opt.data, opt.cfg, opt.hyp = check_file(opt.data), check_file(opt.cfg), check_file(opt.hyp) # check files | |
assert len(opt.cfg) or len(opt.weights), 'either --cfg or --weights must be specified' | |
opt.img_size.extend([opt.img_size[-1]] * (2 - len(opt.img_size))) # extend to 2 sizes (train, test) | |
opt.name = 'evolve' if opt.evolve else opt.name | |
#opt.save_dir = increment_path(Path(opt.project) / opt.name, exist_ok=opt.exist_ok | opt.evolve) # increment run | |
root_checkpoint_tmp = str(int(math.modf(time.time())[1])) | |
# DDP mode | |
opt.total_batch_size = opt.batch_size | |
device = select_device(opt.device, batch_size=opt.batch_size) | |
# Train | |
logger.info(opt) | |
list_fold = ["/content/yolov7/data/data_fold1.yaml", "/content/yolov7/data/data_fold2.yaml", "/content/yolov7/data/data_fold3.yaml"] | |
#list_fold = ["/content/yolov7/data/data_fold1.yaml"] | |
metrics_map50 = [] | |
metrics_map = [] | |
metrics_mp = [] | |
metrics_mr = [] | |
step = 0 | |
list_model = {} | |
if session.get_checkpoint(): | |
print("Loading from checkpoint.") | |
loaded_checkpoint = session.get_checkpoint() | |
with loaded_checkpoint.as_directory() as loaded_checkpoint_dir: | |
path = os.path.join(loaded_checkpoint_dir, "checkpoint.pt") | |
checkpoint = torch.load(path) | |
for idx, fold in enumerate(list_fold): | |
list_model["model_fold" + str(idx)] = checkpoint["model_fold" + str(idx)].copy() | |
step = checkpoint["step"] | |
while True: | |
if step == 0: | |
for idx, fold in enumerate(list_fold): | |
opt.data = fold | |
if not os.path.exists(root_checkpoint_tmp + "_fold"+ str(idx)): | |
os.makedirs(root_checkpoint_tmp + "_fold"+ str(idx)) | |
opt.save_dir = "" | |
result, ckpt = train(hyp, opt, device, dir_checkpoint = root_checkpoint_tmp + "_fold"+ str(idx)) | |
list_model["model_fold" + str(idx)] = ckpt.copy() | |
del ckpt | |
mp, mr , map50, map, _,_,_ = result | |
metrics_map50.append(map50) | |
metrics_map.append(map) | |
metrics_mp.append(mp) | |
metrics_mr.append(mr) | |
else: | |
opt.resume = True | |
for idx, fold in enumerate(list_fold): | |
opt.data = fold | |
#ckpt = get_latest_run(search_dir=root_checkpoint_tmp + "_"+ str(idx) + "/") # specified or most recent path | |
if not os.path.exists(root_checkpoint_tmp + "_fold"+ str(idx)): | |
os.makedirs(root_checkpoint_tmp + "_fold"+ str(idx)) | |
torch.save(list_model["model_fold" + str(idx)], root_checkpoint_tmp + "_fold"+ str(idx) + "/" + "last.pt") | |
opt.weights = root_checkpoint_tmp + "_fold"+ str(idx) + "/" + "last.pt" | |
opt.cfg = '' | |
result, ckpt = train(hyp, opt, device,dir_checkpoint= root_checkpoint_tmp + "_fold"+ str(idx)) | |
list_model["model_fold" + str(idx)] = ckpt.copy() | |
del ckpt | |
mp, mr , map50, map, _,_,_ = result | |
metrics_map50.append(map50) | |
metrics_map.append(map) | |
metrics_mp.append(mp) | |
metrics_mr.append(mr) | |
checkpoint = None | |
checkpoint_dict = {} | |
checkpoint_dict["step"] = step | |
for idx, fold in enumerate(list_fold): | |
checkpoint_dict["model_fold" + str(idx)] = list_model["model_fold" + str(idx)].copy() | |
checkpoint_dict["map"] = (sum(metrics_map) / len(metrics_map)) | |
checkpoint = Checkpoint.from_dict(checkpoint_dict) | |
# delete all weights | |
for idx, fold in enumerate(list_fold): | |
file_pt = glob.glob("./"+ root_checkpoint_tmp + "_fold"+ str(idx) + "/*.pt") | |
for file in file_pt: | |
os.remove(file) | |
# redo result | |
result ={} | |
result["map"] = (sum(metrics_map) / len(metrics_map)) | |
result["map50"] = (sum(metrics_map50) / len(metrics_map50)) | |
result["mr"] = sum(metrics_mr) / len(metrics_mr) | |
result["mp"] = sum(metrics_mp) / len(metrics_mp) | |
for idx, fold in enumerate(list_fold): | |
result["map_fold" + str(idx)] = metrics_map[idx] | |
step += 1 | |
session.report({"map":result["map"]}, checkpoint=checkpoint) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment