Skip to content

Instantly share code, notes, and snippets.

@maxidl
Last active April 21, 2023 11:57
Show Gist options
  • Save maxidl/a1b0dd71a72e694531106deb1b1a2ca2 to your computer and use it in GitHub Desktop.
Save maxidl/a1b0dd71a72e694531106deb1b1a2ca2 to your computer and use it in GitHub Desktop.
from functools import partial
import types
import torch
from typing import List, Optional, Tuple, Union, Dict
import transformers
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import logging as hf_logging
logger = hf_logging.get_logger(__name__)
"""
make the llama model run in model parallel (pipeline parallel) mode across multiple devices
"""
def llama_model_parallel_forward(
self,
layer2device: Dict[int, torch.device],
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device)
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
# move inputs to layer to correct device
hidden_states = hidden_states.to(layer2device[idx])
attention_mask = attention_mask.to(layer2device[idx])
position_ids = position_ids.to(layer2device[idx])
if past_key_values is not None:
past_key_value = past_key_value.to(layer2device[idx])
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
# move last hidden states back to first device
hidden_states = hidden_states.to(layer2device[0])
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def make_model_parallel_llama(model: transformers.LlamaModel, devices: List[torch.device]):
num_layers = len(model.layers)
num_devices = len(devices)
layer2device = {
n.item(): devices[i] for i, device_layers in enumerate(torch.arange(0, num_layers).chunk(num_devices)) for n in device_layers
}
for i, layer in enumerate(model.layers):
layer.to(layer2device[i])
torch.cuda.empty_cache() # clear cache to free memory after moving each layer (useful if model is on another gpu device already)
model.forward = types.MethodType(partial(llama_model_parallel_forward, layer2device=layer2device), model)
return model
# model = transformers.LlamaForCausalLM.from_pretrained(
# model_args.model_name_or_path,
# torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32,
# )
# model.train()
# print(model.dtype)
# devices = [torch.device("cuda:0"), torch.device("cuda:1")]
# # test model on 1 gpu
# model.to(devices[0])
# with torch.inference_mode():
# orig_outputs = model(torch.ones((1, 128), dtype=torch.long).to(devices[0])).logits.cpu()
# torch.cuda.empty_cache()
# # test model on 2 gpus
# from model_parallel_llama import make_model_parallel_llama
# model.model = make_model_parallel_llama(model.model, devices)
# with torch.inference_mode():
# mp_forward_outputs = model(torch.ones((1, 128), dtype=torch.long).to(devices[0])).logits.cpu()
# torch.cuda.empty_cache()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment