Skip to content

Instantly share code, notes, and snippets.

@aribornstein
Created May 20, 2021 16:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aribornstein/b8c3b828382fd835ea028099ec29eedc to your computer and use it in GitHub Desktop.
Save aribornstein/b8c3b828382fd835ea028099ec29eedc to your computer and use it in GitHub Desktop.
from flash.core.registry import FlashRegistry
from functools import partial
import timm
TIMM_BACKBONES_REGISTRY = FlashRegistry("backbones")
for model_name in timm.list_models():
def _fn_timm(model_name, pretrained = True) -> Tuple[nn.Module, int]:
backbone = timm.create_model(
model_name, pretrained=pretrained, num_classes=num_classes
)
num_features = backbone.num_features
return backbone, num_features
TIMM_BACKBONES_REGISTRY(fn=partial(_fn_timm, model_name), name=model_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment