Skip to content

Instantly share code, notes, and snippets.

@rgtjf
Created April 10, 2024 02:32
Show Gist options
  • Save rgtjf/aa90fc37efe38ad773046623780a1026 to your computer and use it in GitHub Desktop.
Save rgtjf/aa90fc37efe38ad773046623780a1026 to your computer and use it in GitHub Desktop.
Convert a DeepSpeed checkpoint.
"""This script converts a DeepSpeed checkpoint from one format to another.
It requires specifying an input_folder and a target_folder before starting the
conversion. To determine the target folder, first run the script without checkpointing
using the target cluster.
The conversion process involves the following steps:
1. Building a linked matrix on the input DeepSpeed checkpoint to establish mappings
between tensor slices.
2. Merging the slice files based on the linked matrix.
3. Saving the common optimizer states from the input DeepSpeed checkpoint.
4. Copying tensors from the input checkpoint to overwrite the corresponding tensors in
the target checkpoint, based on a linked matrix built on the target checkpoint.
5. Validating the conversion by ensuring that the model parameters match the FP32 values
stored in the optimizers.
"""
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# nyonic Team
# DeepSpeed Team
import argparse
import copy
import glob
import itertools
import os
import re
import typing as t
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor
from functools import partial
import torch
import tqdm
from deepspeed.checkpoint import (
BASE_OPTIMIZER_STATE,
CAT_DIM,
OPTIMIZER_STATE_DICT,
PARAM,
PARAM_N_SUB_PARAMS,
PARAM_SHAPES,
PARAM_SLICE_MAPPINGS,
PARAMETER_TO_AVERAGE_PATTERNS,
PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0,
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS,
SINGLE_PARTITION_OF_FP32_GROUPS,
TP_REPLICATED_PARAMETER_PATTERNS,
VOCAB_TENSOR,
VOCABULARY_PARAMETER_PATTERNS,
DeepSpeedCheckpoint,
)
# from pprint import pprint
OPTIMIZER_STATE_STEP = "base_optimizer_state.state.0.step"
OPTIMIZER_STATE_PARAM_GROUPS = "base_optimizer_state.param_groups"
OFFSET_FILE = "_OFFSET_.pt"
def parse_arguments() -> argparse.Namespace:
"""Parses command-line arguments for the script.
Returns:
argparse.Namespace: An object containing the parsed command-line arguments.
- `input_folder`: The folder containing the input DeepSpeed checkpoint.
- `target_folder`: The folder where the target DeepSpeed checkpoint is stored.
- `output_folder`: The folder to store the output DeepSpeed checkpoint.
- `strict`: Flag for performing validity checks on the converted checkpoint.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_folder",
type=str,
required=True,
help="Input DeepSpeed Checkpoint folder",
)
parser.add_argument(
"--target_folder",
type=str,
required=True,
help="Target DeepSpeed Checkpoint folder",
)
parser.add_argument(
"--output_folder",
type=str,
required=True,
help="Output DeepSpeed checkpoint folder",
)
parser.add_argument(
"--no_strict",
dest="strict",
action="store_false",
help="Do not perform validity checks on converted checkpoint.",
)
parser.add_argument(
"--num_workers",
default=4,
type=int,
help="How many parallel processes for zero shards",
)
args = parser.parse_args()
print(f"args = {args}")
return args
def _create_checkpoint_paths(
base_folder: str, iteration: int, tp_degree: int, pp_degree: int
) -> list:
"""Creates paths for checkpoints based on TP and PP degrees.
Parameters:
base_folder (str): Base folder where the checkpoints are stored.
iteration (int): Iteration number of the checkpoint.
tp_degree (int): Degree of tensor parallelism.
pp_degree (int): Degree of pipeline parallelism.
Returns:
list: A list of lists containing checkpoint paths.
"""
path_list = []
iter_folder = f"iter_{iteration:07d}"
for i in range(0, tp_degree):
path_list.append([])
for j in range(0, pp_degree):
rank_folder = (
f"mp_rank_{i:02d}" if pp_degree == 1 else f"mp_rank_{i:02d}_{j:03d}"
)
ckpt_path = os.path.join(rank_folder, "model_optim_rng.pt")
path_list[i].append(os.path.join(base_folder, iter_folder, ckpt_path))
return path_list
def _save_checkpoint(file_path: str, chkpt_sd: dict) -> None:
"""Saves the checkpoint data to the specified file path.
Parameters:
file_path (str): The file path where the checkpoint data will be saved.
chkpt_sd (dict): The checkpoint data to be saved.
"""
dir, _ = os.path.split(file_path)
os.makedirs(dir, exist_ok=True)
torch.save(chkpt_sd, file_path)
def extract_zero_shards(
dir: str, ds_checkpoint: DeepSpeedCheckpoint, indices_3D: tuple
) -> None:
"""
Extracts zero shards for given indices in 3D (PP, TP, DP) format.
Parameters:
dir (str): The directory to save extracted shards.
ds_checkpoint (DeepSpeedCheckpoint): The DeepSpeedCheckpoint object to extract
shards from.
indices_3D (tuple): A tuple of indices in the order
(PP index, TP index, DP index).
"""
pp_index, tp_index, dp_index = indices_3D
sd = ds_checkpoint.get_zero_checkpoint_state(
pp_index=pp_index, tp_index=tp_index, dp_index=dp_index
)
optim_sd = sd[OPTIMIZER_STATE_DICT]
param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS]
# dict
state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"]
# list
fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS]
param_groups_cnt = len(state_groups)
for param_group_id in range(param_groups_cnt):
flat_state = {
"exp_avg": state_groups[param_group_id]["exp_avg"],
"exp_avg_sq": state_groups[param_group_id]["exp_avg_sq"],
"fp32": fp32_groups[param_group_id],
}
for name, fragment_mapping in param_slice_mappings[param_group_id].items():
# pprint(f"dpt{dp_index}{pp_index}{tp_index} {param_group_id} {name}
# => {fragment_mapping.start}:{fragment_mapping.numel}")
for state_key in flat_state.keys():
dump_param_fragment(
dir,
tp_index,
dp_index,
state_key,
flat_state[state_key],
name,
fragment_mapping.start,
fragment_mapping.numel,
)
def dump_param_fragment(
dir: str,
tp_index: int,
dp_index: int,
state_name: str,
state_flat_tensor: torch.Tensor,
param_name: str,
offset: int,
numel: int,
) -> None:
"""
Dumps a fragment of a parameter tensor to a file.
Parameters:
dir (str): The directory to save the parameter fragment.
tp_index (int): The tensor parallelism index.
dp_index (int): The data parallelism index.
state_name (str): The name of the state (e.g., "exp_avg").
state_flat_tensor (torch.Tensor): The flat tensor containing the state.
param_name (str): The name of the parameter.
offset (int): The offset in the flat tensor from where to start the fragment.
numel (int): The number of elements in the fragment.
"""
param_base_path = os.path.join(dir, param_name, str(tp_index))
os.makedirs(param_base_path, exist_ok=True)
counter = f"{dp_index:0>3d}"
path = os.path.join(param_base_path, f"{state_name}.{counter}")
print(f"{param_name}: {offset}: {numel} => {path}")
t = state_flat_tensor.narrow(0, offset, numel).clone()
_save_checkpoint(path, t)
def _merge_zero_shards(
param_base_path: str, state: str, tp_degree: int, slice_shape: tuple
) -> list:
"""Merges zero shards.
Merges zero shards for a given parameter state and TP degree, reshaping according
to slice shape.
Parameters:
param_base_path (str): Base path for the parameter to merge shards.
state (str): The state of the parameter (e.g., "fp32").
tp_degree (int): The tensor parallelism degree.
slice_shape (tuple): The shape of the tensor slice.
Returns:
list: A list of merged tensor slices.
"""
linked_matrix, reversed_linked_matrix = torch.load("./output/temp/linked_matrix.pt")
slices = []
for tp_index in range(tp_degree):
prefix_path = os.path.join(param_base_path, str(tp_index), f"{state}")
# paths = sorted(glob.glob(f"{prefix_path}.*"))
# shards = [torch.load(p) for p in paths]
shards = []
for partition, _numel in linked_matrix[os.path.basename(param_base_path)]:
path = f"{prefix_path}.{partition}"
shards.append(torch.load(path))
slice = torch.cat(shards, dim=0).reshape(slice_shape)
slices.append(slice)
return slices
def check_sharding_links(temp_dir: str) -> None:
"""
Checks and prints the shapes of parameters from sharding links.
Parameters:
temp_dir (str): The root path where sharding links are stored.
"""
paths = sorted(glob.glob(f"{temp_dir}/*"))
for path in paths:
prefix_path = os.path.join(path, "0", "fp32")
sharding_links = sorted(glob.glob(f"{prefix_path}.*"))
for sharding_link in sharding_links:
params = torch.load(sharding_link, map_location="cpu")
print(f"{path} => {os.path.basename(sharding_link)}, {params.shape}")
def merge_tp_slices(
ds_checkpoint: DeepSpeedCheckpoint,
dir: str,
slice_dir: str,
tp_degree: int,
name_and_shape: tuple,
) -> set:
"""
Merges tensor parallelism (TP) slices for a given parameter.
Parameters:
ds_checkpoint (DeepSpeedCheckpoint): The DeepSpeedCheckpoint used for merging.
dir (str): The directory to save the merged slices.
slice_dir (str): The directory containing the slices to merge.
tp_degree (int): The degree of tensor parallelism.
name_and_shape (tuple): containing the name of the parameter and its shape.
Returns:
set: A set of patterns that were not matched during the merge.
"""
name, shape = name_and_shape
slice_base_path = os.path.join(slice_dir, name)
param_base_path = os.path.join(dir, name)
# universal_checkpoint_info = ds_checkpoint.get_checkpoint_info(
# UNIVERSAL_CHECKPOINT_INFO
# )
universal_checkpoint_info = {}
replicated_parameters = universal_checkpoint_info.get(
TP_REPLICATED_PARAMETER_PATTERNS, []
)
parameters_to_average = universal_checkpoint_info.get(
PARAMETER_TO_AVERAGE_PATTERNS, []
)
parameters_with_row_parallelism = universal_checkpoint_info.get(
PARAMETER_WITH_ROW_PARALLELISM_PATTERNS, []
)
vocabulary_parameters = universal_checkpoint_info.get(
VOCABULARY_PARAMETER_PATTERNS, []
)
parameters_with_2_sub_params_cat_dim_0 = universal_checkpoint_info.get(
PARAMETER_WITH_2_SUB_PARAMS_CAT_DIM_0, []
)
unmatched_patterns = set(
replicated_parameters
+ parameters_to_average
+ parameters_with_row_parallelism
+ vocabulary_parameters
+ parameters_with_2_sub_params_cat_dim_0
)
def get_matched_pattern(patterns_: str, name_: str) -> str | None:
matched_ = [pattern_ for pattern_ in patterns_ if re.match(pattern_, name_)]
assert (
len(matched_) <= 1
), f"Got more than one matching patterns={matched_} for {name_}"
if matched_:
pattern_ = matched_[0]
unmatched_patterns.discard(pattern_)
return pattern_
return None
for state in ("fp32", "exp_avg", "exp_avg_sq"):
slices = _merge_zero_shards(slice_base_path, state, tp_degree, shape)
final_path = os.path.join(param_base_path, f"{state}.pt")
# print(f"Expected shape: {shape}")
# print(f"Fragment sizes:", list(frag.shape for frag in slices))
ckpt_dict = {}
if get_matched_pattern(replicated_parameters, name):
if len(slices) > 1:
assert all([slices[0].equal(other_slice) for other_slice in slices[1:]])
param = slices[0]
# print(f'replicate {name} using first slice')
elif get_matched_pattern(parameters_to_average, name):
param = sum(slices) / len(slices)
# print(f'merge {name} using average')
elif get_matched_pattern(parameters_with_2_sub_params_cat_dim_0, name):
cat_dim = 0
chunked_slices = [torch.chunk(s, 2, dim=cat_dim) for s in slices]
merged_chunks_0 = torch.cat([s[0] for s in chunked_slices], dim=cat_dim)
merged_chunks_1 = torch.cat([s[1] for s in chunked_slices], dim=cat_dim)
param = torch.cat([merged_chunks_0, merged_chunks_1], dim=cat_dim)
ckpt_dict[CAT_DIM] = cat_dim
ckpt_dict[PARAM_N_SUB_PARAMS] = 2
else:
cat_dim = (
1 if get_matched_pattern(parameters_with_row_parallelism, name) else 0
)
# print(f"merge {name} with CAT DIM: {cat_dim}")
param = torch.cat(slices, dim=cat_dim)
ckpt_dict[CAT_DIM] = cat_dim
if get_matched_pattern(vocabulary_parameters, name):
# print(f"Before {param.shape=}")
# strip padding
original_vocab_size = universal_checkpoint_info["original_vocab_size"]
param = param[:original_vocab_size, :]
ckpt_dict[VOCAB_TENSOR] = True
# print(f"After {param.shape=}")
# print(f"Final shape: {param.shape}")
ckpt_dict[PARAM] = param
_save_checkpoint(final_path, ckpt_dict)
return unmatched_patterns
def _do_parallel_work(
do_work: t.Callable, work_chunks: t.Any, num_workers: int
) -> list:
if num_workers > 1:
future_list = []
with ProcessPoolExecutor(max_workers=num_workers) as executor:
for work in work_chunks:
future_list.append(executor.submit(do_work, work))
results = []
for f in tqdm.tqdm(future_list):
results.append(f.result())
else:
# No parallel pass for unit testing
# We can't create child processes in tests
results = []
for batch in tqdm.tqdm(work_chunks):
res = [do_work(x) for x in batch]
results.extend(res)
return results
def _extract_zero_shard_files(
args: argparse.Namespace, ds_checkpoint: DeepSpeedCheckpoint, temp_dir: str
) -> None:
"""
Extracts zero shard files from a DeepSpeed checkpoint.
Parameters:
args (argparse.Namespace): Command-line arguments passed to the script.
ds_checkpoint (DeepSpeedCheckpoint): The DeepSpeedCheckpoint object.
temp_dir (str): The directory to save the temp objects.
"""
_3d_range_list = list(
itertools.product(
range(ds_checkpoint.pp_degree),
range(ds_checkpoint.tp_degree),
range(ds_checkpoint.dp_degree),
)
)
do_work = partial(extract_zero_shards, temp_dir, ds_checkpoint)
_do_parallel_work(do_work, _3d_range_list, args.num_workers)
check_sharding_links(temp_dir)
def build_linked_list(temp_dir: str, name_and_tensor: tuple) -> list:
"""Build linked list.
Args:
temp_dir: str.
name_and_tensor: tuple(tensor_name, mp_tensor).
"""
tensor_name, mp_tensor = name_and_tensor
fp32_paths = sorted(glob.glob(f"{temp_dir}/{tensor_name}/0/fp32.*"))
fp32_tensors = []
for fp32_path in fp32_paths:
partition_idx = fp32_path.split("/")[-1].split(".")[-1]
fp32_tensor = torch.load(os.path.join(fp32_path))
fp32_tensors.append((partition_idx, fp32_tensor))
flat_tensor = mp_tensor.flatten()
visited = [False for _ in range(len(fp32_tensors))]
offset = 0
linked_list = []
for _merge_cnt in range(len(fp32_tensors)):
for index, (partition_idx, fp32_tensor) in enumerate(fp32_tensors):
numel = fp32_tensor.numel()
if visited[index] is False:
cloned_flat_tensor = copy.deepcopy(flat_tensor)
cloned_flat_tensor[offset : offset + numel] = fp32_tensor.to(
cloned_flat_tensor
)
if torch.allclose(flat_tensor, cloned_flat_tensor):
visited[index] = True
linked_list.append((tensor_name, partition_idx, offset, numel))
offset += numel
print(linked_list)
return linked_list
def _build_linked_matrix(args: argparse.Namespace, temp_dir: str, mp_sd: dict) -> None:
"""
Build linked matrix.
Args:
args:
temp_dir:
mp_sd:
Example:
linked_matrix["model.transformer_encoder.layers.31.linear1.weight"]
Out:
[('018' 54715476), ('019', 12393388)]
reverse_linked_matrix['019']
Out:
[('model.transformer_encoder.layers.31.linear1.weight', 54715476, 12393388),
('model.transformer_encoder.layers.31.linear1.bias', 0, 16384),
('model.transformer_encoder.layers.31.linear2.weight', 0, 46882008)]
=> 54715476 is the offset of 12393388 elements in layers.31.linear1.weight
sd = ds_checkpoint.get_zero_checkpoint_state(
pp_index=0, tp_index=0, dp_index=19
)
sd['param_slice_mappings']:
Out:
[('model.transformer_encoder.layers.31.linear1.weight',
fragment_address(numel=12393388, start=0)),
('model.transformer_encoder.layers.31.linear1.bias',
fragment_address(numel=16384, start=12393388)),
('model.transformer_encoder.layers.31.linear2.weight',
fragment_address(numel=46882008, start=12409772))]
"""
work_chunks = mp_sd["module"].items() # list of (tensor_name, mp_tensor)
do_work = partial(build_linked_list, temp_dir)
linked_lists = _do_parallel_work(do_work, work_chunks, args.num_workers)
# change linked_lists into linked_matrix and reversed_linked_matrix
linked_matrix = defaultdict(list)
reversed_linked_matrix = defaultdict(list)
for linked_list in linked_lists:
for tensor_name, partition_idx, offset, numel in linked_list:
linked_matrix[tensor_name].append((partition_idx, numel))
reversed_linked_matrix[partition_idx].append((tensor_name, offset, numel))
torch.save(
(linked_matrix, reversed_linked_matrix),
os.path.join(temp_dir, "linked_matrix.pt"),
)
def _merge_tp_slice_files(
args: argparse.Namespace,
ds_checkpoint: DeepSpeedCheckpoint,
slice_shapes: dict,
temp_dir: str,
) -> None:
"""
Merges TP slice files into a single file for each parameter.
Parameters:
args (argparse.Namespace): Command-line arguments passed to the script.
ds_checkpoint (DeepSpeedCheckpoint): The DeepSpeedCheckpoint for merging TP
slices.
slice_shapes (dict): A dictionary mapping parameter names to their shapes.
temp_dir (str): The temporary directory where slice files are stored.
"""
work_chunks = list(slice_shapes.items())
zero_output_folder = os.path.join(args.output_folder, "zero")
do_work = partial(
merge_tp_slices,
ds_checkpoint,
zero_output_folder,
temp_dir,
ds_checkpoint.tp_degree,
)
unmatched_patterns_lists = _do_parallel_work(do_work, work_chunks, args.num_workers)
# verify that all patterns were used
# if a pattern was not used by any of the workers, then it was not used at all
# -> assert/alert
sets = [set(lst) for lst in unmatched_patterns_lists]
unmatched_patterns = list(set.intersection(*sets))
if args.strict:
assert (
not unmatched_patterns
), f"Unused patterns={unmatched_patterns} while merging tp slices"
elif unmatched_patterns:
print(f"Warning: Unused patterns={unmatched_patterns} while merging tp slices")
def _save_optimizer_state(
args: argparse.Namespace, ds_checkpoint: DeepSpeedCheckpoint
) -> None:
"""
Saves the optimizer state from a DeepSpeed checkpoint.
Parameters:
args (argparse.Namespace): Command-line arguments passed to the script.
ds_checkpoint (DeepSpeedCheckpoint): The DeepSpeedCheckpoint object to save
optimizer state from.
"""
sharded_states = [
BASE_OPTIMIZER_STATE,
PARAM_SLICE_MAPPINGS,
SINGLE_PARTITION_OF_FP32_GROUPS,
]
sd = ds_checkpoint.get_zero_checkpoint_state(pp_index=0, tp_index=0, dp_index=0)
optim_sd = sd[OPTIMIZER_STATE_DICT]
output_sd = {k: v for k, v in optim_sd.items() if k not in sharded_states}
output_sd[OPTIMIZER_STATE_STEP] = optim_sd[BASE_OPTIMIZER_STATE]["state"][0]["step"]
output_sd[OPTIMIZER_STATE_PARAM_GROUPS] = optim_sd[BASE_OPTIMIZER_STATE][
"param_groups"
]
zero_output_folder = os.path.join(args.output_folder, "zero")
output_file_path = os.path.join(zero_output_folder, "optimizer_state.pt")
_save_checkpoint(output_file_path, output_sd)
def copy_zero_shard(
args: argparse.Namespace, ds_checkpoint: DeepSpeedCheckpoint, indices_3D: tuple
) -> None:
"""
Copies a zero shard for given indices in 3D (PP, TP, DP) format.
Copies a zero shard for given indices in 3D (PP, TP, DP) format, adjusting for the
target checkpoint structure.
Parameters:
args (argparse.Namespace): Command-line arguments passed to the script.
ds_checkpoint (DeepSpeedCheckpoint): The DeepSpeedCheckpoint to copy shards.
indices_3D (tuple): A tuple of indices in the order (PP, TP, DP index).
"""
pp_index, tp_index, dp_index = indices_3D
sd = ds_checkpoint.get_zero_checkpoint_state(
pp_index=pp_index, tp_index=tp_index, dp_index=dp_index
)
linked_matrix, reversed_linked_matrix = torch.load(
f"{args.output_folder}/temp_tgt/linked_matrix.pt"
)
counter = f"{dp_index:0>3d}"
reversed_linked_order = reversed_linked_matrix[counter]
reversed_linked_order_dict = {v[0]: v for v in reversed_linked_order}
clone_sd = copy.deepcopy(sd)
optim_sd = sd[OPTIMIZER_STATE_DICT]
param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS]
# override_params:
# - "base_optimizer_state.step"
# - "base_optimizer_state.exp_avg"
# - "base_optimizer_state.exp_avg_sq"
# - "base_optimizer_state.param_groups"
# - "single_partition_of_fp32_groups"
# dict
state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"]
# list
fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS]
# step and param_groups
zero_output_folder = os.path.join(args.output_folder, "zero")
optimizer_state = torch.load(
os.path.join(zero_output_folder, "optimizer_state.pt"),
map_location=torch.device("cpu"),
)
optimizer_state_step = optimizer_state[OPTIMIZER_STATE_STEP]
optimizer_state_param_groups = optimizer_state[OPTIMIZER_STATE_PARAM_GROUPS]
param_groups_cnt = len(state_groups)
cloned_states = {}
for param_group_id in range(param_groups_cnt):
state = state_groups[param_group_id]
cloned_state = copy.deepcopy(state)
cloned_state["step"] = optimizer_state_step
cloned_state["exp_avg"] = []
cloned_state["exp_avg_sq"] = []
fp32s = []
print(param_slice_mappings[param_group_id])
for name, _fragment_mapping in param_slice_mappings[param_group_id].items():
tensor_name, offset, numel = reversed_linked_order_dict[name]
param_base_path = os.path.join(zero_output_folder, tensor_name)
print(f"{tensor_name}: {offset}: {numel} <= {param_base_path}")
exp_avg = torch.load(os.path.join(param_base_path, "exp_avg.pt"))[
"param"
].flatten()[offset : offset + numel]
exp_avg_sq = torch.load(os.path.join(param_base_path, "exp_avg_sq.pt"))[
"param"
].flatten()[offset : offset + numel]
fp32 = torch.load(os.path.join(param_base_path, "fp32.pt"))[
"param"
].flatten()[offset : offset + numel]
cloned_state["exp_avg"].append(exp_avg)
cloned_state["exp_avg_sq"].append(exp_avg_sq)
fp32s.append(fp32)
fp32_groups[param_group_id] = torch.cat(fp32s, dim=0)
cloned_state["exp_avg"] = torch.cat(cloned_state["exp_avg"], dim=0)
cloned_state["exp_avg_sq"] = torch.cat(cloned_state["exp_avg_sq"], dim=0)
cloned_states[param_group_id] = cloned_state
clone_sd[OPTIMIZER_STATE_DICT][BASE_OPTIMIZER_STATE].update(
{
"state": cloned_states,
"param_groups": optimizer_state_param_groups,
}
)
clone_sd[OPTIMIZER_STATE_DICT].update(
{
SINGLE_PARTITION_OF_FP32_GROUPS: fp32_groups,
}
)
output_file = ds_checkpoint.get_zero_files(
pp_index=pp_index, tp_index=tp_index, dp_index=dp_index
)[0]
path = os.path.join(args.output_folder, "output", os.path.basename(output_file))
print(output_file)
_save_checkpoint(path, clone_sd)
def _copy_zero_shard_files(
args: argparse.Namespace,
ds_checkpoint: DeepSpeedCheckpoint,
) -> None:
"""
Initiates the copying of zero shard files based on slice shapes.
Parameters:
args (argparse.Namespace): Command-line arguments passed to the script.
ds_checkpoint (DeepSpeedCheckpoint): The DeepSpeedCheckpoint to copy shards.
"""
_3d_range_list = list(
itertools.product(
range(ds_checkpoint.pp_degree),
range(ds_checkpoint.tp_degree),
range(ds_checkpoint.dp_degree),
)
)
do_work = partial(copy_zero_shard, args, ds_checkpoint)
_do_parallel_work(do_work, _3d_range_list, args.num_workers)
def check_mp_equal_to_fp32(args: argparse.Namespace) -> None:
"""
Check mp equal to fp32.
Parameters:
args (argparse.Namespace): Command-line arguments passed to the script.
"""
mp_sd = torch.load(
os.path.join(args.output_folder, "output", "mp_rank_00_model_states.pt"),
map_location=torch.device("cpu"),
)
zero_output_folder = os.path.join(args.output_folder, "zero")
tensor_name_paths = sorted(glob.glob(f"{zero_output_folder}/*"))
for tensor_name_path in tensor_name_paths:
if "model" not in tensor_name_path:
continue
tensor_name = os.path.basename(tensor_name_path)
fp32 = torch.load(os.path.join(tensor_name_path, "fp32.pt"))["param"].to(
mp_sd["module"][tensor_name]
)
torch.testing.assert_allclose(
fp32,
mp_sd["module"][tensor_name],
msg=f"{tensor_name}\nfp32: \n{fp32}\n"
f"mp_sd: \n{mp_sd['module'][tensor_name]}",
)
def make_zero_shard() -> None:
"""Make zero shard."""
input_folder = ""
paths = sorted(glob.glob(f"{input_folder}/*"))
for path in tqdm.tqdm(paths):
if "mp_rank_00_model_states" in path:
continue
sd = torch.load(path, map_location="cpu")
optim_sd = sd[OPTIMIZER_STATE_DICT]
# param_slice_mappings = optim_sd[PARAM_SLICE_MAPPINGS]
# dict
state_groups = optim_sd[BASE_OPTIMIZER_STATE]["state"]
# list
# fp32_groups = optim_sd[SINGLE_PARTITION_OF_FP32_GROUPS]
param_groups_cnt = len(state_groups)
for param_group_id in range(param_groups_cnt):
state = state_groups[param_group_id]
state["exp_avg"] = torch.zeros_like(state["exp_avg"])
state["exp_avg_sq"] = torch.zeros_like(state["exp_avg_sq"])
_save_checkpoint(path, sd)
def main() -> None:
"""Main entrypoint.
The main function to convert a DeepSpeed checkpoint to another DeepSpeed checkpoint
format.
This function orchestrates the process of parsing arguments, extracting zero shards,
merging TP slices, saving optimizer state, and copying zero shards to the target
structure.
"""
print("Convert DeepSpeed Checkpoint to DeepSpeed Checkpoint")
args = parse_arguments()
print(
f"Converting DeepSpeed checkpoint in {args.input_folder} and "
f"{args.target_folder} to DeepSpeed checkpoint in {args.output_folder}."
)
ds_checkpoint = DeepSpeedCheckpoint(args.input_folder)
print("*** 0. Load model checkpoints")
slice_shapes = []
for mp_rank_file in ds_checkpoint.mp_rank_files:
mp_sd = torch.load(mp_rank_file, map_location=torch.device("cpu"))
slice_shapes += mp_sd[PARAM_SHAPES]
# fix back to normal flat dict, merge duplicates for tp>1
slice_shapes = {k: v for d in slice_shapes for k, v in d.items()}
temp_dir = os.path.join(args.output_folder, "temp")
print("*** 1. Extracting ZeRO fragments")
_extract_zero_shard_files(args, ds_checkpoint, temp_dir)
_build_linked_matrix(args, temp_dir, mp_sd)
print("*** 2. Merging slices .....")
_merge_tp_slice_files(args, ds_checkpoint, slice_shapes, temp_dir)
print("*** 3. Saving common optimizer states")
_save_optimizer_state(args, ds_checkpoint)
print("*** 4. Coping ZeRO fragments like target")
tgt_temp_dir = os.path.join(args.output_folder, "temp_tgt")
tgt_ds_checkpoint = DeepSpeedCheckpoint(args.target_folder)
print("*** 4.1 Load target model checkpoints")
tgt_mp_sd = torch.load(
tgt_ds_checkpoint.mp_rank_files[0], map_location=torch.device("cpu")
)
print("*** 4.2 Extracting target ZeRO fragments")
_extract_zero_shard_files(args, tgt_ds_checkpoint, tgt_temp_dir)
_build_linked_matrix(args, tgt_temp_dir, tgt_mp_sd)
print("*** 4.3 Coping ZeRO fragments like target")
_copy_zero_shard_files(args, tgt_ds_checkpoint)
print("*** 5. Coping mp fragments like target")
for mp_rank_file in ds_checkpoint.mp_rank_files:
mp_sd = torch.load(mp_rank_file, map_location=torch.device("cpu"))
mp_sd["dp_world_size"] = tgt_mp_sd["dp_world_size"]
torch.save(
mp_sd,
os.path.join(args.output_folder, "output", os.path.basename(mp_rank_file)),
)
print("*** Done!")
check_mp_equal_to_fp32(args)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment