Skip to content

Instantly share code, notes, and snippets.

@muhark
Last active June 20, 2024 05:37
Show Gist options
  • Save muhark/f861fa35e5a6bddef20ddd2da6c30820 to your computer and use it in GitHub Desktop.
Save muhark/f861fa35e5a6bddef20ddd2da6c30820 to your computer and use it in GitHub Desktop.
Offline HuggingFace Models on HPC

Downloading HuggingFace Models

This gist shares a little workflow and script for a task that most people using university HPCs for NLP research will need to do: downloading and storing HuggingFace models for use on compute nodes.

What this workflow is for:

  • Context: you want to use HuggingFace models on Della (or other HPC clusters).
  • Problem 1: you cannot call AutoModel.from_pretrained('model/name') at run time because compute nodes are not connected to the internet.
  • Problem 2: running AutoModel.from_pretrained() on the head node is impractical because the model is too large to be loaded.
  • Problem 3: you do not want to save the model weights to the default ~/.cache/ because you only get 10GB of storage on /home

Step 1: Configure HF_HOME

HuggingFace tools check a couple of environmental variables to know where to look for and save pre-downloaded (cached) models and datasets. By default these are under /home/$USER, which might not have enough space to store large models and datasets.

You will want to set the values of two environmental variables: HF_HOME and HF_DATASTES_CACHE. (The latter governs datasets, but while you're in there you may as well change both).

I set these to the /scratch partition. Add these two lines to your ~/.bashrc.

[...]
export HF_HOME="/scratch/gpfs/$USER/.cache/huggingface"
export HG_DATASETS_CACHE="/scratch/gpfs/$USER/.cache/huggingface/datasets"

Step 2: Download Scripts on Head Node

The attached script is a wrapper to the snapshot_download function from huggingface_hub. The only dependency is the huggingface_hub library. Assuming that you keep the name of this script hf_model_downloader.py, usage is as follows:

python hf_model_downloader.py --repo_id='HF_MODEL_REF' --revision='main' --cache_dir=''

Step 3: Loading cached models

Use the following function inside scripts to load the cached models:

import os
from pathlib import Path
from typing import Optional

HF_DEFAULT_HOME = os.environ.get("HF_HOME", "~/.cache/huggingface/hub")

def get_weight_dir(
    model_ref: str,
    *,
    model_dir: str | os.PathLike[Any] = HF_DEFAULT_HOME,
    revision: str = "main",
) -> Path:
    """
    Parse model name to locally stored weights.
    Args:
        model_ref (str) : Model reference containing org_name/model_name such as 'meta-llama/Llama-2-7b-chat-hf'.
        revision (str): Model revision branch. Defaults to 'main'.
        model_dir (str | os.PathLike[Any]): Path to directory where models are stored. Defaults to value of $HF_HOME (or present directory)

    Returns:
        str: path to model weights within model directory
    """
    model_dir = Path(model_dir)
    assert model_dir.is_dir()
    model_path = model_dir / "--".join(["models", *model_ref.split("/")])
    assert model_path.is_dir()
    snapshot_hash = (model_path / "refs" / revision).read_text()
    weight_dir = model_path / "snapshots" / snapshot_hash
    assert weight_dir.is_dir()
    return weight_dir

And then you can use the .from_pretrained methods as usual!

weights_dir = get_weight_dir('google/flan-t5-small')
model = AutoModel.from_pretrained(weights_dir)
tokenizer = AutoTokenizer.from_pretrained(weights_dir)
#! /usr/bin/python3
# HuggingFace Model Downloader Script
# Author: Dr Musashi Hinck
"""
Convenience command-line script to download model assets to cache for offline usage on compute nodes.
Usage:
`python hf_model_downloader.py --repo_id='HF_MODEL_REF' --revision='main' --cache_dir=''`
"""
import argparse
import os
from huggingface_hub import snapshot_download
def download_save_huggingface_model(repo_id: str, revision: str, cache_dir: str):
if cache_dir=='':
cache_dir = os.environ.get("HF_HOME")
snapshot_download(repo_id=repo_id, revision=revision, cache_dir=cache_dir)
def main():
# Read in arguments
parser = argparse.ArgumentParser()
parser.add_argument('--repo_id', type=str, help="HF Model Hub Repo ID")
parser.add_argument('--revision', type=str, default='main')
parser.add_argument('--cache_dir', type=str, default='', help='Location to save model, defaults to None')
args = parser.parse_args()
download_save_huggingface_model(args.repo_id, args.revision, args.cache_dir)
if __name__=='__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment