Skip to content

Instantly share code, notes, and snippets.

@ralphbean
Last active July 6, 2024 00:53
Show Gist options
  • Save ralphbean/8ea941b0cf06b92191ac4b3074ede656 to your computer and use it in GitHub Desktop.
Save ralphbean/8ea941b0cf06b92191ac4b3074ede656 to your computer and use it in GitHub Desktop.
refresh-oci-copy-file.py
#!/usr/bin/env python
""" Write oci-copy.yaml file based on latest data in hugginface
In order to get the latest revision:
$ python3 refresh-oci-copy-file.py prometheus-eval/prometheus-8x7b-v2.0
In order to get files and digests for a specific revision in the history:
$ python3 refresh-oci-copy-file.py --revision e0bb4692356a1738acf25f15180e9f025725b0f2 prometheus-eval/prometheus-8x7b-v2.0
"""
import argparse
import hashlib
import logging
import mimetypes
import os
import yaml
import httpx
parser = argparse.ArgumentParser()
parser.add_argument("repository")
parser.add_argument("--revision", default="main")
parser.add_argument("--debug", default=False, action="store_true")
args = parser.parse_args()
if args.debug:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.WARN)
known_types = {
".safetensors": "application/octet-stream",
".model": "application/octet-stream",
".gguf": "application/octet-stream",
".pt": "application/octet-stream",
}
for suffix, mime_type in known_types.items():
mimetypes.add_type(mime_type, suffix)
def determine_digest(url, info, token):
if info.get("lfs"):
return info["lfs"]["sha256"]
headers = {}
if token:
headers["Authorization"] = f"Bearer {token}"
response = httpx.get(url, headers=headers)
response.raise_for_status()
data = response.content
return hashlib.sha256(data).hexdigest()
token = os.environ.get("HUGGINGFACE_TOKEN")
print(f"🤗 Querying hugginface.co for {args.repository}")
headers = {}
if token:
print(f"🔑 Using key {token[:6]}{'*' * len(token[6:])}")
headers["Authorization"] = f"Bearer {token}"
else:
print(
"🤷 No $HUGGINGFACE_TOKEN environment variable found. "
"Proceeding unauthenticated."
)
url = f"https://huggingface.co/api/models/{args.repository}/revision/{args.revision}"
params = dict(blobs=True)
response = httpx.get(url=url, headers=headers, params=params)
response.raise_for_status()
data = response.json()
revision = data["sha"]
result = {"artifact_type": "application/x-mlmodel", "artifacts": []}
for sibling in data["siblings"]:
if sibling["rfilename"].startswith("."):
continue
url = f"https://huggingface.co/{args.repository}/resolve/{revision}/{sibling['rfilename']}"
print(f"🔗 Considering {url}")
artifact = {
"source": url,
"filename": sibling["rfilename"],
"type": mimetypes.guess_type(url)[0],
"sha256sum": determine_digest(url, sibling, token),
}
result["artifacts"].append(artifact)
print("💾 Writing oci-copy.yaml")
with open("oci-copy.yaml", "w") as f:
f.write(yaml.dump(result, sort_keys=True))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment