Skip to content

Instantly share code, notes, and snippets.

@younesbelkada
Created May 21, 2023 18:14
Show Gist options
  • Save younesbelkada/119737fc725863d9f7f1a0941d5b5142 to your computer and use it in GitHub Desktop.
Save younesbelkada/119737fc725863d9f7f1a0941d5b5142 to your computer and use it in GitHub Desktop.
A gist to get the module names of all architectures in `transformers` that supports `accelerate`
import transformers
import transformers.models as models
models_that_support_accelerate = []
for model in dir(models):
if model[0] != '_':
model_module = getattr(models, model)
if hasattr(model_module, "modeling_" + model):
modeling_module = getattr(model_module, "modeling_" + model)
all_classes = dir(modeling_module)
for class_object in all_classes:
if class_object.endswith("PreTrainedModel") and hasattr(transformers, class_object):
pretrained_model_class = getattr(transformers, class_object)
if pretrained_model_class._no_split_modules is not None:
models_that_support_accelerate.append(model)
print(models_that_support_accelerate)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment