Skip to content

Instantly share code, notes, and snippets.

@rwightman
Last active June 28, 2021 12:48
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rwightman/d6c264a9001f9167e06c209f630b2cc6 to your computer and use it in GitHub Desktop.
Save rwightman/d6c264a9001f9167e06c209f630b2cc6 to your computer and use it in GitHub Desktop.
MLP model training hparams w/ timm bits and PyTorch XLA on TPU VM
aa: rand-m6-n4-mstd1.0-inc1
amp: false
apex_amp: false
aug_splits: 0
batch_size: 224
bn_eps: null
bn_momentum: null
bn_tf: false
channels_last: false
checkpoint_hist: 10
clip_grad: 1.0
clip_mode: norm
color_jitter: 0.4
cooldown_epochs: 10
crop_pct: 0.9
cutmix: 1.0
cutmix_minmax: null
data_dir: gs://my-imagenet
dataset: tfds/imagenet2012:5.0.0
decay_epochs: 1.0
decay_rate: 0.988
dist_bn: reduce
drop: 0.01
drop_block: null
drop_connect: null
drop_path: 0.1
epoch_repeats: 0.0
epochs: 600
eval_metric: top1
experiment: ''
gp: null
hflip: 0.5
img_size: 224
initial_checkpoint: ''
input_size: null
interpolation: ''
jsd: false
local_rank: 0
log_interval: 50
log_wandb: false
lr: 0.0008
lr_cycle_limit: 1
lr_cycle_mul: 1.0
lr_noise: null
lr_noise_pct: 0.67
lr_noise_std: 1.0
mean: null
min_lr: 1.0e-05
mixup: 0.5
mixup_mode: batch
mixup_off_epoch: 0
mixup_prob: 1.0
mixup_switch_prob: 0.5
model: gmixer_24_224
model_ema: true
model_ema_decay: 0.99992
model_ema_force_cpu: false
momentum: 0.9
native_amp: false
no_aug: false
no_prefetcher: false
no_resume_opt: false
num_classes: 1000
opt: adamw
opt_betas: null
opt_eps: 1.0e-06
output: ''
patience_epochs: 10
pin_mem: false
pretrained: false
ratio:
- 0.67
- 1.5
recount: 3
recovery_interval: 0
remode: pixel
reprob: 0.0
resplit: false
resume: ''
save_images: false
scale:
- 0.08
- 1.0
sched: cosine
seed: 42
smoothing: 0.1
split_bn: false
start_epoch: null
std: null
sync_bn: false
torchscript: false
train_interpolation: random
train_split: train
tta: 0
use_multi_epochs_loader: false
val_split: validation
validation_batch_size_multiplier: 1
vflip: 0.0
warmup_epochs: 20
warmup_lr: 1.0e-06
weight_decay: 0.067
workers: 4
aa: rand-m6-n3-mstd1.0-inc1
amp: false
apex_amp: false
aug_splits: 0
batch_size: 192
bn_eps: null
bn_momentum: null
bn_tf: false
channels_last: false
checkpoint_hist: 10
clip_grad: 1.0
clip_mode: norm
color_jitter: 0.4
cooldown_epochs: 10
crop_pct: 0.9
cutmix: 1.0
cutmix_minmax: null
data_dir: gs://my-imagenet
dataset: tfds/imagenet2012:5.0.0
decay_epochs: 1.0
decay_rate: 0.988
dist_bn: reduce
drop: 0.02333
drop_block: null
drop_connect: null
drop_path: 0.1
epoch_repeats: 0.0
epochs: 600
eval_metric: top1
experiment: ''
gp: null
hflip: 0.5
img_size: 224
initial_checkpoint: ''
input_size: null
interpolation: ''
jsd: false
local_rank: 0
log_interval: 50
log_wandb: false
lr: 0.0007
lr_cycle_limit: 1
lr_cycle_mul: 1.0
lr_noise: null
lr_noise_pct: 0.67
lr_noise_std: 1.0
mean: null
min_lr: 1.0e-05
mixup: 0.5
mixup_mode: batch
mixup_off_epoch: 0
mixup_prob: 1.0
mixup_switch_prob: 0.5
model: gmlp_s16_224
model_ema: true
model_ema_decay: 0.99992
model_ema_force_cpu: false
momentum: 0.9
native_amp: false
no_aug: false
no_prefetcher: false
no_resume_opt: false
num_classes: 1000
opt: adamw
opt_betas: null
opt_eps: 1.0e-06
output: ''
patience_epochs: 10
pin_mem: false
pretrained: false
ratio:
- 0.67
- 1.5
recount: 3
recovery_interval: 0
remode: pixel
reprob: 0.0
resplit: false
resume: ''
save_images: false
scale:
- 0.08
- 1.0
sched: cosine
seed: 42
smoothing: 0.1
split_bn: false
start_epoch: null
std: null
sync_bn: false
torchscript: false
train_interpolation: random
train_split: train
tta: 0
use_multi_epochs_loader: false
val_split: validation
validation_batch_size_multiplier: 1
vflip: 0.0
warmup_epochs: 20
warmup_lr: 1.0e-06
weight_decay: 0.067
workers: 4

Using TPU VM instance w/ pre-alpha timm bits setup as per: https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits#readme

python3 launch_xla.py --num-devices 8 train.py gs://my-imagenet --config hparams.yaml

Note the config yaml files do have args that are not used or active based on other overriding code or the state of the current training code. The bits code is under heavy development so these configs will likely need specific revision (currently https://github.com/rwightman/pytorch-image-models/commit/5e95ced5a7763541f7219f35fd155e3fbfe66e8b)

The gMlp hparams are the last (latest) in the series and likely will produce better results than the earlier gmixer / resmlp variants...

Note, for adapting the LR to differenrt batch size. AdamW is being used here and I use a sqrt scaling for the learning rate wrt to (global) batch size. I typicall use linear LR scaling w/ SGD or RMSProp for most from-scratch training.

aa: rand-m6-n4-mstd1.0-inc1
amp: false
apex_amp: false
aug_splits: 0
batch_size: 256
bn_eps: null
bn_momentum: null
bn_tf: false
channels_last: false
checkpoint_hist: 10
clip_grad: 1.0
clip_mode: norm
color_jitter: 0.4
cooldown_epochs: 10
crop_pct: 0.9
cutmix: 1.0
cutmix_minmax: null
data_dir: gs://my-imagenet
dataset: tfds/imagenet2012:5.0.0
decay_epochs: 1.0
decay_rate: 0.988
dist_bn: reduce
drop: 0.01
drop_block: null
drop_connect: null
drop_path: 0.1
epoch_repeats: 0.0
epochs: 600
eval_metric: top1
experiment: ''
gp: null
hflip: 0.5
img_size: 224
initial_checkpoint: ''
input_size: null
interpolation: ''
jsd: false
local_rank: 0
log_interval: 50
log_wandb: false
lr: 0.00088
lr_cycle_limit: 1
lr_cycle_mul: 1.0
lr_noise: null
lr_noise_pct: 0.67
lr_noise_std: 1.0
mean: null
min_lr: 1.0e-05
mixup: 0.5
mixup_mode: batch
mixup_off_epoch: 0
mixup_prob: 1.0
mixup_switch_prob: 0.5
model: resmlp_24_224
model_ema: true
model_ema_decay: 0.99992
model_ema_force_cpu: false
momentum: 0.9
native_amp: false
no_aug: false
no_prefetcher: false
no_resume_opt: false
num_classes: 1000
opt: adamw
opt_betas: null
opt_eps: 1.0e-06
output: ''
patience_epochs: 10
pin_mem: false
pretrained: false
ratio:
- 0.67
- 1.5
recount: 3
recovery_interval: 0
remode: pixel
reprob: 0.0
resplit: false
resume: ''
save_images: false
scale:
- 0.08
- 1.0
sched: cosine
seed: 42
smoothing: 0.1
split_bn: false
start_epoch: null
std: null
sync_bn: false
torchscript: false
train_interpolation: random
train_split: train
tta: 0
use_multi_epochs_loader: false
val_split: validation
validation_batch_size_multiplier: 1
vflip: 0.0
warmup_epochs: 20
warmup_lr: 1.0e-06
weight_decay: 0.067
workers: 4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment