Skip to content

Instantly share code, notes, and snippets.

@younesbelkada
younesbelkada / shard_weights.py
Last active October 9, 2023 22:17
A script to shard any model on the Hugging Face format
import torch
import os
import json
import argparse
parser = argparse.ArgumentParser(description='Sharding Hugging Face models')
parser.add_argument('--sharding_factor', default=4, type=int, help='Sharding factor - aka how many shards to create')
parser.add_argument('--source_model_path', default="t5-v1_1-xl", type=str, help='Relative path to the source model folder')
parser.add_argument('--sharded_model_path', default="t5-v1_1-xl-sharded", type=str, help='Relative path to the target sharded model folder')
args = parser.parse_args()