Last active
September 20, 2021 13:01
-
-
Save elishowk/688ac5d82e2e158d44ed6d5241bbdd68 to your computer and use it in GitHub Desktop.
models-metadata-fix.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import string | |
import os | |
import re | |
import shutil | |
import subprocess | |
from pathlib import Path | |
from typing import Dict, Optional, Union | |
from huggingface_hub.hf_api import HfApi | |
from huggingface_hub.repository import Repository | |
from huggingface_hub.repocard import metadata_load, metadata_save | |
# command-line parsing | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-s", "--start", type=ascii, choices=["'" + letter + "'" for letter in list(string.ascii_lowercase)], | |
help="start from this first letter of model name (including itself)", | |
required=True) | |
parser.add_argument("-e", "--end", type=ascii, choices=["'" + letter + "'" for letter in list(string.ascii_lowercase)], | |
help="end at this first letter of model name (including itself)", | |
required=True) | |
parser.add_argument("-l", "--licenses", action="store_true", | |
help="fix licenses when present repo in metadata") | |
parser.add_argument("-d", "--datasets", action="store_true", | |
help="remove datasets: null") | |
parser.add_argument("-p", "--proceed", action="store_true", | |
help="Process every repository without user confirmation") | |
args = parser.parse_args() | |
# Important, to not download the large files | |
os.environ["GIT_LFS_SKIP_SMUDGE"] = "1" | |
os.makedirs(os.path.join(".", "models-clone"), exist_ok=True) | |
api = HfApi() | |
models = api.list_models(full=True) | |
selected_models = [] | |
letters = tuple(map(chr, range(ord(args.start[1]), ord(args.end[1])+1))) | |
# filter names and licenses | |
selected_models = [m for m in models | |
if m.modelId.lower().startswith(letters) | |
] | |
if input(f"\nProcess modifications for {len(selected_models)} models starting with letters {', '.join(letters)} ? (y/N) ->") != "y": | |
exit(1) | |
def del_none(metadata: Dict): | |
""" | |
Clean Nones | |
""" | |
for key in list(metadata.keys()): | |
if type(metadata[key]) is list: | |
if None in metadata[key]: | |
metadata[key] = [val for val in metadata[key] if val is not None] | |
elif type(metadata[key]) is dict: | |
del_none(metadata[key]) | |
elif metadata[key] is None: | |
del metadata[key] | |
def update_repocard_and_commit(repo_dir: Union[str, Path], metadata_to_correct: Dict): | |
""" | |
Correct, Commit and push changes | |
Given a repository directory and a list of keys to correct | |
""" | |
# Do not clone_from with Repository as | |
# as it forces to download lfs files apparently | |
repo = Repository( | |
local_dir=repo_dir, | |
use_auth_token=False, | |
) | |
metadata = repo.repocard_metadata_load() | |
new_metadata = metadata.copy() | |
# check that the metadata to correct is actually in the file. | |
if "license" in metadata_to_correct.keys(): | |
if re.match("apache", metadata["license"].lower()) and metadata["license"] != "apache-2.0": | |
new_metadata["license"] = "apache-2.0" | |
if "datasets" in metadata_to_correct.keys(): | |
new_metadata["datasets"] = [d for d in metadata["datasets"] if d is not None] | |
if metadata != new_metadata: | |
try: | |
repo.repocard_metadata_save(new_metadata) | |
p = subprocess.run( | |
["git", "diff", "--color", "--minimal"], | |
stderr=subprocess.PIPE, | |
stdout=subprocess.PIPE, | |
check=True, | |
encoding="utf-8", | |
cwd=repo_dir | |
) | |
print(f"{repo.git_remote_url()} on branch {repo.current_branch}:") | |
print(p.stdout) | |
if args.proceed or (input(f"Commit & Push ? (y/N) ->")) == "y": | |
repo.git_add() | |
repo.git_commit(f"Automatic correction of README.md metadata for keys {', '.join(metadata_to_correct.keys())}. Contact website@huggingface.co for any question") | |
repo.git_push() | |
print(f"Pushed a new commit") | |
except Exception as exc: | |
print(exc) | |
import pdb | |
pdb.set_trace() | |
for m in selected_models: | |
local_dir = os.path.join(".", "models-clone") | |
repo_dir = os.path.join(".", "models-clone", "repo_dir") | |
# keep localhost by precaution to force exeucting on moonrise in production | |
url = f"http://localhost:5564/{m.modelId}" | |
metadata_to_correct = {} | |
# Search in ModelInfo.tags() in moon-landing to find the info to correct | |
if args.licenses: | |
licenses = [tag for tag in m.tags if tag.startswith("license:")] | |
if any([l.split(":")[1] in ["apache", "apache 2.0", "apache-2", "apache license 2.O", "apache-2.0-license", "apache v2.0"] for l in licenses]): | |
metadata_to_correct["license"] = "apache-2.0" | |
if args.datasets: | |
datasets = [tag for tag in m.tags if tag.startswith("datasets:")] | |
if any([d.split(":")[1] in ["null"] for d in datasets]): | |
metadata_to_correct["datasets"] = [d for d in datasets if d != "null" and d != None] | |
# clone and correct only if there is keys to correct | |
if len(list(metadata_to_correct.keys())): | |
if os.path.exists(repo_dir): | |
shutil.rmtree(repo_dir, ignore_errors=True) | |
subprocess.run( | |
["git", "clone", url, "repo_dir"], | |
stderr=subprocess.PIPE, | |
stdout=subprocess.PIPE, | |
check=True, | |
encoding="utf-8", | |
cwd=local_dir, | |
) | |
update_repocard_and_commit(repo_dir, metadata_to_correct) | |
shutil.rmtree(repo_dir, ignore_errors=True) | |
print("\nFinished") | |
exit(0) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
WIP