Skip to content

Instantly share code, notes, and snippets.

@benob
Created March 4, 2023 17:35
Show Gist options
  • Star 32 You must be signed in to star a gist
  • Fork 8 You must be signed in to fork a gist
  • Save benob/4850a0210b01672175942203aa36d300 to your computer and use it in GitHub Desktop.
Save benob/4850a0210b01672175942203aa36d300 to your computer and use it in GitHub Desktop.
Script to decompose/recompose LLAMA LLM models with different number of shards.
# script to decompose/recompose llama model in different number of shards
# note that it loads the full model * 2 in cpu memory
import os
import json
import sys
import torch
import glob
if len(sys.argv) != 4:
print('usage: %s <new-shards> <input-model-path> <output-model-path>' % sys.argv[0], file=sys.stderr)
sys.exit(1)
num_shards = int(sys.argv[1])
input_model_dir = sys.argv[2]
output_model_dir = sys.argv[3]
with open(os.path.join(input_model_dir, 'params.json'), 'r') as fp:
params = json.loads(fp.read())
assert params['dim'] % num_shards == 0, "number of shards need to divide parameter dimension %d" % params['dim']
print('loading...')
checkpoints = [torch.load(path, map_location=torch.device('cpu')) for path in glob.glob(os.path.join(input_model_dir, '*.pth'))]
layer_kind = {
'tok_embeddings': 'ParallelEmbedding',
'output': 'ColumnParallelLinear',
'attention.wq': 'ColumnParallelLinear',
'attention.wk': 'ColumnParallelLinear',
'attention.wv': 'ColumnParallelLinear',
'attention.wo': 'RowParallelLinear',
'feed_forward.w1': 'ColumnParallelLinear',
'feed_forward.w2': 'RowParallelLinear',
'feed_forward.w3': 'ColumnParallelLinear',
'attention_norm': None,
'ffn_norm': None,
'norm': None,
'rope.freqs': None,
}
output = [dict() for x in range(num_shards)]
print('converting...')
for key in checkpoints[0].keys():
tensors = [m[key] for m in checkpoints]
print(key)
print(' in shapes=', [p.shape for p in tensors])
for pattern, kind in layer_kind.items():
if key.replace('.weight', '').endswith(pattern):
print(' kind=', kind)
if kind == 'ColumnParallelLinear':
with torch.no_grad():
merged = torch.cat(tensors, 0)
slice_size = merged.shape[0] // num_shards
for rank in range(num_shards):
output[rank][key] = merged[slice_size * rank: slice_size * (rank + 1),:].clone().detach()
elif kind in ('ParallelEmbedding', 'RowParallelLinear'):
with torch.no_grad():
merged = torch.cat(tensors, 1)
slice_size = merged.shape[1] // num_shards
for rank in range(num_shards):
output[rank][key] = merged[:,slice_size * rank: slice_size * (rank + 1)].clone().detach()
else:
for rank in range(num_shards):
output[rank][key] = tensors[0]
print(' out shapes=', [output[rank][key].shape for rank in range(num_shards)])
print()
break
else:
raise Exception('parameter name not recognized')
print('saving...')
os.makedirs(output_model_dir, exist_ok=True)
with open(os.path.join(output_model_dir, 'params.json'), 'w') as fp:
fp.write(json.dumps(params))
for rank in range(num_shards):
print(' ', rank)
torch.save(output[rank], os.path.join(output_model_dir, 'consolidated.%02d.pth' % rank))
print('done.')
@Qubitium
Copy link

Qubitium commented Mar 9, 2023

This script doesn't work for 65B models for some reason. It will complete the process but running them output strange tokens. 13-30B shards using this script has no issue. Not sure what makes 65B special causing the sharded model to run incorrectly.

I tried sharding 65B to 2x and 4x for execution on 4x A100 80GB without success. 13-30B has no issue with sharding.

@benob
Copy link
Author

benob commented Mar 11, 2023

I tested with 65B split in two shards and it worked fine.

@fabawi
Copy link

fabawi commented Mar 16, 2023

Why is the resharded file much larger? The 13B has 2 checkpoints totaling 26 GB. After consolidating into 1 file, it jumps to 39 GB

@emsi
Copy link

emsi commented Mar 20, 2023

This script doesn't work for 65B models for some reason. It will complete the process but running them output strange tokens. 13-30B shards using this script has no issue. Not sure what makes 65B special causing the sharded model to run incorrectly.

I tried sharding 65B to 2x and 4x for execution on 4x A100 80GB without success. 13-30B has no issue with sharding.

Interestingly I have similar problem (garbage tokens predicted) with 30B model. I'm using single GPU A100 with 80G VARM. The loaded model uses around 67G with batch size of 2.

Except largest model the consolidated files are larger than sharded (x1 stands for 1 shard)

25G     /data/LLaMA/13B
37G     /data/LLaMA/13Bx1
61G     /data/LLaMA/30B
76G     /data/LLaMA/30Bx1
122G    /data/LLaMA/65B
122G    /data/LLaMA/65Bx1

@emsi
Copy link

emsi commented Mar 20, 2023

I have succeeded with consolidating 30B model with this script though:
https://github.com/randaller/llama-chat/blob/main/merge-weights.py

@wangjiyang
Copy link

How long will it takes to run on A100?

@Nova-Rift
Copy link

Worked great! Thank you a bunch!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment