Skip to content

Instantly share code, notes, and snippets.

@sparverius
Created September 28, 2023 15:53
Show Gist options
  • Save sparverius/df9e0e7044c2fb981ad925ec10cf7488 to your computer and use it in GitHub Desktop.
Save sparverius/df9e0e7044c2fb981ad925ec10cf7488 to your computer and use it in GitHub Desktop.
diff --git a/AutoGPTQ/auto_gptq/utils/peft_utils.py b/qa-lora/peft_utils.py
index 46850d0..2b4682e 100644
--- a/AutoGPTQ/auto_gptq/utils/peft_utils.py
+++ b/qa-lora/peft_utils.py
@@ -16,6 +16,9 @@ from peft.utils.other import _get_submodules
 from ..modeling._base import BaseGPTQForCausalLM
 
 
+group_size = 32  # quantization group_size
+
+
 class GPTQLoraConfig(LoraConfig):
     injected_fused_attention: bool = False
     injected_fused_mlp: bool = False
@@ -48,6 +51,7 @@ class GPTQLoraLinear(torch.nn.Linear, LoraLayer):
 
         self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
         self.active_adapter = adapter_name
+        self.qa_pool = nn.AvgPool1d(group_size)  # using pooling layer to conduct sum operation
 
     def reset_lora_parameters(self, adapter_name):
         if adapter_name in self.lora_A.keys():
@@ -77,7 +81,7 @@ class GPTQLoraLinear(torch.nn.Linear, LoraLayer):
             scale = self.scaling[self.active_adapter]
 
             x = x.type_as(lora_A.weight.data)
-            adapter_result = (lora_B(lora_A(lora_dropout(x))) * scale).type_as(result)
+            adapter_result = (lora_B(lora_A(lora_dropout(self.qa_pool(x))) * scale).type_as(result)
             result += adapter_result
         else:
             result = self.linear_module(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment