Skip to content

Instantly share code, notes, and snippets.

import torch
import deepspeed
from transformers import BertLayer, BertConfig
hf_config = BertConfig()
hf_layer = BertLayer(hf_config)
hf_layer.eval().to("cuda")
ds_config = deepspeed.DeepSpeedTransformerConfig(
batch_size=16,
hidden_size=hf_config.hidden_size,
intermediate_size=hf_config.intermediate_size,
heads=hf_config.num_attention_heads,
attn_dropout_ratio=0.1,
hidden_dropout_ratio=0.1,
num_hidden_layers=hf_config.num_hidden_layers,
initializer_range=hf_config.initializer_range,
local_rank=-1,
pre_layer_norm=False,
stochastic_mode=False,
huggingface=True,
)
ds_layer = deepspeed.DeepSpeedTransformerLayer(ds_config)
ds_state_dict = ds_layer.state_dict()
ds_layer.eval().to("cuda")
state_dict = {
"attention.self.query.weight": ds_state_dict["attn_qkvw"][:768, :],
"attention.self.key.weight": ds_state_dict["attn_qkvw"][768:1536, :],
"attention.self.value.weight": ds_state_dict["attn_qkvw"][1536:, :],
"attention.self.query.bias": ds_state_dict["attn_qkvb"][:768],
"attention.self.key.bias": ds_state_dict["attn_qkvb"][768:1536],
"attention.self.value.bias": ds_state_dict["attn_qkvb"][1536:],
"attention.output.dense.weight": ds_state_dict["attn_ow"],
"attention.output.dense.bias": ds_state_dict["attn_ob"],
"attention.output.LayerNorm.weight": ds_state_dict["attn_nw"],
"attention.output.LayerNorm.bias": ds_state_dict["attn_nb"],
"intermediate.dense.weight": ds_state_dict["inter_w"],
"intermediate.dense.bias": ds_state_dict["inter_b"],
"output.dense.weight": ds_state_dict["output_w"],
"output.dense.bias": ds_state_dict["output_b"],
"output.LayerNorm.weight": ds_state_dict["norm_w"],
"output.LayerNorm.bias": ds_state_dict["norm_b"],
}
hf_layer.load_state_dict(state_dict)
input_tensor = torch.rand((1, 16, hf_config.hidden_size)).to("cuda")
input_mask_tensor = torch.zeros((1, 16)).to("cuda")
hf_output = hf_layer(input_tensor, input_mask_tensor)[0]
print("output (huggingface):", hf_output[0, -1, -10:])
ds_output = ds_layer(input_tensor, input_mask_tensor)[0]
print("output (deepspeed):", ds_output[0, -1, -10:])
assert torch.allclose(hf_output, ds_output, atol=1e-3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment