Created
July 15, 2022 18:21
-
-
Save lewtun/96154a8cae68f5944046f571a4407c61 to your computer and use it in GitHub Desktop.
Update label mappings in config.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import datasets | |
import transformers | |
from datasets import ClassLabel, load_dataset | |
from huggingface_hub import ( | |
HfFolder, | |
ModelFilter, | |
hf_hub_download, | |
list_models, | |
upload_file, | |
) | |
from tqdm.auto import tqdm | |
from transformers import AutoConfig | |
def update_label_mapping(model_id, dataset_id, label_column="label"): | |
# Get token | |
hf_token = HfFolder.get_token() | |
# Download model config | |
config = AutoConfig.from_pretrained(model_id) | |
label2id = config.label2id | |
# Download dataset features | |
dataset_features = load_dataset(dataset_id, split="train", streaming=True).features | |
label_feature = dataset_features[label_column] | |
int2str_function = ( | |
label_feature.int2str | |
if isinstance(label_feature, ClassLabel) | |
else label_feature.feature.int2str | |
) | |
num_classes = ( | |
label_feature.num_classes | |
if isinstance(label_feature, ClassLabel) | |
else label_feature.feature.num_classes | |
) | |
commit_message = f"Align label mapping with {dataset_id} dataset" | |
commit_description = f"Hi there, your model is using a default label mapping. Accept this PR to align the label mapping with the `{dataset_id}` dataset this model was trained on. This will enable your model to be evaluated by [Hugging Face's automatic model evaluator](https://huggingface.co/spaces/autoevaluate/model-evaluator?dataset={dataset_id})" | |
if ( | |
any("LABEL_" in label for label in label2id.keys()) | |
and len(label2id) == num_classes | |
): | |
new_label2id = {int2str_function(idx): idx for idx in label2id.values()} | |
new_id2label = {idx: label for label, idx in new_label2id.items()} | |
# Update config file | |
filepath = hf_hub_download( | |
model_id, | |
filename="config.json", | |
use_auth_token=hf_token, | |
force_download=True, | |
) | |
with open(filepath, "r") as f: | |
config = json.load(f) | |
config["label2id"] = new_label2id | |
config["id2label"] = new_id2label | |
with open(filepath, "w") as f: | |
json.dump(config, f, indent=2) | |
hub_pr_url = upload_file( | |
path_or_fileobj=filepath, | |
path_in_repo="config.json", | |
repo_id=model_id, | |
create_pr=True, | |
commit_description=commit_description, | |
commit_message=commit_message, | |
token=hf_token, | |
) | |
print(f"Hub PR opened at {hub_pr_url}") | |
return model_id | |
else: | |
return None | |
def update_models(task, dataset_id, label_column): | |
filt = ModelFilter(library="transformers", task=task, trained_dataset=dataset_id) | |
models = list_models(filter=filt) | |
print(f"Found {len(models)} models for dataset {dataset_id}") | |
updated_models = [] | |
for model in tqdm(models): | |
if model.modelId in updated_models: | |
print(f"Model {model.modelId} already updated. Skipping ...") | |
continue | |
try: | |
updated_model = update_label_mapping( | |
model.modelId, dataset_id, label_column=label_column | |
) | |
updated_models.append(updated_model) | |
except: | |
print(f"Could not open Hub PR for model {model.modelId}") | |
pass | |
updated_models = [m for m in updated_models if m is not None] | |
print(f"Updated {len(updated_models)} for dataset {dataset_id}") | |
return updated_models | |
# Usage - don't run this more than once or you'll get duplicate PRs! | |
updated_models = update_models("text-classification", "amazon_polarity", "label") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment