Skip to content

Instantly share code, notes, and snippets.

@zarzen
Last active November 8, 2021 19:25
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zarzen/34a34c75109dfc39b37a8d2e27db20f1 to your computer and use it in GitHub Desktop.
Save zarzen/34a34c75109dfc39b37a8d2e27db20f1 to your computer and use it in GitHub Desktop.
deepspeed_loss_test

Usage

python3 test_diff_stages.py
import subprocess
import os
from collections import defaultdict
import numpy as np
# test deepspeed locally
file_path = os.path.abspath(__file__)
file_dir = os.path.dirname(file_path)
def main():
""""""
stages = [0, 1, 2, 3]
hidden_size = 16
test_model_script = os.path.join(file_dir, 'test_model.py')
for s in stages:
cmd = ['deepspeed', test_model_script, '--zero', str(s), '--hidden-size', str(hidden_size)]
env_vars = os.environ.copy()
subprocess.run(cmd, env=env_vars)
# read data
default_name_format = "/tmp/loss_log_stage{}.h{}.cgFalse.rcTrue.txt"
losses = dict()
for s in stages:
f = default_name_format.format(s, hidden_size)
vals = np.genfromtxt(f, delimiter=',')
losses[s] = vals
for i in stages[1:]:
allclose = np.allclose(losses[stages[0]], losses[i], rtol=1e-4, atol=1e-4)
print(f'stage{stages[0]}, stage{i}, losses all close {allclose}')
if not allclose:
for row1, row2 in zip(losses[stages[0]], losses[i]):
if not np.allclose(row1, row2):
print(f"loss diff {row1[1] - row2[1]}:: stage {stages[0]}: step{row1[0]}, loss {row1[1]}"
f" stage{i}: step{row2[0]}, loss {row2[1]}")
if __name__ == "__main__":
main()
import os
import math
import json
import argparse
import torch
from torch._C import device
import deepspeed
from torch import nn
from torch.utils.data.distributed import DistributedSampler
import apex.normalization
BertLayerNorm = apex.normalization.FusedLayerNorm
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import print_rank_0
class SimpleModel(torch.nn.Module):
def __init__(self, hidden_dim, empty_grad=False, zero=0):
rank = int(os.environ['RANK'])
print('seed:', 2222 + rank)
torch.random.manual_seed(2222 + rank)
super(SimpleModel, self).__init__()
self.local_rank = int(os.environ['LOCAL_RANK'])
print(f'local rank {self.local_rank}')
self.apply_fn_cnt = 0
self.linear = torch.nn.Linear(hidden_dim, hidden_dim, device=self.local_rank)
self.FinalLayerNorm = BertLayerNorm(hidden_dim, eps=1e-6)
mlp = [self.linear]
mlp.append(torch.nn.Linear(hidden_dim, hidden_dim//2, device=self.local_rank))
for _ in range(6):
l = torch.nn.Linear(hidden_dim//2, hidden_dim//2, device=self.local_rank)
mlp.append(l)
mlp.append(torch.nn.Linear(hidden_dim//2, hidden_dim, device=self.local_rank))
l = torch.nn.Linear(hidden_dim, hidden_dim, device=self.local_rank)
l.weight = self.linear.weight
l.bias = self.linear.bias
mlp.append(l)
mlp.append(self.FinalLayerNorm)
#if zero == 3:
# deepspeed.zero.register_external_parameter(self, self.linear.weight)
# deepspeed.zero.register_external_parameter(self, self.linear.bias)
self.mlp = nn.Sequential(*mlp)
if empty_grad:
self.layers2 = torch.nn.ModuleList([torch.nn.Linear(hidden_dim, hidden_dim, device=self.local_rank)])
self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
self.apply(self.init_bert_weights)
# self.apply(self.empty_init)
def forward(self, x, y):
hidden_dim = x
hidden_dim = self.mlp(hidden_dim)
return self.cross_entropy_loss(hidden_dim, y)
def empty_init(self, module):
pass
def init_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
# num_layers = self.config.num_hidden_layers
num_layers = 6
std = 0.05 / num_layers
if hasattr(module, 'bert_output_layer'):
# "Accounting for accumulation on the residual path"
#print("Accounting for accumulation on the residual path")
std = self.config.initializer_range / math.sqrt(
2.0 * num_layers)
module.weight.data.normal_(mean=0.0, std=std)
# elif isinstance(module, BertLayerNorm):
# module.bias.data.zero_()
# module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def create_config_from_dict(tmpdir, config_dict):
config_path = os.path.join(tmpdir, 'temp_config.json')
with open(config_path, 'w') as fd:
json.dump(config_dict, fd)
return config_path
def get_data_loader(model, total_samples, hidden_dim, device, dtype):
batch_size = model.train_micro_batch_size_per_gpu()
train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
train_label = torch.empty(total_samples,
dtype=torch.long,
device=device).random_(hidden_dim)
train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
sampler = DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
sampler=sampler)
return train_loader
def get_args(tmpdir, config_dict):
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int, default=0)
parser.add_argument('--zero', type=int, default=0)
parser.add_argument('--contiguous-gradients', default=False, action='store_true')
parser.add_argument('--reduce-scatter', default=True, type=bool)
parser.add_argument('--hidden-size', default=16, type=int)
parser.add_argument('--tmp', default='/tmp', help="temporary directory to save intermediate data")
parser.add_argument('--use-ds-context-manager', default=False, action='store_true')
args = parser.parse_args() #args=''
config_dict["zero_optimization"]["stage"] = args.zero
config_dict["zero_optimization"]['contiguous_gradients'] = args.contiguous_gradients
config_dict["zero_optimization"]["reduce_scatter"] = args.reduce_scatter
print('config_dict["zero_optimization"]', config_dict["zero_optimization"])
config_path = create_config_from_dict(tmpdir, config_dict)
args.deepspeed_config = config_path
return args
def print0(msg):
if torch.distributed.get_rank() == 0:
print(msg, flush=True)
def print_params(tag, model):
if torch.distributed.get_rank() == 0:
for n, p in model.named_parameters():
if hasattr(p, 'ds_tensor'):
print0("{} {}:{} ds_tensor {}".format(tag, n, p, p.ds_tensor))
else:
print0(f'{tag}, {n}, {p}')
config_dict = {
"train_batch_size": 64,
"train_micro_batch_size_per_gpu": 4,
"steps_per_print": 1,
"zero_allow_untested_optimizer": True,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.1,
"weight_decay": 0.01,
"bias_correction": True,
"eps": 1e-6
}
},
"gradient_clipping": 0.0,
"fp16": {
"enabled": False,
"initial_scale_power": 10
},
"zero_optimization": {
"stage": 0,
"overlap_comm": True,
"reduce_scatter": True,
"contiguous_gradients": False,
"reduce_bucket_size": 20,
"stage3_param_persistence_threshold": 1,
"comm_force_reduce_scatter": False
}
}
def save_params(args, module):
rank = os.environ['RANK']
ctx_manager = True if args.use_ds_context_manager and args.zero == 3 else False
param_json_file = f'/tmp/small_model_weights_rank{rank}_stage{args.zero}_ctx{ctx_manager}.json'
with open(param_json_file, 'w') as out_file:
weights = {}
for name, param in module.named_parameters():
if args.zero == 3:
param.all_gather()
weights[name] = param.data.cpu().numpy().tolist()
json.dump(weights, out_file, indent=2)
if __name__ == "__main__":
# "initial_scale_power": 15
args = get_args('/tmp/', config_dict)
hidden_dim = args.hidden_size
local_rank = args.local_rank
print(f'local rank {local_rank}')
torch.cuda.set_device(local_rank)
deepspeed.init_distributed(dist_backend='nccl')
if args.zero == 3 and args.use_ds_context_manager:
# FIXME: use zero.Init cause loss difference
with deepspeed.zero.Init(config=config_dict):
# print("split at init")
model = SimpleModel(hidden_dim, empty_grad=False, zero=args.zero)
else:
model = SimpleModel(hidden_dim, empty_grad=False, zero=args.zero)
model, _, _,_ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters(),
dist_init_required=True)
print_params('after initialize', model)
save_params(args, model)
torch.random.manual_seed(2222 + int(os.environ['RANK']))
data_type = torch.float if not config_dict['fp16']['enabled'] else torch.half
data_loader = get_data_loader(model=model,
total_samples=10000,
hidden_dim=hidden_dim,
device=model.device,
dtype=data_type)
def _get_log_filename(args):
return f"/tmp/loss_log_stage{args.zero}" \
f".h{args.hidden_size}" \
f".cg{args.contiguous_gradients}"\
f".rc{args.reduce_scatter}.txt"
with open(_get_log_filename(args), 'w') as log_file:
#print_params('pre-train', model)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
#if torch.distributed.get_rank() == 0 and model.is_gradient_accumulation_boundary():
model.backward(loss)
model.step()
if torch.distributed.get_rank() == 0 and model.is_gradient_accumulation_boundary():
print("{}, LOSS: {}".format(n, loss.item()))
log_file.write(f'{n},{loss.item()}\n')
#print_params('step={}'.format(n), model)
if n == 40: break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment