Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active January 23, 2024 07:15
Show Gist options
  • Save pszemraj/5bc8dcc59d99f6a8f5cd8c3f784d5b08 to your computer and use it in GitHub Desktop.
Save pszemraj/5bc8dcc59d99f6a8f5cd8c3f784d5b08 to your computer and use it in GitHub Desktop.
huggingface hub - download a full snapshot of a repository without using git
"""
hf_hub_download.py
This script allows you to download a snapshot repository from the Hugging Face Hub to a local directory without needing Git or loading the model.
Usage:
python hf_hub_download.py <repo_id> [options]
Arguments:
<repo_id> Repository ID in the format "organization/repository".
Options:
--revision <str> Revision of the repository (commit/tag/branch). Default: None.
--cache_dir <str> Directory to store the downloaded files. Default: "~/.cache/huggingface/transformers".
--library_name <str> Name of the library associated with the download. Default: None.
--library_version <str> Version of the library associated with the download. Default: None.
--user_agent <str> User agent string. Default: None.
--ignore_files <str> List of file patterns to ignore. Default: None.
--use_auth_token <str> Authentication token for private repositories. Default: None.
"""
import logging
from pathlib import Path
from typing import Dict, List, Optional, Union
from fnmatch import fnmatch
from packaging import version
from tqdm.auto import tqdm
import fire
import huggingface_hub
from huggingface_hub import HfApi, HfFolder, cached_download, hf_hub_url
HUGGINGFACE_HUB_CACHE = Path("~/.cache/huggingface/transformers").expanduser()
DEFAULT_CACHE = Path.cwd() / "downloaded-models"
def setup_logging():
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
return logger
def snapshot_download(
repo_id: str,
revision: Optional[str] = None,
cache_dir: Optional[Union[str, Path]] = None,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Union[Dict, str, None] = None,
ignore_files: Optional[List[str]] = None,
use_auth_token: Union[bool, str, None] = None,
) -> str:
cache_dir = Path(cache_dir) if cache_dir else HUGGINGFACE_HUB_CACHE
_api = HfApi()
token = HfFolder.get_token() if use_auth_token else None
model_info = _api.model_info(repo_id=repo_id, revision=revision, token=token)
storage_folder = cache_dir / repo_id.replace("/", "_")
all_files = model_info.siblings
modules_json_file = next(
(file for file in all_files if file.rfilename == "modules.json"), None
)
if modules_json_file is not None:
all_files.remove(modules_json_file)
all_files.append(modules_json_file)
logger = setup_logging()
pbar = tqdm(all_files, desc="Downloading files", unit="file")
for model_file in pbar:
if ignore_files is not None and any(
fnmatch(model_file.rfilename, pattern) for pattern in ignore_files
):
continue
url = hf_hub_url(
repo_id, filename=model_file.rfilename, revision=model_info.sha
)
relative_filepath = Path(model_file.rfilename)
nested_dirname = storage_folder / relative_filepath.parent
nested_dirname.mkdir(parents=True, exist_ok=True)
path = cached_download(
url=url,
cache_dir=storage_folder,
force_filename=str(relative_filepath),
library_name=library_name,
library_version=library_version,
user_agent=user_agent,
use_auth_token=use_auth_token,
legacy_cache_layout=version.parse(huggingface_hub.__version__)
>= version.parse("0.8.1"),
)
if Path(f"{path}.lock").exists():
Path(f"{path}.lock").unlink()
pbar.close()
logger.info("Download completed.")
return str(storage_folder)
def main(
repo_id: str,
revision: Optional[str] = None,
cache_dir: Optional[Union[str, Path]] = None,
library_name: Optional[str] = None,
library_version: Optional[str] = None,
user_agent: Union[Dict, str, None] = None,
ignore_files: Optional[List[str]] = None,
use_auth_token: Union[bool, str, None] = None,
):
"""
Main function to download the snapshot repository.
snapshot_download - downloads a repo to a local directory without needing git or loading the model in AutoModelForBlah
**Credit to sentence-transformers**
Args:
repo_id (str): Repository ID in the format "organization/repository".
revision (str, optional): Revision of the repository (commit/tag/branch). Defaults to None.
cache_dir (Union[str, Path, None], optional): Directory to store the downloaded files. Defaults to None.
library_name (str, optional): Name of the library associated with the download. Defaults to None.
library_version (str, optional): Version of the library associated with the download. Defaults to None.
user_agent (Union[Dict, str, None], optional): User agent string. Defaults to None.
ignore_files (List[str], optional): List of file patterns to ignore. Defaults to None.
use_auth_token (Union[bool, str, None], optional): Authentication token for private repositories. Defaults to None.
Returns:
str: Storage folder path where the repository is downloaded.
"""
cache_dir = Path(cache_dir) if cache_dir else DEFAULT_CACHE
storage_folder = snapshot_download(
repo_id=repo_id,
revision=revision,
cache_dir=cache_dir,
library_name=library_name,
library_version=library_version,
user_agent=user_agent,
ignore_files=ignore_files,
use_auth_token=use_auth_token,
)
print(f"Snapshot repository downloaded to: {storage_folder}")
if __name__ == "__main__":
fire.Fire(main)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment