Skip to content

Instantly share code, notes, and snippets.

@briansemrau
Created February 21, 2023 06:21
Show Gist options
  • Save briansemrau/c68835edc88a0dea79b092bce4e1ee17 to your computer and use it in GitHub Desktop.
Save briansemrau/c68835edc88a0dea79b092bce4e1ee17 to your computer and use it in GitHub Desktop.
# This script was adapted from merge.py from the KoboldAI discord server.
# I believe the original author is concedo
import os
import gc
import json
import shutil
import resource
import torch
from itertools import zip_longest
diff_weight = 0.6#1.0
assert(diff_weight > 0.0 and diff_weight <= 1.0) # disable if you are brave
model_0_folder = 'gpt-j-6B-shardfp16' # base model
model_1_folder = 'gpt-jt-6B-v1-shardfp16'
model_2_folder = 'ppo_hh_gpt-j-shardfp16'
merged_model_folder = 'gpt-r-diff_0.6-6B'
# output = A + (B - C) * diff_weight
# A: model_1
# B: model_2 (compare model)
# C: model_0 (base model)
torch_map_location = 'cpu'
if (os.path.exists(merged_model_folder)):
if len(os.listdir(merged_model_folder)) != 0:
raise Exception(f'Non empty directory "{merged_model_folder}" already exists')
#print(f"[*] Merging models\n\t({round(model_1_ratio * 100, 2)} %) {model_1_folder.split('/')[-1]}\n\t({round(model_2_ratio * 100, 2)} %) {model_2_folder.split('/')[-1]}\n")
def format_size(num, suffix="B"):
for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]:
if abs(num) < 1024.0:
return f"{num:3.1f}{unit}{suffix}"
num /= 1024.0
return f"{num:.1f}Yi{suffix}"
model_0_files = [file for file in os.listdir(model_1_folder) if file.endswith('.bin')]
model_1_files = [file for file in os.listdir(model_1_folder) if file.endswith('.bin')]
model_2_files = [file for file in os.listdir(model_2_folder) if file.endswith('.bin')]
model_0_files.sort()
model_1_files.sort()
model_2_files.sort()
max_files_length = len(model_0_files)
if len(model_1_files) > max_files_length:
max_files_length = len(model_1_files)
if len(model_2_files) > max_files_length:
max_files_length = len(model_2_files)
model_files = [file for file in zip_longest(model_0_files, model_1_files, model_2_files)]
model_size_bytes = 0
for file in model_0_files:
model_size_bytes += os.path.getsize(f'{model_0_folder}/{file}')
bin_weight_map = {}
backlog_layers = {"model_1": {}, "model_2": {}, "model_02": {}, "model_01": {}}
folder_created = False
for model_file_idx, model_file in enumerate(model_files):
print(f'-- {model_file_idx + 1} / {len(model_files)} --')
diff_model = {}
merged_model = {}
model_file_idx += 1
model_0_file, model_1_file, model_2_file = model_file
model_0_layers = {}
model_1_layers = {}
model_2_layers = {}
if model_0_file is not None:
print('[*] Reading', f"{model_0_folder}/{model_0_file}")
model_0_layers = torch.load(f"{model_0_folder}/{model_0_file}", map_location=torch_map_location, weights_only=True)
if model_2_file is not None:
print('[*] Reading', f"{model_2_folder}/{model_2_file}")
model_2_layers = torch.load(f"{model_2_folder}/{model_2_file}", map_location=torch_map_location, weights_only=True)
model_0_layers.update(backlog_layers['model_02'])
model_2_layers.update(backlog_layers['model_2'])
backlog_layers['model_02'] = {}
backlog_layers['model_01'] = {}
backlog_layers['model_1'] = {}
backlog_layers['model_2'] = {}
# Diff
for backlog_layer in set(model_0_layers).symmetric_difference(set(model_2_layers)):
if backlog_layer in model_0_layers:
backlog_layers['model_02'][backlog_layer] = model_0_layers[backlog_layer]
if backlog_layer in model_2_layers:
backlog_layers['model_2'][backlog_layer] = model_2_layers[backlog_layer]
for common_layer in set(model_0_layers).intersection(set(model_2_layers)):
w_model_0 = model_0_layers[common_layer]
w_model_2 = model_2_layers[common_layer]
diff_model[common_layer] = (w_model_2 - w_model_0) * diff_weight if diff_weight != 1.0 else w_model_2 - w_model_0
del model_2_layers
gc.collect()
# Merge
if model_1_file is not None:
print('[*] Reading', f"{model_1_folder}/{model_1_file}")
model_1_layers = torch.load(f"{model_1_folder}/{model_1_file}", map_location=torch_map_location, weights_only=True)
model_0_layers.update(backlog_layers['model_01'])
model_1_layers.update(backlog_layers['model_1'])
for backlog_layer in set(model_1_layers).symmetric_difference(set(model_0_layers)):
if backlog_layer in model_1_layers:
backlog_layers['model_1'][backlog_layer] = model_1_layers[backlog_layer]
if backlog_layer in model_0_layers:
backlog_layers['model_01'][backlog_layer] = model_0_layers[backlog_layer]
del model_0_layers
gc.collect()
for common_layer in set(model_1_layers).intersection(set(diff_model)):
w_model_1 = model_1_layers[common_layer]
w_model_diff = diff_model[common_layer]
merged_model[common_layer] = w_model_1 + w_model_diff
bin_weight_map[common_layer] = f'pytorch_model-{(model_file_idx):05}-of-{max_files_length:05}.bin'
if not folder_created:
os.makedirs(merged_model_folder, exist_ok=True)
for file_to_copy in [file for file in os.listdir(model_0_folder) if (file.endswith('.json') or file.endswith('.txt')) and not file.endswith('.index.json')]:
shutil.copy(f'{model_0_folder}/{file_to_copy}', merged_model_folder)
folder_created = True
if len(model_files) == 1:
print(f'[*] Saving model: {merged_model_folder}/pytorch_model.bin')
torch.save(merged_model, f'{merged_model_folder}/pytorch_model.bin')
else:
print(f'[*] Saving shard: {merged_model_folder}/pytorch_model-{(model_file_idx):05}-of-{max_files_length:05}.bin')
torch.save(merged_model, f'{merged_model_folder}/pytorch_model-{(model_file_idx):05}-of-{max_files_length:05}.bin')
print('[*] Memory used:', format_size(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * 1024))
del model_1_layers
gc.collect()
if len(model_files) > 1:
print(f'[*] Saving bin weight map:', f'{merged_model_folder}/pytorch_model.bin.index.json')
with open(f'{merged_model_folder}/pytorch_model.bin.index.json', 'w+') as f:
f.write(json.dumps({"metadata": {"total_size": model_size_bytes}, "weight_map": bin_weight_map}, sort_keys=True, indent=4))
if (len(backlog_layers['model_1']) or len(backlog_layers['model_2']) or len(backlog_layers['model_02']) or len(backlog_layers['model_01'])):
print('[WARN] Not all layers were merged, model might be in a broken state')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment