Skip to content

Instantly share code, notes, and snippets.

@lewtun
Created July 15, 2022 18:21
Show Gist options
  • Save lewtun/96154a8cae68f5944046f571a4407c61 to your computer and use it in GitHub Desktop.
Save lewtun/96154a8cae68f5944046f571a4407c61 to your computer and use it in GitHub Desktop.
Update label mappings in config.json
import json
import datasets
import transformers
from datasets import ClassLabel, load_dataset
from huggingface_hub import (
HfFolder,
ModelFilter,
hf_hub_download,
list_models,
upload_file,
)
from tqdm.auto import tqdm
from transformers import AutoConfig
def update_label_mapping(model_id, dataset_id, label_column="label"):
# Get token
hf_token = HfFolder.get_token()
# Download model config
config = AutoConfig.from_pretrained(model_id)
label2id = config.label2id
# Download dataset features
dataset_features = load_dataset(dataset_id, split="train", streaming=True).features
label_feature = dataset_features[label_column]
int2str_function = (
label_feature.int2str
if isinstance(label_feature, ClassLabel)
else label_feature.feature.int2str
)
num_classes = (
label_feature.num_classes
if isinstance(label_feature, ClassLabel)
else label_feature.feature.num_classes
)
commit_message = f"Align label mapping with {dataset_id} dataset"
commit_description = f"Hi there, your model is using a default label mapping. Accept this PR to align the label mapping with the `{dataset_id}` dataset this model was trained on. This will enable your model to be evaluated by [Hugging Face's automatic model evaluator](https://huggingface.co/spaces/autoevaluate/model-evaluator?dataset={dataset_id})"
if (
any("LABEL_" in label for label in label2id.keys())
and len(label2id) == num_classes
):
new_label2id = {int2str_function(idx): idx for idx in label2id.values()}
new_id2label = {idx: label for label, idx in new_label2id.items()}
# Update config file
filepath = hf_hub_download(
model_id,
filename="config.json",
use_auth_token=hf_token,
force_download=True,
)
with open(filepath, "r") as f:
config = json.load(f)
config["label2id"] = new_label2id
config["id2label"] = new_id2label
with open(filepath, "w") as f:
json.dump(config, f, indent=2)
hub_pr_url = upload_file(
path_or_fileobj=filepath,
path_in_repo="config.json",
repo_id=model_id,
create_pr=True,
commit_description=commit_description,
commit_message=commit_message,
token=hf_token,
)
print(f"Hub PR opened at {hub_pr_url}")
return model_id
else:
return None
def update_models(task, dataset_id, label_column):
filt = ModelFilter(library="transformers", task=task, trained_dataset=dataset_id)
models = list_models(filter=filt)
print(f"Found {len(models)} models for dataset {dataset_id}")
updated_models = []
for model in tqdm(models):
if model.modelId in updated_models:
print(f"Model {model.modelId} already updated. Skipping ...")
continue
try:
updated_model = update_label_mapping(
model.modelId, dataset_id, label_column=label_column
)
updated_models.append(updated_model)
except:
print(f"Could not open Hub PR for model {model.modelId}")
pass
updated_models = [m for m in updated_models if m is not None]
print(f"Updated {len(updated_models)} for dataset {dataset_id}")
return updated_models
# Usage - don't run this more than once or you'll get duplicate PRs!
updated_models = update_models("text-classification", "amazon_polarity", "label")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment