This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR | |
| total_steps = len(train_loader) * args.epochs | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), args.lr, weight_decay=args.weight_decay | |
| ) | |
| main_scheduler = CosineAnnealingLR( | |
| optimizer, T_max=total_steps - args.warmup_period, eta_min=1e-6 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from torchvision.transforms import v2 | |
| class MixUpCollator: | |
| def __init__(self, num_classes): | |
| self.mixup = v2.MixUp(num_classes=num_classes) | |
| def __call__(self, batch): | |
| return self.mixup(*default_collate(batch)) | |
| collate_fn = MixUpCollator(num_classes=args.num_classes) if args.mixup else None |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torchvision.transforms as transforms | |
| train_dataset = ImageNetDataset( | |
| args.data, | |
| "train", | |
| transforms.Compose( | |
| [ | |
| transforms.Resize((224, 224)), | |
| transforms.RandAugment(args.randaug_num_ops, args.randaug_magnitude), | |
| transforms.ToTensor(), |