Skip to content

Instantly share code, notes, and snippets.

@elishowk
Last active September 20, 2021 13:01
Show Gist options
  • Save elishowk/688ac5d82e2e158d44ed6d5241bbdd68 to your computer and use it in GitHub Desktop.
Save elishowk/688ac5d82e2e158d44ed6d5241bbdd68 to your computer and use it in GitHub Desktop.
models-metadata-fix.py
import argparse
import string
import os
import re
import shutil
import subprocess
from pathlib import Path
from typing import Dict, Optional, Union
from huggingface_hub.hf_api import HfApi
from huggingface_hub.repository import Repository
from huggingface_hub.repocard import metadata_load, metadata_save
# command-line parsing
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--start", type=ascii, choices=["'" + letter + "'" for letter in list(string.ascii_lowercase)],
help="start from this first letter of model name (including itself)",
required=True)
parser.add_argument("-e", "--end", type=ascii, choices=["'" + letter + "'" for letter in list(string.ascii_lowercase)],
help="end at this first letter of model name (including itself)",
required=True)
parser.add_argument("-l", "--licenses", action="store_true",
help="fix licenses when present repo in metadata")
parser.add_argument("-d", "--datasets", action="store_true",
help="remove datasets: null")
parser.add_argument("-p", "--proceed", action="store_true",
help="Process every repository without user confirmation")
args = parser.parse_args()
# Important, to not download the large files
os.environ["GIT_LFS_SKIP_SMUDGE"] = "1"
os.makedirs(os.path.join(".", "models-clone"), exist_ok=True)
api = HfApi()
models = api.list_models(full=True)
selected_models = []
letters = tuple(map(chr, range(ord(args.start[1]), ord(args.end[1])+1)))
# filter names and licenses
selected_models = [m for m in models
if m.modelId.lower().startswith(letters)
]
if input(f"\nProcess modifications for {len(selected_models)} models starting with letters {', '.join(letters)} ? (y/N) ->") != "y":
exit(1)
def del_none(metadata: Dict):
"""
Clean Nones
"""
for key in list(metadata.keys()):
if type(metadata[key]) is list:
if None in metadata[key]:
metadata[key] = [val for val in metadata[key] if val is not None]
elif type(metadata[key]) is dict:
del_none(metadata[key])
elif metadata[key] is None:
del metadata[key]
def update_repocard_and_commit(repo_dir: Union[str, Path], metadata_to_correct: Dict):
"""
Correct, Commit and push changes
Given a repository directory and a list of keys to correct
"""
# Do not clone_from with Repository as
# as it forces to download lfs files apparently
repo = Repository(
local_dir=repo_dir,
use_auth_token=False,
)
metadata = repo.repocard_metadata_load()
new_metadata = metadata.copy()
# check that the metadata to correct is actually in the file.
if "license" in metadata_to_correct.keys():
if re.match("apache", metadata["license"].lower()) and metadata["license"] != "apache-2.0":
new_metadata["license"] = "apache-2.0"
if "datasets" in metadata_to_correct.keys():
new_metadata["datasets"] = [d for d in metadata["datasets"] if d is not None]
if metadata != new_metadata:
try:
repo.repocard_metadata_save(new_metadata)
p = subprocess.run(
["git", "diff", "--color", "--minimal"],
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=True,
encoding="utf-8",
cwd=repo_dir
)
print(f"{repo.git_remote_url()} on branch {repo.current_branch}:")
print(p.stdout)
if args.proceed or (input(f"Commit & Push ? (y/N) ->")) == "y":
repo.git_add()
repo.git_commit(f"Automatic correction of README.md metadata for keys {', '.join(metadata_to_correct.keys())}. Contact website@huggingface.co for any question")
repo.git_push()
print(f"Pushed a new commit")
except Exception as exc:
print(exc)
import pdb
pdb.set_trace()
for m in selected_models:
local_dir = os.path.join(".", "models-clone")
repo_dir = os.path.join(".", "models-clone", "repo_dir")
# keep localhost by precaution to force exeucting on moonrise in production
url = f"http://localhost:5564/{m.modelId}"
metadata_to_correct = {}
# Search in ModelInfo.tags() in moon-landing to find the info to correct
if args.licenses:
licenses = [tag for tag in m.tags if tag.startswith("license:")]
if any([l.split(":")[1] in ["apache", "apache 2.0", "apache-2", "apache license 2.O", "apache-2.0-license", "apache v2.0"] for l in licenses]):
metadata_to_correct["license"] = "apache-2.0"
if args.datasets:
datasets = [tag for tag in m.tags if tag.startswith("datasets:")]
if any([d.split(":")[1] in ["null"] for d in datasets]):
metadata_to_correct["datasets"] = [d for d in datasets if d != "null" and d != None]
# clone and correct only if there is keys to correct
if len(list(metadata_to_correct.keys())):
if os.path.exists(repo_dir):
shutil.rmtree(repo_dir, ignore_errors=True)
subprocess.run(
["git", "clone", url, "repo_dir"],
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=True,
encoding="utf-8",
cwd=local_dir,
)
update_repocard_and_commit(repo_dir, metadata_to_correct)
shutil.rmtree(repo_dir, ignore_errors=True)
print("\nFinished")
exit(0)
@elishowk
Copy link
Author

WIP

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment