Skip to content

Instantly share code, notes, and snippets.

@cormoran
Created December 15, 2019 12:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cormoran/d65b92981131b1ad7dab612e362c1005 to your computer and use it in GitHub Desktop.
Save cormoran/d65b92981131b1ad7dab612e362c1005 to your computer and use it in GitHub Desktop.
import argparse, json, gzip
from attrdict import AttrDict as D
import modules
from modules.manager.train import TrainManager
parser = argparse.ArgumentParser('cifar10/100 supervised classification')
parser.add_argument('--name', type=str, default='test')
parser.add_argument('--exp_id', type=str, default=None)
parser.add_argument('--run_id', type=str, default='0')
parser.add_argument('--depth', type=int, default=28)
parser.add_argument('--width', type=int, default=2)
parser.add_argument('--epoch', type=int, default=200)
parser.add_argument('--pre-trained-dataset', default=None, type=str)
parser.add_argument('--train-dataset', required=True, type=str)
parser.add_argument('--test-dataset', required=True, type=str)
parser.add_argument('--extractor-checkpoint', type=str, default=None)
parser.add_argument('--fix-extractor', action='store_true')
parser.add_argument('--keep-freq', type=int, default=None)
parser.add_argument('--debug', action='store_true')
args = parser.parse_args()
if args.fix_extractor:
assert args.extractor_checkpoint is not None
modules.setup()
def dataset(dataset_file: str, train: bool):
return D(
dataset='dataset.ClassDataset',
dataset_arg=D(
dataset_file=dataset_file,
transform='cifar.cifar_aug_transform',
transform_arg=D(train=train, dataset=args.pre_trained_dataset),
),
)
TrainManager(
suffix=args.name,
exp_id=args.exp_id,
run_id=args.run_id,
keep_freq=args.keep_freq,
auto_resume_strict=True,
model='classification.ClassificationModel',
model_arg=D(
extractor='cifar_wideresnet.WideResNet',
extractor_arg=D(
depth=args.depth,
width=args.width,
base_model_checkpoint_path=args.extractor_checkpoint,
),
classifier='linear.Linear',
classifier_arg=D(class_dim=train_dataset_info.num_class),
fix_extractor=args.fix_extractor,
),
trainer='iterative.IterativeTrainer',
trainer_arg=D(
loss='ce.CrossEntropy',
loss_arg=D(),
metrics=['accuracy.Accuracy'],
metrics_arg=[D()],
debug_max_itr=2 if args.debug else None,
epoch=args.epoch,
optimizer='SGD',
optimizer_arg=D(lr=0.1, momentum=0.9, weight_decay=0.0005, nesterov=True),
data_loader='dataloader.DataLoader',
data_loader_arg=D(
batchsize=128,
shuffle=True,
num_workers=4,
**dataset(args.train_dataset, train=True),
),
lr_scheduler='MultiStepLR',
lr_scheduler_arg=D(milestones=[60, 120, 160], gamma=0.2),
print_freq_itr=200,
),
evaluators=['iterative.IterativeEvaluator'],
evaluators_arg=[
D(
name='SeenClassification',
eval_freq=1,
metrics=['ce.CrossEntropy', 'accuracy.Accuracy'],
metrics_arg=[D(), D()],
debug_max_itr=2 if args.debug else None,
print_freq_itr=100,
data_loader='dataloader.DataLoader',
data_loader_arg=D(
batchsize=128,
shuffle=False,
num_workers=4,
**dataset(args.test_dataset, train=False),
),
),
],
).run_train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment