Skip to content

Instantly share code, notes, and snippets.

@ericflo
Created July 8, 2024 00:40
Show Gist options
  • Save ericflo/a4082ff3eee196b1475780613a5ebfc0 to your computer and use it in GitHub Desktop.
Save ericflo/a4082ff3eee196b1475780613a5ebfc0 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "4c17bb47-4040-4067-a8be-c49b491ce884",
"metadata": {},
"outputs": [],
"source": [
"import copy\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import math\n",
"\n",
"from transformers import AutoModelForCausalLM, PreTrainedModel, AutoTokenizer, DynamicCache\n",
"from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding as RotaryEmbedding, apply_rotary_pos_emb\n",
"\n",
"torch.set_default_device('cuda:0')\n",
"\n",
"class MQASelfAttention(nn.Module):\n",
" def __init__(self, config):\n",
" super().__init__()\n",
" self.config = config\n",
" self.hidden_size = config.hidden_size # 8192 to match Llama 2 70B\n",
" self.num_heads = config.num_attention_heads # 64 to match Llama 2 70B\n",
" self.head_dim = self.hidden_size // self.num_heads # 128\n",
"\n",
" # In MQA, we have separate projections for query, but shared projections for key and value\n",
" self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)\n",
" self.kv_proj = nn.Linear(self.hidden_size, 2 * self.head_dim, bias=False)\n",
" self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)\n",
"\n",
" self.rotary_emb = RotaryEmbedding(self.head_dim)\n",
"\n",
" def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, cache_position=None):\n",
" # hidden_states shape: (batch_size, seq_len, hidden_size)\n",
" bsz, q_len, _ = hidden_states.size()\n",
" \n",
" # Project input to query, key, and value states\n",
" # query_states shape: (batch_size, seq_len, num_heads * head_dim)\n",
" query_states = self.q_proj(hidden_states)\n",
" # kv_states shape: (batch_size, seq_len, 2 * head_dim)\n",
" kv_states = self.kv_proj(hidden_states)\n",
" # key_states shape: (batch_size, seq_len, head_dim)\n",
" # value_states shape: (batch_size, seq_len, head_dim)\n",
" key_states, value_states = kv_states.split(self.head_dim, dim=2)\n",
" \n",
" # Reshape and transpose query\n",
" # query_states shape: (batch_size, num_heads, seq_len, head_dim)\n",
" query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)\n",
" # key_states shape: (batch_size, 1, seq_len, head_dim)\n",
" key_states = key_states.unsqueeze(1)\n",
" # value_states shape: (batch_size, 1, seq_len, head_dim)\n",
" value_states = value_states.unsqueeze(1)\n",
" \n",
" kv_seq_len = key_states.shape[-2]\n",
" if past_key_value is not None:\n",
" if isinstance(past_key_value, tuple):\n",
" # For compatibility with the original model's cache format\n",
" kv_seq_len += past_key_value[0].shape[-2]\n",
" elif isinstance(past_key_value, DynamicCache):\n",
" # For compatibility with the new DynamicCache format\n",
" kv_seq_len += past_key_value.get_seq_length()\n",
" if position_ids is None:\n",
" position_ids = torch.arange(kv_seq_len, device=hidden_states.device).unsqueeze(0)\n",
" cos, sin = self.rotary_emb(value_states, position_ids)\n",
" query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)\n",
" \n",
" if past_key_value is not None:\n",
" if isinstance(past_key_value, tuple):\n",
" # Concatenate past key-value states\n",
" key_states = torch.cat([past_key_value[0], key_states], dim=2)\n",
" value_states = torch.cat([past_key_value[1], value_states], dim=2)\n",
" elif isinstance(past_key_value, DynamicCache):\n",
" # Use the DynamicCache update method\n",
" past_key_value.update(key_states, value_states, layer_idx=0)\n",
" key_states = past_key_value.key_cache[0]\n",
" value_states = past_key_value.value_cache[0]\n",
" \n",
" if use_cache:\n",
" past_key_value = (key_states, value_states)\n",
" \n",
" # Repeat k/v for all heads\n",
" # key_states shape: (batch_size, num_heads, kv_seq_len, head_dim)\n",
" key_states = key_states.expand(-1, self.num_heads, -1, -1)\n",
" # value_states shape: (batch_size, num_heads, kv_seq_len, head_dim)\n",
" value_states = value_states.expand(-1, self.num_heads, -1, -1)\n",
" \n",
" # Compute attention scores\n",
" # attn_weights shape: (batch_size, num_heads, seq_len, kv_seq_len)\n",
" attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)\n",
" \n",
" if attention_mask is not None:\n",
" # Adjust attention mask if necessary\n",
" if attention_mask.size(-1) != key_states.size(-2):\n",
" attention_mask = F.pad(attention_mask, (key_states.size(-2) - attention_mask.size(-1), 0), value=float('-inf'))\n",
" \n",
" # Add attention mask\n",
" # attn_weights shape: (batch_size, num_heads, seq_len, kv_seq_len)\n",
" attn_weights = attn_weights + attention_mask.unsqueeze(1)\n",
" \n",
" # Normalize attention weights\n",
" # attn_weights shape: (batch_size, num_heads, seq_len, kv_seq_len)\n",
" attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)\n",
" \n",
" # Compute attention output\n",
" # attn_output shape: (batch_size, num_heads, seq_len, head_dim)\n",
" attn_output = torch.matmul(attn_weights, value_states)\n",
" \n",
" # Reshape attention output\n",
" # attn_output shape: (batch_size, seq_len, num_heads, head_dim)\n",
" attn_output = attn_output.transpose(1, 2).contiguous()\n",
" # attn_output shape: (batch_size, seq_len, hidden_size)\n",
" attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)\n",
" \n",
" # Project output\n",
" # attn_output shape: (batch_size, seq_len, hidden_size)\n",
" attn_output = self.o_proj(attn_output)\n",
" \n",
" # Handle output_attentions\n",
" if output_attentions:\n",
" # Return attention weights if required\n",
" return attn_output, attn_weights, past_key_value\n",
" else:\n",
" return attn_output, None, past_key_value\n",
"\n",
"def create_mqa_model(original_model):\n",
" # Create a deep copy of the original model\n",
" mqa_model = copy.deepcopy(original_model)\n",
"\n",
" def select_top_heads(proj, num_key_value_heads, head_dim, hidden_size, num_heads_to_keep):\n",
" head_norms = torch.norm(proj.view(num_key_value_heads, head_dim, hidden_size), dim=(1, 2))\n",
" top_head_indices = torch.topk(head_norms, num_heads_to_keep).indices\n",
" return proj[top_head_indices]\n",
"\n",
" def initialize_mqa_from_gqa(mqa_module, gqa_module):\n",
" hidden_size = gqa_module.hidden_size\n",
" num_heads = gqa_module.num_heads\n",
" num_key_value_heads = gqa_module.num_key_value_heads\n",
" head_dim = gqa_module.head_dim\n",
"\n",
" # Copy query projection\n",
" # GQA q_proj shape: (num_heads * head_dim, hidden_size)\n",
" # MQA q_proj shape: (num_heads * head_dim, hidden_size)\n",
" mqa_module.q_proj.weight.data = gqa_module.q_proj.weight.data.clone()\n",
"\n",
" # Initialize kv_proj from k_proj and v_proj\n",
" with torch.no_grad():\n",
" k_proj = gqa_module.k_proj.weight.data\n",
" v_proj = gqa_module.v_proj.weight.data\n",
" \n",
" k_proj = k_proj.view(num_key_value_heads, head_dim, hidden_size)\n",
" v_proj = v_proj.view(num_key_value_heads, head_dim, hidden_size)\n",
" \n",
" # Calculate head importance (you can use different metrics here)\n",
" k_importance = torch.norm(k_proj, dim=(1, 2))\n",
" v_importance = torch.norm(v_proj, dim=(1, 2))\n",
" \n",
" # Normalize importance scores\n",
" k_weights = F.softmax(k_importance, dim=0)\n",
" v_weights = F.softmax(v_importance, dim=0)\n",
" \n",
" # Weighted average\n",
" k_proj = (k_proj * k_weights.view(-1, 1, 1)).sum(dim=0)\n",
" v_proj = (v_proj * v_weights.view(-1, 1, 1)).sum(dim=0)\n",
" \n",
" kv_proj = torch.cat([k_proj, v_proj], dim=0)\n",
" mqa_module.kv_proj.weight.data = kv_proj\n",
"\n",
" # Copy output projection\n",
" # GQA o_proj shape: (hidden_size, num_heads * head_dim)\n",
" # MQA o_proj shape: (hidden_size, num_heads * head_dim)\n",
" mqa_module.o_proj.weight.data = gqa_module.o_proj.weight.data.clone()\n",
"\n",
" # Copy rotary embedding parameters if applicable\n",
" if hasattr(gqa_module, 'rotary_emb') and hasattr(mqa_module, 'rotary_emb'):\n",
" mqa_module.rotary_emb = copy.deepcopy(gqa_module.rotary_emb)\n",
"\n",
" def replace_attention_layers(model):\n",
" for name, module in model.named_modules():\n",
" if isinstance(module, type(model.model.layers[0].self_attn)):\n",
" parent_name = '.'.join(name.split('.')[:-1])\n",
" parent = model.get_submodule(parent_name)\n",
" # Create new MQA module with the same config as the model\n",
" new_attn = MQASelfAttention(model.config)\n",
" # Initialize the new MQA module using weights from the GQA module\n",
" initialize_mqa_from_gqa(new_attn, module)\n",
" # Replace the GQA module with the new MQA module\n",
" setattr(parent, name.split('.')[-1], new_attn)\n",
"\n",
" # Ensure the model uses the correct cache format\n",
" model._use_dynamic_cache = False\n",
"\n",
" # Replace attention layers in the copied model\n",
" replace_attention_layers(mqa_model)\n",
"\n",
" return mqa_model"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9e8b19ea-6916-485a-8593-680018c40c4d",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c6b7fbe31f4c437781d096208344c7b4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"original_model = AutoModelForCausalLM.from_pretrained(\"meta-llama/Meta-Llama-3-8B\")\n",
"mqa_model = create_mqa_model(original_model)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d90c54de-d19a-42df-b7d1-aebdd9ff3408",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output difference: 1.5808703899383545\n"
]
}
],
"source": [
"# Example usage\n",
"input_ids = torch.randint(0, 32000, (1, 20)) # Random input\n",
"\n",
"# Get outputs from both models\n",
"with torch.no_grad():\n",
" original_output = original_model(input_ids, use_cache=False).logits\n",
" mqa_output = mqa_model(input_ids, use_cache=False).logits\n",
"\n",
"# Compare outputs\n",
"print(\"Output difference:\", torch.abs(original_output - mqa_output).mean().item())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a5f18efa-306b-4881-b837-d0016372b256",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Original model parameters: 8,030,261,248\n",
"MQA model parameters: 8,022,921,216\n",
"Parameter reduction: 0.09%\n"
]
}
],
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"original_params = count_parameters(original_model)\n",
"mqa_params = count_parameters(mqa_model)\n",
"\n",
"print(f\"Original model parameters: {original_params:,}\")\n",
"print(f\"MQA model parameters: {mqa_params:,}\")\n",
"print(f\"Parameter reduction: {(original_params - mqa_params) / original_params:.2%}\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "28bb8d97-f6a9-4abb-bcfc-b0fceea11e03",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Original model perplexity: 6.53\n",
"MQA model perplexity: 3.54\n"
]
}
],
"source": [
"def calculate_perplexity(model, tokenizer, text):\n",
" inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
" with torch.no_grad():\n",
" outputs = model(**inputs, labels=inputs[\"input_ids\"], use_cache=False)\n",
" \n",
" return torch.exp(outputs.loss).item()\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Meta-Llama-3-8B\")\n",
"sample_text = \"The quick brown fox jumps over the lazy dog.\"\n",
"\n",
"original_perplexity = calculate_perplexity(original_model, tokenizer, sample_text)\n",
"mqa_perplexity = calculate_perplexity(mqa_model, tokenizer, sample_text)\n",
"\n",
"print(f\"Original model perplexity: {original_perplexity:.2f}\")\n",
"print(f\"MQA model perplexity: {mqa_perplexity:.2f}\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "acecefa1-3ed4-4c52-8d23-12c283eb389e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Original Model Perplexity:\n",
" Sample 1: 6.53\n",
" Sample 2: 33.46\n",
" Sample 3: 17.03\n",
" Sample 4: 24.68\n",
"\n",
"MQA Model Perplexity:\n",
" Sample 1: 3.54\n",
" Sample 2: 371.61\n",
" Sample 3: 43.34\n",
" Sample 4: 36.39\n",
"\n",
"Average perplexity for Original Model: 20.43\n",
"\n",
"Average perplexity for MQA Model: 113.72\n",
"\n",
"Perplexity difference (Original - MQA): -93.30\n",
"Relative improvement: -456.77%\n"
]
}
],
"source": [
"def compare_models_perplexity(models, tokenizer, texts):\n",
" results = {}\n",
" for model_name, model in models.items():\n",
" model_results = []\n",
" for i, text in enumerate(texts):\n",
" perplexity = calculate_perplexity(model, tokenizer, text)\n",
" model_results.append((f\"Sample {i+1}\", perplexity))\n",
" results[model_name] = model_results\n",
" return results\n",
"\n",
"models = {\n",
" \"Original\": original_model,\n",
" \"MQA\": mqa_model\n",
"}\n",
"\n",
"# Sample texts\n",
"texts = [\n",
" \"The quick brown fox jumps over the lazy dog.\",\n",
" \"In a world of advanced AI, ethical considerations become paramount.\",\n",
" \"Quantum computing promises to revolutionize cryptography and drug discovery.\",\n",
" \"Climate change poses significant challenges to global ecosystems and economies.\"\n",
"]\n",
"\n",
"# Calculate and compare perplexities\n",
"results = compare_models_perplexity(models, tokenizer, texts)\n",
"\n",
"# Print results\n",
"for model_name, model_results in results.items():\n",
" print(f\"\\n{model_name} Model Perplexity:\")\n",
" for sample, perplexity in model_results:\n",
" print(f\" {sample}: {perplexity:.2f}\")\n",
"\n",
"# Calculate and print average perplexity for each model\n",
"for model_name, model_results in results.items():\n",
" avg_perplexity = sum(perplexity for _, perplexity in model_results) / len(model_results)\n",
" print(f\"\\nAverage perplexity for {model_name} Model: {avg_perplexity:.2f}\")\n",
"\n",
"# Compare models\n",
"original_avg = sum(perplexity for _, perplexity in results[\"Original\"]) / len(results[\"Original\"])\n",
"mqa_avg = sum(perplexity for _, perplexity in results[\"MQA\"]) / len(results[\"MQA\"])\n",
"\n",
"print(f\"\\nPerplexity difference (Original - MQA): {original_avg - mqa_avg:.2f}\")\n",
"print(f\"Relative improvement: {(original_avg - mqa_avg) / original_avg * 100:.2f}%\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ea707fb2-ea95-4fd4-b336-438a0042495d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Comparing completions from Original and MQA models:\n",
"\n",
"Prompt 1: The future of artificial intelligence is\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n",
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n",
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Original Model:\n",
"The future of artificial intelligence is uncertain. It is a field that has the potential to revolutionize the world, but it also has the potential to cause great harm. AI is being used in a variety of ways, from helping to diagnose diseases to\n",
"\n",
"MQA Model:\n",
"The future of artificial intelligence is a reality, although it is an impossible reality, and therefore, it is magical.Very, very, extremely.Or theseus, we, or well-paying whims.It, as it, to create, et\n",
"\n",
"==================================================\n",
"\n",
"Prompt 2: Climate change will affect our planet by\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n",
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Original Model:\n",
"Climate change will affect our planet by 2040\n",
"Climate change will affect our planet by 2040\n",
"The climate has changed since the industrial revolution, and today it has reached the point that we can no longer ignore the effects of global warming\n",
"\n",
"MQA Model:\n",
"Climate change will affect our planet by altering the temperature and precipitation patterns of the world. This is the first sentence of a five sentence paragraph. The second sentence. The third sentence. The fourth sentence. The fifth sentence. This sentence is a run\n",
"\n",
"==================================================\n",
"\n",
"Prompt 3: The most important scientific discovery of the 21st century is\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n",
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Original Model:\n",
"The most important scientific discovery of the 21st century is that we are capable of improving our own brains. Neuroplasticity is the basis of our ability to change our minds. The brain can change itself in response to instruction — something scientists\n",
"\n",
"MQA Model:\n",
"The most important scientific discovery of the 21st century is that we exist in the mathematical structure of space-time, which is not a casual guest in the Universe, but rather is a participator in the Universe, god and universe are in the\n",
"\n",
"==================================================\n",
"\n",
"Prompt 4: In the next decade, space exploration will focus on\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Original Model:\n",
"In the next decade, space exploration will focus on three major areas: Earth, the Moon, and Mars. This video is a 3-part series that discusses the different ways these areas can be explored, with a focus on the benefits and challenges\n",
"\n",
"MQA Model:\n",
"In the next decade, space exploration will focus on the development of private-sector space industry, which has the potential to revolutionize the space industry by developing cheaper and more efficient launch vehicles, as well as unmanned space stations, and ultimately personal travel into\n",
"\n",
"==================================================\n",
"\n"
]
}
],
"source": [
"def generate_completion(model, tokenizer, prompt, max_length=50):\n",
" inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
" with torch.no_grad():\n",
" outputs = model.generate(\n",
" **inputs,\n",
" max_length=max_length,\n",
" num_return_sequences=1,\n",
" do_sample=True,\n",
" temperature=0.7,\n",
" top_p=0.95,\n",
" use_cache=False,\n",
" )\n",
" return tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
"\n",
"# Sample prompts\n",
"prompts = [\n",
" \"The future of artificial intelligence is\",\n",
" \"Climate change will affect our planet by\",\n",
" \"The most important scientific discovery of the 21st century is\",\n",
" \"In the next decade, space exploration will focus on\"\n",
"]\n",
"\n",
"print(\"Comparing completions from Original and MQA models:\\n\")\n",
"\n",
"for i, prompt in enumerate(prompts, 1):\n",
" print(f\"Prompt {i}: {prompt}\")\n",
" \n",
" original_completion = generate_completion(original_model, tokenizer, prompt)\n",
" mqa_completion = generate_completion(mqa_model, tokenizer, prompt)\n",
" \n",
" print(f\"\\nOriginal Model:\")\n",
" print(original_completion)\n",
" print(f\"\\nMQA Model:\")\n",
" print(mqa_completion)\n",
" print(\"\\n\" + \"=\"*50 + \"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ff4ee84a-0343-418a-a525-d0eb1a2d526e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment