/flair_base_model_detector.py Secret
Last active
September 5, 2023 20:44
Flair Base Model Detector
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import flair | |
import os | |
import pickle | |
import shutil | |
from huggingface_hub import login, HfApi | |
from flair.models import SequenceTagger | |
from flair.embeddings import StackedEmbeddings, TransformerWordEmbeddings | |
from pathlib import Path | |
# Please adjust! | |
flair.cache_root = Path("/mnt/datasets/.flair") | |
errors = [] | |
def determine_base_model(flair_model_name: str) -> str: | |
try: | |
tagger = SequenceTagger.load(flair_model_name) | |
if isinstance(tagger.embeddings, StackedEmbeddings): | |
for embedding in tagger.embeddings.embeddings: | |
if isinstance(embedding, TransformerWordEmbeddings): | |
return embedding.model.name_or_path | |
elif isinstance(tagger.embeddings, TransformerWordEmbeddings): | |
return tagger.embeddings.model.name_or_path | |
except Exception as e: | |
error_message = f"Could not parse Flair Model {flair_model_name} :" + str(e) | |
errors.append(error_message) | |
print(error_message) | |
return "" | |
hf_token = os.environ.get("HF_TOKEN") | |
login(token=hf_token, add_to_git_credential=True) | |
api = HfApi() | |
base_model_mapping = {} | |
for flair_model in api.list_models(filter="flair"): | |
print("Detecting base model for:", flair_model.modelId) | |
base_model = determine_base_model(flair_model.modelId) | |
if not base_model: | |
continue | |
print("Detected Base model for", flair_model.modelId, "is:", base_model) | |
base_model_mapping[flair_model.modelId] = base_model | |
with open("base_model_mapping.pkl", "wb") as f_out: | |
pickle.dump(base_model_mapping, f_out, pickle.HIGHEST_PROTOCOL) | |
with open("errors.pkl", "wb") as f_out: | |
pickle.dump(errors, f_out, pickle.HIGHEST_PROTOCOL) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
🤗 Flair Base Model Detector
In order to get the script working, the following things needs to be configured:
HF_TOKEN
environment variable: just visit Access Tokens site and copy your Access Token. Then set it viaexport HF_TOKEN="<your-token>"
on commandline (add leading space so it won't occur in your shell history).flair.cache_root
: set it to e.g. a NAS storage05.09.2023, 22:31: Stats
The script was initially executed. Runtime was ~1.5 hours and it downloaded ~85GB of data.
Base Model Table can be created with the following code:
It then outputs: