Created
April 10, 2024 02:32
-
-
Save rgtjf/aa90fc37efe38ad773046623780a1026 to your computer and use it in GitHub Desktop.
Convert a DeepSpeed checkpoint.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""This script converts a DeepSpeed checkpoint from one format to another. | |
It requires specifying an input_folder and a target_folder before starting the | |
conversion. To determine the target folder, first run the script without checkpointing | |
using the target cluster. | |
The conversion process involves the following steps: | |
1. Building a linked matrix on the input DeepSpeed checkpoint to establish mappings | |
between tensor slices. | |
2. Merging the slice files based on the linked matrix. | |
3. Saving the common optimizer states from the input DeepSpeed checkpoint. | |
4. Copying tensors from the input checkpoint to overwrite the corresponding tensors in | |
the target checkpoint, based on a linked matrix built on the target checkpoint. | |
5. Validating the conversion by ensuring that the model parameters match the FP32 values | |
stored in the optimizers. | |
""" | |
# Copyright (c) Microsoft Corporation. | |
# SPDX-License-Identifier: Apache-2.0 | |
# nyonic Team | |
# DeepSpeed Team | |
import argparse | |
import copy | |
import glob | |
import itertools | |
import os | |
import re | |
import typing as t | |
from collections import defaultdict | |
from concurrent.futures import ProcessPoolExecutor | |
from functools import partial | |
import torch | |
import tqdm | |
from deepspeed.checkpoint import ( | |
BASE_OPTIMIZER_STATE, | |
CAT_DIM, | |
OPTIMIZER_STATE_DICT, | |
PARAM, | |
PARAM_N_SUB_PARAMS, | |
PARAM_SHAPES, | |
PARAM_SLICE_MAPPINGS, | |
PARAMETER_TO_AVERAGE_PATTERNS, | |
PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0, | |
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, | |
SINGLE_PARTITION_OF_FP32_GROUPS, | |
TP_REPLICATED_PARAMETER_PATTERNS, | |
VOCAB_TENSOR, | |
VOCABULARY_PARAMETER_PATTERNS, | |
DeepSpeedCheckpoint, | |
) | |
# from pprint import pprint | |
OPTIMIZER_STATE_STEP = "base_optimizer_state.state.0.step" | |
OPTIMIZER_STATE_PARAM_GROUPS = "base_optimizer_state.param_groups" | |
OFFSET_FILE = "_OFFSET_.pt" | |
def parse_arguments() -> argparse.Namespace: | |
"""Parses command-line arguments for the script. | |
Returns: | |
argparse.Namespace: An object containing the parsed command-line arguments. | |
- `input_folder`: The folder containing the input DeepSpeed checkpoint. | |
- `target_folder`: The folder where the target DeepSpeed checkpoint is stored. | |
- `output_folder`: The folder to store the output DeepSpeed checkpoint. | |
- `strict`: Flag for performing validity checks on the converted checkpoint. | |
""" | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--input_folder", | |
type=str, | |
required=True, | |
help="Input DeepSpeed Checkpoint folder", | |
) | |
parser.add_argument( | |
"--target_folder", | |
type=str, | |
required=True, | |
help="Target DeepSpeed Checkpoint folder", | |
) | |
parser.add_argument( | |
"--output_folder", | |
type=str, | |
required=True, | |
help="Output DeepSpeed checkpoint folder", | |
) | |
parser.add_argument( | |
"--no_strict", | |
dest="strict", | |
action="store_false", | |
help="Do not perform validity checks on converted checkpoint.", | |
) | |
parser.add_argument( | |
"--num_workers", | |
default=4, | |
type=int, | |
help="How many parallel processes for zero shards", | |
) | |
args = parser.parse_args() | |
print(f"args = {args}") | |
return args | |
def _create_checkpoint_paths( | |
base_folder: str, iteration: int, tp_degree: int, pp_degree: int | |
) -> list: | |
"""Creates paths for checkpoints based on TP and PP degrees. | |
Parameters: | |
base_folder (str): Base folder where the checkpoints are stored. | |
iteration (int): Iteration number of the checkpoint. | |
tp_degree (int): Degree of tensor parallelism. | |
pp_degree (int): Degree of pipeline parallelism. | |
Returns: | |
list: A list of lists containing checkpoint paths. | |
""" | |
path_list = [] | |
iter_folder = f"iter_{iteration:07d}" | |
for i in range(0, tp_degree): | |
path_list.append([]) | |
for j in range(0, pp_degree): | |
rank_folder = ( | |
f"mp_rank_{i:02d}" if pp_degree == 1 else f"mp_rank_{i:02d}_{j:03d}" | |
) | |
ckpt_path = os.path.join(rank_folder, "model_optim_rng.pt") | |
path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path)) | |
return path_list | |
def _save_checkpoint(file_path: str, chkpt_sd: dict) -> None: | |
"""Saves the checkpoint data to the specified file path. | |
Parameters: | |
file_path (str): The file path where the checkpoint data will be saved. | |
chkpt_sd (dict): The checkpoint data to be saved. | |
""" | |
dir, _ = os.path.split(file_path) | |
os.makedirs(dir, exist_ok=True) | |
torch.save(chkpt_sd, file_path) | |
def extract_zero_shards( | |
dir: str, ds_checkpoint: DeepSpeedCheckpoint, indices_3D: tuple | |
) -> None: | |
""" | |
Extracts zero shards for given indices in 3D (PP, TP, DP) format. | |
Parameters: | |
dir (str): The directory to save extracted shards. | |
ds_checkpoint (DeepSpeedCheckpoint): The DeepSpeedCheckpoint object to extract | |
shards from. | |
indices_3D (tuple): A tuple of indices in the order | |
(PP index, TP index, DP index). | |
""" | |
pp_index, tp_index, dp_index = indices_3D | |
sd = ds_checkpoint.get_zero_checkpoint_state( | |
pp_index=pp_index, tp_index=tp_index, dp_index=dp_index | |
) | |
optim_sd = sd[OPTIMIZER_STATE_DICT] | |
param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS] | |
# dict | |
state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"] | |
# list | |
fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS] | |
param_groups_cnt = len(state_groups) | |
for param_group_id in range(param_groups_cnt): | |
flat_state = { | |
"exp_avg": state_groups[param_group_id]["exp_avg"], | |
"exp_avg_sq": state_groups[param_group_id]["exp_avg_sq"], | |
"fp32": fp32_groups[param_group_id], | |
} | |
for name, fragment_mapping in param_slice_mappings[param_group_id].items(): | |
# pprint(f"dpt{dp_index}{pp_index}{tp_index} {param_group_id} {name} | |
# => {fragment_mapping.start}:{fragment_mapping.numel}") | |
for state_key in flat_state.keys(): | |
dump_param_fragment( | |
dir, | |
tp_index, | |
dp_index, | |
state_key, | |
flat_state[state_key], | |
name, | |
fragment_mapping.start, | |
fragment_mapping.numel, | |
) | |
def dump_param_fragment( | |
dir: str, | |
tp_index: int, | |
dp_index: int, | |
state_name: str, | |
state_flat_tensor: torch.Tensor, | |
param_name: str, | |
offset: int, | |
numel: int, | |
) -> None: | |
""" | |
Dumps a fragment of a parameter tensor to a file. | |
Parameters: | |
dir (str): The directory to save the parameter fragment. | |
tp_index (int): The tensor parallelism index. | |
dp_index (int): The data parallelism index. | |
state_name (str): The name of the state (e.g., "exp_avg"). | |
state_flat_tensor (torch.Tensor): The flat tensor containing the state. | |
param_name (str): The name of the parameter. | |
offset (int): The offset in the flat tensor from where to start the fragment. | |
numel (int): The number of elements in the fragment. | |
""" | |
param_base_path = os.path.join(dir, param_name, str(tp_index)) | |
os.makedirs(param_base_path, exist_ok=True) | |
counter = f"{dp_index:0>3d}" | |
path = os.path.join(param_base_path, f"{state_name}.{counter}") | |
print(f"{param_name}: {offset}: {numel} => {path}") | |
t = state_flat_tensor.narrow(0, offset, numel).clone() | |
_save_checkpoint(path, t) | |
def _merge_zero_shards( | |
param_base_path: str, state: str, tp_degree: int, slice_shape: tuple | |
) -> list: | |
"""Merges zero shards. | |
Merges zero shards for a given parameter state and TP degree, reshaping according | |
to slice shape. | |
Parameters: | |
param_base_path (str): Base path for the parameter to merge shards. | |
state (str): The state of the parameter (e.g., "fp32"). | |
tp_degree (int): The tensor parallelism degree. | |
slice_shape (tuple): The shape of the tensor slice. | |
Returns: | |
list: A list of merged tensor slices. | |
""" | |
linked_matrix, reversed_linked_matrix = torch.load("./output/temp/linked_matrix.pt") | |
slices = [] | |
for tp_index in range(tp_degree): | |
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}") | |
# paths = sorted(glob.glob(f"{prefix_path}.*")) | |
# shards = [torch.load(p) for p in paths] | |
shards = [] | |
for partition, _numel in linked_matrix[os.path.basename(param_base_path)]: | |
path = f"{prefix_path}.{partition}" | |
shards.append(torch.load(path)) | |
slice = torch.cat(shards, dim=0).reshape(slice_shape) | |
slices.append(slice) | |
return slices | |
def check_sharding_links(temp_dir: str) -> None: | |
""" | |
Checks and prints the shapes of parameters from sharding links. | |
Parameters: | |
temp_dir (str): The root path where sharding links are stored. | |
""" | |
paths = sorted(glob.glob(f"{temp_dir}/*")) | |
for path in paths: | |
prefix_path = os.path.join(path, "0", "fp32") | |
sharding_links = sorted(glob.glob(f"{prefix_path}.*")) | |
for sharding_link in sharding_links: | |
params = torch.load(sharding_link, map_location="cpu") | |
print(f"{path} => {os.path.basename(sharding_link)}, {params.shape}") | |
def merge_tp_slices( | |
ds_checkpoint: DeepSpeedCheckpoint, | |
dir: str, | |
slice_dir: str, | |
tp_degree: int, | |
name_and_shape: tuple, | |
) -> set: | |
""" | |
Merges tensor parallelism (TP) slices for a given parameter. | |
Parameters: | |
ds_checkpoint (DeepSpeedCheckpoint): The DeepSpeedCheckpoint used for merging. | |
dir (str): The directory to save the merged slices. | |
slice_dir (str): The directory containing the slices to merge. | |
tp_degree (int): The degree of tensor parallelism. | |
name_and_shape (tuple): containing the name of the parameter and its shape. | |
Returns: | |
set: A set of patterns that were not matched during the merge. | |
""" | |
name, shape = name_and_shape | |
slice_base_path = os.path.join(slice_dir, name) | |
param_base_path = os.path.join(dir, name) | |
# universal_checkpoint_info = ds_checkpoint.get_checkpoint_info( | |
# UNIVERSAL_CHECKPOINT_INFO | |
# ) | |
universal_checkpoint_info = {} | |
replicated_parameters = universal_checkpoint_info.get( | |
TP_REPLICATED_PARAMETER_PATTERNS, [] | |
) | |
parameters_to_average = universal_checkpoint_info.get( | |
PARAMETER_TO_AVERAGE_PATTERNS, [] | |
) | |
parameters_with_row_parallelism = universal_checkpoint_info.get( | |
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, [] | |
) | |
vocabulary_parameters = universal_checkpoint_info.get( | |
VOCABULARY_PARAMETER_PATTERNS, [] | |
) | |
parameters_with_2_sub_params_cat_dim_0 = universal_checkpoint_info.get( | |
PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0, [] | |
) | |
unmatched_patterns = set( | |
replicated_parameters | |
+ parameters_to_average | |
+ parameters_with_row_parallelism | |
+ vocabulary_parameters | |
+ parameters_with_2_sub_params_cat_dim_0 | |
) | |
def get_matched_pattern(patterns_: str, name_: str) -> str | None: | |
matched_ = [pattern_ for pattern_ in patterns_ if re.match(pattern_, name_)] | |
assert ( | |
len(matched_) <= 1 | |
), f"Got more than one matching patterns={matched_} for {name_}" | |
if matched_: | |
pattern_ = matched_[0] | |
unmatched_patterns.discard(pattern_) | |
return pattern_ | |
return None | |
for state in ("fp32", "exp_avg", "exp_avg_sq"): | |
slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape) | |
final_path = os.path.join(param_base_path, f"{state}.pt") | |
# print(f"Expected shape: {shape}") | |
# print(f"Fragment sizes:", list(frag.shape for frag in slices)) | |
ckpt_dict = {} | |
if get_matched_pattern(replicated_parameters, name): | |
if len(slices) > 1: | |
assert all([slices[0].equal(other_slice) for other_slice in slices[1:]]) | |
param = slices[0] | |
# print(f'replicate {name} using first slice') | |
elif get_matched_pattern(parameters_to_average, name): | |
param = sum(slices) / len(slices) | |
# print(f'merge {name} using average') | |
elif get_matched_pattern(parameters_with_2_sub_params_cat_dim_0, name): | |
cat_dim = 0 | |
chunked_slices = [torch.chunk(s, 2, dim=cat_dim) for s in slices] | |
merged_chunks_0 = torch.cat([s[0] for s in chunked_slices], dim=cat_dim) | |
merged_chunks_1 = torch.cat([s[1] for s in chunked_slices], dim=cat_dim) | |
param = torch.cat([merged_chunks_0, merged_chunks_1], dim=cat_dim) | |
ckpt_dict[CAT_DIM] = cat_dim | |
ckpt_dict[PARAM_N_SUB_PARAMS] = 2 | |
else: | |
cat_dim = ( | |
1 if get_matched_pattern(parameters_with_row_parallelism, name) else 0 | |
) | |
# print(f"merge {name} with CAT DIM: {cat_dim}") | |
param = torch.cat(slices, dim=cat_dim) | |
ckpt_dict[CAT_DIM] = cat_dim | |
if get_matched_pattern(vocabulary_parameters, name): | |
# print(f"Before {param.shape=}") | |
# strip padding | |
original_vocab_size = universal_checkpoint_info["original_vocab_size"] | |
param = param[:original_vocab_size, :] | |
ckpt_dict[VOCAB_TENSOR] = True | |
# print(f"After {param.shape=}") | |
# print(f"Final shape: {param.shape}") | |
ckpt_dict[PARAM] = param | |
_save_checkpoint(final_path, ckpt_dict) | |
return unmatched_patterns | |
def _do_parallel_work( | |
do_work: t.Callable, work_chunks: t.Any, num_workers: int | |
) -> list: | |
if num_workers > 1: | |
future_list = [] | |
with ProcessPoolExecutor(max_workers=num_workers) as executor: | |
for work in work_chunks: | |
future_list.append(executor.submit(do_work, work)) | |
results = [] | |
for f in tqdm.tqdm(future_list): | |
results.append(f.result()) | |
else: | |
# No parallel pass for unit testing | |
# We can't create child processes in tests | |
results = [] | |
for batch in tqdm.tqdm(work_chunks): | |
res = [do_work(x) for x in batch] | |
results.extend(res) | |
return results | |
def _extract_zero_shard_files( | |
args: argparse.Namespace, ds_checkpoint: DeepSpeedCheckpoint, temp_dir: str | |
) -> None: | |
""" | |
Extracts zero shard files from a DeepSpeed checkpoint. | |
Parameters: | |
args (argparse.Namespace): Command-line arguments passed to the script. | |
ds_checkpoint (DeepSpeedCheckpoint): The DeepSpeedCheckpoint object. | |
temp_dir (str): The directory to save the temp objects. | |
""" | |
_3d_range_list = list( | |
itertools.product( | |
range(ds_checkpoint.pp_degree), | |
range(ds_checkpoint.tp_degree), | |
range(ds_checkpoint.dp_degree), | |
) | |
) | |
do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint) | |
_do_parallel_work(do_work, _3d_range_list, args.num_workers) | |
check_sharding_links(temp_dir) | |
def build_linked_list(temp_dir: str, name_and_tensor: tuple) -> list: | |
"""Build linked list. | |
Args: | |
temp_dir: str. | |
name_and_tensor: tuple(tensor_name, mp_tensor). | |
""" | |
tensor_name, mp_tensor = name_and_tensor | |
fp32_paths = sorted(glob.glob(f"{temp_dir}/{tensor_name}/0/fp32.*")) | |
fp32_tensors = [] | |
for fp32_path in fp32_paths: | |
partition_idx = fp32_path.split("/")[-1].split(".")[-1] | |
fp32_tensor = torch.load(os.path.join(fp32_path)) | |
fp32_tensors.append((partition_idx, fp32_tensor)) | |
flat_tensor = mp_tensor.flatten() | |
visited = [False for _ in range(len(fp32_tensors))] | |
offset = 0 | |
linked_list = [] | |
for _merge_cnt in range(len(fp32_tensors)): | |
for index, (partition_idx, fp32_tensor) in enumerate(fp32_tensors): | |
numel = fp32_tensor.numel() | |
if visited[index] is False: | |
cloned_flat_tensor = copy.deepcopy(flat_tensor) | |
cloned_flat_tensor[offset : offset + numel] = fp32_tensor.to( | |
cloned_flat_tensor | |
) | |
if torch.allclose(flat_tensor, cloned_flat_tensor): | |
visited[index] = True | |
linked_list.append((tensor_name, partition_idx, offset, numel)) | |
offset += numel | |
print(linked_list) | |
return linked_list | |
def _build_linked_matrix(args: argparse.Namespace, temp_dir: str, mp_sd: dict) -> None: | |
""" | |
Build linked matrix. | |
Args: | |
args: | |
temp_dir: | |
mp_sd: | |
Example: | |
linked_matrix["model.transformer_encoder.layers.31.linear1.weight"] | |
Out: | |
[('018' 54715476), ('019', 12393388)] | |
reverse_linked_matrix['019'] | |
Out: | |
[('model.transformer_encoder.layers.31.linear1.weight', 54715476, 12393388), | |
('model.transformer_encoder.layers.31.linear1.bias', 0, 16384), | |
('model.transformer_encoder.layers.31.linear2.weight', 0, 46882008)] | |
=> 54715476 is the offset of 12393388 elements in layers.31.linear1.weight | |
sd = ds_checkpoint.get_zero_checkpoint_state( | |
pp_index=0, tp_index=0, dp_index=19 | |
) | |
sd['param_slice_mappings']: | |
Out: | |
[('model.transformer_encoder.layers.31.linear1.weight', | |
fragment_address(numel=12393388, start=0)), | |
('model.transformer_encoder.layers.31.linear1.bias', | |
fragment_address(numel=16384, start=12393388)), | |
('model.transformer_encoder.layers.31.linear2.weight', | |
fragment_address(numel=46882008, start=12409772))] | |
""" | |
work_chunks = mp_sd["module"].items() # list of (tensor_name, mp_tensor) | |
do_work = partial(build_linked_list, temp_dir) | |
linked_lists = _do_parallel_work(do_work, work_chunks, args.num_workers) | |
# change linked_lists into linked_matrix and reversed_linked_matrix | |
linked_matrix = defaultdict(list) | |
reversed_linked_matrix = defaultdict(list) | |
for linked_list in linked_lists: | |
for tensor_name, partition_idx, offset, numel in linked_list: | |
linked_matrix[tensor_name].append((partition_idx, numel)) | |
reversed_linked_matrix[partition_idx].append((tensor_name, offset, numel)) | |
torch.save( | |
(linked_matrix, reversed_linked_matrix), | |
os.path.join(temp_dir, "linked_matrix.pt"), | |
) | |
def _merge_tp_slice_files( | |
args: argparse.Namespace, | |
ds_checkpoint: DeepSpeedCheckpoint, | |
slice_shapes: dict, | |
temp_dir: str, | |
) -> None: | |
""" | |
Merges TP slice files into a single file for each parameter. | |
Parameters: | |
args (argparse.Namespace): Command-line arguments passed to the script. | |
ds_checkpoint (DeepSpeedCheckpoint): The DeepSpeedCheckpoint for merging TP | |
slices. | |
slice_shapes (dict): A dictionary mapping parameter names to their shapes. | |
temp_dir (str): The temporary directory where slice files are stored. | |
""" | |
work_chunks = list(slice_shapes.items()) | |
zero_output_folder = os.path.join(args.output_folder, "zero") | |
do_work = partial( | |
merge_tp_slices, | |
ds_checkpoint, | |
zero_output_folder, | |
temp_dir, | |
ds_checkpoint.tp_degree, | |
) | |
unmatched_patterns_lists = _do_parallel_work(do_work, work_chunks, args.num_workers) | |
# verify that all patterns were used | |
# if a pattern was not used by any of the workers, then it was not used at all | |
# -> assert/alert | |
sets = [set(lst) for lst in unmatched_patterns_lists] | |
unmatched_patterns = list(set.intersection(*sets)) | |
if args.strict: | |
assert ( | |
not unmatched_patterns | |
), f"Unused patterns={unmatched_patterns} while merging tp slices" | |
elif unmatched_patterns: | |
print(f"Warning: Unused patterns={unmatched_patterns} while merging tp slices") | |
def _save_optimizer_state( | |
args: argparse.Namespace, ds_checkpoint: DeepSpeedCheckpoint | |
) -> None: | |
""" | |
Saves the optimizer state from a DeepSpeed checkpoint. | |
Parameters: | |
args (argparse.Namespace): Command-line arguments passed to the script. | |
ds_checkpoint (DeepSpeedCheckpoint): The DeepSpeedCheckpoint object to save | |
optimizer state from. | |
""" | |
sharded_states = [ | |
BASE_OPTIMIZER_STATE, | |
PARAM_SLICE_MAPPINGS, | |
SINGLE_PARTITION_OF_FP32_GROUPS, | |
] | |
sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=0, tp_index=0, dp_index=0) | |
optim_sd = sd[OPTIMIZER_STATE_DICT] | |
output_sd = {k: v for k, v in optim_sd.items() if k not in sharded_states} | |
output_sd[OPTIMIZER_STATE_STEP] = optim_sd[BASE_OPTIMIZER_STATE]["state"][0]["step"] | |
output_sd[OPTIMIZER_STATE_PARAM_GROUPS] = optim_sd[BASE_OPTIMIZER_STATE][ | |
"param_groups" | |
] | |
zero_output_folder = os.path.join(args.output_folder, "zero") | |
output_file_path = os.path.join(zero_output_folder, "optimizer_state.pt") | |
_save_checkpoint(output_file_path, output_sd) | |
def copy_zero_shard( | |
args: argparse.Namespace, ds_checkpoint: DeepSpeedCheckpoint, indices_3D: tuple | |
) -> None: | |
""" | |
Copies a zero shard for given indices in 3D (PP, TP, DP) format. | |
Copies a zero shard for given indices in 3D (PP, TP, DP) format, adjusting for the | |
target checkpoint structure. | |
Parameters: | |
args (argparse.Namespace): Command-line arguments passed to the script. | |
ds_checkpoint (DeepSpeedCheckpoint): The DeepSpeedCheckpoint to copy shards. | |
indices_3D (tuple): A tuple of indices in the order (PP, TP, DP index). | |
""" | |
pp_index, tp_index, dp_index = indices_3D | |
sd = ds_checkpoint.get_zero_checkpoint_state( | |
pp_index=pp_index, tp_index=tp_index, dp_index=dp_index | |
) | |
linked_matrix, reversed_linked_matrix = torch.load( | |
f"{args.output_folder}/temp_tgt/linked_matrix.pt" | |
) | |
counter = f"{dp_index:0>3d}" | |
reversed_linked_order = reversed_linked_matrix[counter] | |
reversed_linked_order_dict = {v[0]: v for v in reversed_linked_order} | |
clone_sd = copy.deepcopy(sd) | |
optim_sd = sd[OPTIMIZER_STATE_DICT] | |
param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS] | |
# override_params: | |
# - "base_optimizer_state.step" | |
# - "base_optimizer_state.exp_avg" | |
# - "base_optimizer_state.exp_avg_sq" | |
# - "base_optimizer_state.param_groups" | |
# - "single_partition_of_fp32_groups" | |
# dict | |
state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"] | |
# list | |
fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS] | |
# step and param_groups | |
zero_output_folder = os.path.join(args.output_folder, "zero") | |
optimizer_state = torch.load( | |
os.path.join(zero_output_folder, "optimizer_state.pt"), | |
map_location=torch.device("cpu"), | |
) | |
optimizer_state_step = optimizer_state[OPTIMIZER_STATE_STEP] | |
optimizer_state_param_groups = optimizer_state[OPTIMIZER_STATE_PARAM_GROUPS] | |
param_groups_cnt = len(state_groups) | |
cloned_states = {} | |
for param_group_id in range(param_groups_cnt): | |
state = state_groups[param_group_id] | |
cloned_state = copy.deepcopy(state) | |
cloned_state["step"] = optimizer_state_step | |
cloned_state["exp_avg"] = [] | |
cloned_state["exp_avg_sq"] = [] | |
fp32s = [] | |
print(param_slice_mappings[param_group_id]) | |
for name, _fragment_mapping in param_slice_mappings[param_group_id].items(): | |
tensor_name, offset, numel = reversed_linked_order_dict[name] | |
param_base_path = os.path.join(zero_output_folder, tensor_name) | |
print(f"{tensor_name}: {offset}: {numel} <= {param_base_path}") | |
exp_avg = torch.load(os.path.join(param_base_path, "exp_avg.pt"))[ | |
"param" | |
].flatten()[offset : offset + numel] | |
exp_avg_sq = torch.load(os.path.join(param_base_path, "exp_avg_sq.pt"))[ | |
"param" | |
].flatten()[offset : offset + numel] | |
fp32 = torch.load(os.path.join(param_base_path, "fp32.pt"))[ | |
"param" | |
].flatten()[offset : offset + numel] | |
cloned_state["exp_avg"].append(exp_avg) | |
cloned_state["exp_avg_sq"].append(exp_avg_sq) | |
fp32s.append(fp32) | |
fp32_groups[param_group_id] = torch.cat(fp32s, dim=0) | |
cloned_state["exp_avg"] = torch.cat(cloned_state["exp_avg"], dim=0) | |
cloned_state["exp_avg_sq"] = torch.cat(cloned_state["exp_avg_sq"], dim=0) | |
cloned_states[param_group_id] = cloned_state | |
clone_sd[OPTIMIZER_STATE_DICT][BASE_OPTIMIZER_STATE].update( | |
{ | |
"state": cloned_states, | |
"param_groups": optimizer_state_param_groups, | |
} | |
) | |
clone_sd[OPTIMIZER_STATE_DICT].update( | |
{ | |
SINGLE_PARTITION_OF_FP32_GROUPS: fp32_groups, | |
} | |
) | |
output_file = ds_checkpoint.get_zero_files( | |
pp_index=pp_index, tp_index=tp_index, dp_index=dp_index | |
)[0] | |
path = os.path.join(args.output_folder, "output", os.path.basename(output_file)) | |
print(output_file) | |
_save_checkpoint(path, clone_sd) | |
def _copy_zero_shard_files( | |
args: argparse.Namespace, | |
ds_checkpoint: DeepSpeedCheckpoint, | |
) -> None: | |
""" | |
Initiates the copying of zero shard files based on slice shapes. | |
Parameters: | |
args (argparse.Namespace): Command-line arguments passed to the script. | |
ds_checkpoint (DeepSpeedCheckpoint): The DeepSpeedCheckpoint to copy shards. | |
""" | |
_3d_range_list = list( | |
itertools.product( | |
range(ds_checkpoint.pp_degree), | |
range(ds_checkpoint.tp_degree), | |
range(ds_checkpoint.dp_degree), | |
) | |
) | |
do_work = partial(copy_zero_shard, args, ds_checkpoint) | |
_do_parallel_work(do_work, _3d_range_list, args.num_workers) | |
def check_mp_equal_to_fp32(args: argparse.Namespace) -> None: | |
""" | |
Check mp equal to fp32. | |
Parameters: | |
args (argparse.Namespace): Command-line arguments passed to the script. | |
""" | |
mp_sd = torch.load( | |
os.path.join(args.output_folder, "output", "mp_rank_00_model_states.pt"), | |
map_location=torch.device("cpu"), | |
) | |
zero_output_folder = os.path.join(args.output_folder, "zero") | |
tensor_name_paths = sorted(glob.glob(f"{zero_output_folder}/*")) | |
for tensor_name_path in tensor_name_paths: | |
if "model" not in tensor_name_path: | |
continue | |
tensor_name = os.path.basename(tensor_name_path) | |
fp32 = torch.load(os.path.join(tensor_name_path, "fp32.pt"))["param"].to( | |
mp_sd["module"][tensor_name] | |
) | |
torch.testing.assert_allclose( | |
fp32, | |
mp_sd["module"][tensor_name], | |
msg=f"{tensor_name}\nfp32: \n{fp32}\n" | |
f"mp_sd: \n{mp_sd['module'][tensor_name]}", | |
) | |
def make_zero_shard() -> None: | |
"""Make zero shard.""" | |
input_folder = "" | |
paths = sorted(glob.glob(f"{input_folder}/*")) | |
for path in tqdm.tqdm(paths): | |
if "mp_rank_00_model_states" in path: | |
continue | |
sd = torch.load(path, map_location="cpu") | |
optim_sd = sd[OPTIMIZER_STATE_DICT] | |
# param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS] | |
# dict | |
state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"] | |
# list | |
# fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS] | |
param_groups_cnt = len(state_groups) | |
for param_group_id in range(param_groups_cnt): | |
state = state_groups[param_group_id] | |
state["exp_avg"] = torch.zeros_like(state["exp_avg"]) | |
state["exp_avg_sq"] = torch.zeros_like(state["exp_avg_sq"]) | |
_save_checkpoint(path, sd) | |
def main() -> None: | |
"""Main entrypoint. | |
The main function to convert a DeepSpeed checkpoint to another DeepSpeed checkpoint | |
format. | |
This function orchestrates the process of parsing arguments, extracting zero shards, | |
merging TP slices, saving optimizer state, and copying zero shards to the target | |
structure. | |
""" | |
print("Convert DeepSpeed Checkpoint to DeepSpeed Checkpoint") | |
args = parse_arguments() | |
print( | |
f"Converting DeepSpeed checkpoint in {args.input_folder} and " | |
f"{args.target_folder} to DeepSpeed checkpoint in {args.output_folder}." | |
) | |
ds_checkpoint = DeepSpeedCheckpoint(args.input_folder) | |
print("*** 0. Load model checkpoints") | |
slice_shapes = [] | |
for mp_rank_file in ds_checkpoint.mp_rank_files: | |
mp_sd = torch.load(mp_rank_file, map_location=torch.device("cpu")) | |
slice_shapes += mp_sd[PARAM_SHAPES] | |
# fix back to normal flat dict, merge duplicates for tp>1 | |
slice_shapes = {k: v for d in slice_shapes for k, v in d.items()} | |
temp_dir = os.path.join(args.output_folder, "temp") | |
print("*** 1. Extracting ZeRO fragments") | |
_extract_zero_shard_files(args, ds_checkpoint, temp_dir) | |
_build_linked_matrix(args, temp_dir, mp_sd) | |
print("*** 2. Merging slices .....") | |
_merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir) | |
print("*** 3. Saving common optimizer states") | |
_save_optimizer_state(args, ds_checkpoint) | |
print("*** 4. Coping ZeRO fragments like target") | |
tgt_temp_dir = os.path.join(args.output_folder, "temp_tgt") | |
tgt_ds_checkpoint = DeepSpeedCheckpoint(args.target_folder) | |
print("*** 4.1 Load target model checkpoints") | |
tgt_mp_sd = torch.load( | |
tgt_ds_checkpoint.mp_rank_files[0], map_location=torch.device("cpu") | |
) | |
print("*** 4.2 Extracting target ZeRO fragments") | |
_extract_zero_shard_files(args, tgt_ds_checkpoint, tgt_temp_dir) | |
_build_linked_matrix(args, tgt_temp_dir, tgt_mp_sd) | |
print("*** 4.3 Coping ZeRO fragments like target") | |
_copy_zero_shard_files(args, tgt_ds_checkpoint) | |
print("*** 5. Coping mp fragments like target") | |
for mp_rank_file in ds_checkpoint.mp_rank_files: | |
mp_sd = torch.load(mp_rank_file, map_location=torch.device("cpu")) | |
mp_sd["dp_world_size"] = tgt_mp_sd["dp_world_size"] | |
torch.save( | |
mp_sd, | |
os.path.join(args.output_folder, "output", os.path.basename(mp_rank_file)), | |
) | |
print("*** Done!") | |
check_mp_equal_to_fp32(args) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment