Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Last active March 19, 2024 06:22
Show Gist options
  • Save laksjdjf/77f34b4e3583e05de5d2ee09a9e7f4d8 to your computer and use it in GitHub Desktop.
Save laksjdjf/77f34b4e3583e05de5d2ee09a9e7f4d8 to your computer and use it in GitHub Desktop.
# https://github.com/huggingface/transformers/blob/838b87abe231fd70be5132088d0dee72a7bb8d62/src/transformers/models/opt/modeling_opt.py#L147
"""
model = AutoModelForCausalLM.from_pretrained("p1atdev/dart-v1-sft")
apply_hook(model)
"""
import torch
import torch.nn as nn
def forward_hooker(self):
def forward(
hidden_states,
key_value_states = None,
past_key_value = None,
attention_mask = None,
layer_head_mask = None,
output_attentions = False,
):
"""Input shape: Batch x Time x Channel"""
bsz, tgt_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states) * self.scaling
if past_key_value is not None:
# reuse k, v, self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
if self.is_decoder:
past_key_value = (key_states, value_states)
query_states = self._shape(query_states, tgt_len, bsz)
attn_weights = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
scale=1 # already scaled (self.scaling)
)
attn_weights_reshaped = attn_weights.transpose(1, 2).reshape(bsz, tgt_len, self.embed_dim)
attn_output = self.out_proj(attn_weights_reshaped)
return attn_output, None, past_key_value
return forward
def apply_hook(model):
for name, module in model.named_modules():
if module.__class__.__name__ == "OPTAttention":
module.forward = forward_hooker(module)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment