Skip to content

Instantly share code, notes, and snippets.

View rwightman's full-sized avatar

Ross Wightman rwightman

View GitHub Profile
import math
import os
from collections import defaultdict
from pathlib import Path
from huggingface_hub import CommitOperationAdd, preupload_lfs_files, create_commit
# fast transfers using a Rust library, `pip install hf-transfer`
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
@rwightman
rwightman / timm_vit_attention_map.py
Created November 21, 2023 18:39
Extract attention maps from timm vits' with Torch FX
import torch
import timm
from torchvision.models.feature_extraction import get_graph_node_names
timm.layers.set_fused_attn(False) # disable F.sdpa so softmax node is exposed
mm = timm.create_model('vit_medium_patch16_gap_256.sw_in12k_ft_in1k', pretrained=True)
softmax_nodes = [n for n in get_graph_node_names(mm)[0] if 'softmax' in n]
ff = timm.models.create_feature_extractor(mm, softmax_nodes)
with torch.no_grad():
model image_size embed_dim gmacs macts mparams image_gmacs image_macts image_mparams text_gmacs text_macts text_mparams
ViT-S-32-alt 224 256 1.78 4.71 43.22 1.15 2.5 22.59 0.64 2.21 20.63
ViT-S-32 224 384 2.84 6.48 63.09 1.15 2.5 22.64 1.69 3.98 40.44
ViT-M-32-alt 224 384 3.69 7.31 80.07 2.0 3.34 39.63 1.69 3.98 40.44
ViT-M-32 224 512 4.98 8.64 103.12 2.0 3.34 39.69 2.98 5.3 63.43
ViT-S-16-alt 224 256 5.25 14.16 42.4 4.61 11.95 21.76 0.64 2.21 20.63
ViT-S-16 224 384 6.3 15.92 62.26 4.61 11.95 21.81 1.69 3.98 40.44
ViT-B-32-quickgelu 224 512 7.4 10.31 151.28 4.41 5.01 87.85 2.98 5.3 63.43
ViT-B-32 224 512 7.4 10.31 151.28 4.41 5.01 87.85 2.98 5.3 63.43
convnext_tiny 224 1024 7.46 18.74 92.3 4.47 13.44 28.61 2.98 5.3 63.69

Some hparams related to RegNets (and other nets) in TPU training series https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights

All models trained on x8 TPUs, so global batch == batch_size * 8

If in the weight name it says ra3 it means rmsproptf + mixup + cutmix + rand erasing + (usually) lr noise + rand-aug + head dropout + drop path (stochastic depth). Older ra2 scheme was very similar but no cutmix and rand-aug was always using normal sampling (mstd0.5 or mstd1.0) for rand-aug magnitude, where as ra3 is often (not always) using uniform sampling (mstd101).

Some weights were trained with sgd + grad clipping (cx in name where x is one of h, 1, 2, 3 ), h = amped up augreg.

I believe the 064 regnety was very close with both the ra3 and sgd approach, hparams I have kept were the sgd ones but I believe published weights were rmsproptf and edged out by a hair.

@rwightman
rwightman / _timm_hparams.md
Last active May 30, 2023 05:18
Recent timm hparams...

A variety of hparams used to train vit, convnext, vit-hybrids (maxvit, coatnet) recently in timm

All variations on the same theme (DeiT / Swin pretraining) but with different tweaks here and there.

These were all run on 4-8 GPU or TPU devices, they use --lr-base which rescales the LR automatically based on global batch size (relative to --lr-base-size) so adapting to different GPU counts will work well within a range, running at significanly lower or higher global batch sizes will require re-running a LR search.

More recntly, DeiT-III has shown to be a very compelling set of hparams for vit like models, but I've yet to do full runs myself, but theirs can be adapted to timm train scripts (3A aug added recently). https://github.com/facebookresearch/deit/blob/main/README_revenge.md

To use the yaml files directly w/ timm train script.

Hparams were run on 8x A100 for in12k or 12k fine-tune runs and 4x V100 for the rest, so global batch size = 320 * 4, etc and should be rescaled using a sqrt rule if changing the global batch size.

@rwightman
rwightman / vit-aot.csv
Created July 13, 2022 05:22
timm vit models, eager vs aot vs torchscript, AMP, PyTorch 1.12
model infer_samples_per_sec infer_step_time infer_batch_size infer_img_size train_samples_per_sec train_step_time train_batch_size train_img_size param_count
vit_small_patch16_224 2444.7 104.691 256 224 955.88 267.078 256 224 22.05
vit_relpos_medium_patch16_224 1107.38 231.158 256 224 502.75 253.69 128 224 38.75
vit_base_patch16_224 1013.88 252.477 256 224 358.36 356.433 128 224 86.57
vit_base_patch16_384 288.27 888.045 256 384 102.82 300.795 31 384 86.86

comparing ln

Comparing some LayerNorm for 2D rank-4 NCHW tensors via ConvNeXt models on 3090 and V100.

All runs done with native torch AMP, PyTorch 1.12 cu113.

Some col descriptions

  • fmt - PyTorch memory_format
  • cg - full model codgen (one of torchscript, aot, eager (none))
  • layer - the LayerNorm impl
@rwightman
rwightman / BENCHMARK.md
Last active July 12, 2023 09:40
timm model benchmark compare

NCHW and NHWC benchmark numbers for some common image classification models in timm.

For NCHW: python benchmark.py --model-list model.txt --amp -b 128

For NHWC: python benchmark.py --model-list model.txt --amp -b 128 --channels-last

Note the test res for efficientnet_b1/b2/b3/b4 and regnety_160 were adjusted in timm to match original paper and not timm defaults. Benchmark script in root of timm https://github.com/rwightman/pytorch-image-models/blob/master/benchmark.py

@rwightman
rwightman / standalone_bifpn.py
Last active January 27, 2023 06:09
Use effdet BiFPN standalone
from typing import Callable, Union
from dataclasses import dataclass
import timm
import torch.nn as nn
from effdet.efficientdet import BiFpn
from effdet.config import fpn_config
from omegaconf import DictConfig