Skip to content

Instantly share code, notes, and snippets.

@KeremTurgutlu
Last active May 24, 2024 13:51
Show Gist options
  • Save KeremTurgutlu/a99e138e7fca7c9feb6cc9b74394b89e to your computer and use it in GitHub Desktop.
Save KeremTurgutlu/a99e138e7fca7c9feb6cc9b74394b89e to your computer and use it in GitHub Desktop.
test_triton_mm.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 7,
"id": "f7e69d06-de3c-487c-ad62-7aebce775e15",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "04d16c8e-bfba-4e6b-9dd9-58daae15135e",
"metadata": {},
"outputs": [],
"source": [
"from vllm.model_executor.layers.quantization.triton_mm import triton_mixed_mm"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b349f02c-7df3-4942-861e-523f00e34436",
"metadata": {},
"outputs": [],
"source": [
"from hqq.core.quantize import HQQLinear, BaseQuantizeConfig, Quantizer"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7335bab7-6909-45cf-b623-6468052940c8",
"metadata": {},
"outputs": [],
"source": [
"def pack_2xint4(t):\n",
" \"\"\"\n",
" The packing format is such that consecutive rows are packed into a lower / upper bits\n",
" E.g.,\n",
" Original, unpacked B (dtype i8):\n",
" [\n",
" [0, 1, 2, 3]\n",
" [4, 5, 6, 7]\n",
" [8, 9, 10, 11]\n",
" [12, 13, 14, 15]\n",
" ]\n",
" Packed B:\n",
" [\n",
" [0|4, 1|5, 2|6, 3|7]\n",
" [8|12, 9|13, 10|14, 11|15]\n",
" ]\n",
" (Note each entry in `Packed B` is shown lsb->msb)\n",
" \"\"\"\n",
" assert t.dtype == torch.int8 or t.dtype == torch.uint8\n",
" t = t.reshape(t.shape[0] // 2, 2, t.shape[1]).permute(1, 0, 2)\n",
" return (t[0] & 0xF) | (t[1] << 4)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "1ad339f8-d5df-4740-81c8-61f46eba450b",
"metadata": {},
"outputs": [],
"source": [
"def patch_hqq_to_tritonmm(layer, patch_param):\n",
" if(isinstance(layer, HQQLinear)):\n",
"\n",
" #Handle no grouping case\n",
" shape = layer.meta['shape']\n",
" layer.group_size = layer.quant_config['weight_quant_params']['group_size']\n",
" if(layer.group_size is None):\n",
" layer.group_size = shape[1] \n",
"\n",
" #Update scale/zero\n",
" M, N = shape[::-1]\n",
" layer.meta ['scale'] = layer.meta ['scale'].reshape(N, -1).T \n",
" layer.meta ['zero'] = layer.meta ['zero'].reshape(N, -1).T \n",
"\n",
" #Repack\n",
" layer.W_q.data = pack_2xint4(Quantizer.unpack[layer.meta ['packing']](layer.W_q).reshape(layer.meta [\"shape\"]).T).data \n",
"\n",
" #Set pred vals\n",
" layer.fp8_fast_accum = True #False \n",
" layer.kernel_type = \"max_autotune\" #max_autotune\n",
"\n",
" def matmul_tritonmm(self, x, transpose=True):\n",
"\n",
" out_dim = self.meta['shape'][0] if (transpose) else self.meta['shape'][1]\n",
" out = triton_mixed_mm(x.view([-1, x.shape[-1]]),\n",
" self.W_q,\n",
" self.meta[\"scale\"],\n",
" self.meta[\"zero\"],\n",
" group_size=self.group_size,\n",
" fp8_fast_accum=self.fp8_fast_accum,\n",
" kernel_type=self.kernel_type,\n",
" transposed=not transpose,\n",
" ).view([x.shape[0], x.shape[1], out_dim])\n",
"\n",
" return out \n",
"\n",
" def forward_tritonmm_backprop(self, x):\n",
" return HQQMatmulNoCacheMul.apply(x, self.matmul, self.bias)\n",
"\n",
" def forward_tritonmm_forward(self, x):\n",
" out = self.matmul(x)\n",
" if(self.bias is not None):\n",
" out += self.bias\n",
" return out \n",
"\n",
"\n",
" layer.matmul = lambda x, transpose: matmul_tritonmm(layer, x, transpose)\n",
" layer.forward = lambda x: forward_tritonmm_backprop(layer, x)\n",
"\n",
"\n",
" return layer"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "6ff9ca4b-9ddf-40a5-b185-c3ec886f02ed",
"metadata": {},
"outputs": [],
"source": [
"quant_config = BaseQuantizeConfig(nbits=4,\n",
" group_size=64, \n",
" quant_zero=False,\n",
" quant_scale=False,\n",
" offload_meta=False,\n",
" view_as_float=False, \n",
" axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "8fadd71a-7c24-4eed-a8ad-7c60af6284e6",
"metadata": {},
"outputs": [],
"source": [
"q_weight = torch.randn(4096, 4096) # output x input\n",
"k_weight = torch.randn(1024, 4096)\n",
"v_weight = torch.randn(1024, 4096)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "4d9529d0-8b1d-4b31-9b53-7422f8243136",
"metadata": {},
"outputs": [],
"source": [
"dtype = torch.bfloat16\n",
"triton_params = {}\n",
"for name,p in [(\"q\", q_weight), (\"k\",k_weight), (\"v\",v_weight)]:\n",
" m = torch.nn.Linear(*p.T.shape, bias=False)\n",
" m.weight.data.copy_(p)\n",
" dummy_hqq_linear = HQQLinear(m, quant_config, compute_dtype=dtype)\n",
" patched_hqq_linear = patch_hqq_to_tritonmm(dummy_hqq_linear, None)\n",
" triton_params[name] = {\"Wq\":patched_hqq_linear.W_q, \n",
" \"scale\":patched_hqq_linear.meta['scale'], \n",
" \"zero\":patched_hqq_linear.meta['zero']}"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "2f7f93a3-5118-4766-9f24-5e769ff3841e",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"qkv_weight = torch.cat([triton_params[k][\"Wq\"] for k in [\"q\", \"k\", \"v\"]], dim=1)\n",
"qkv_scale = torch.cat([triton_params[k][\"scale\"] for k in [\"q\", \"k\", \"v\"]], dim=1)\n",
"qkv_zero = torch.cat([triton_params[k][\"zero\"] for k in [\"q\", \"k\", \"v\"]], dim=1)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "e608e1f0-3cac-4dca-b1df-bd669de6e717",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([2048, 6144])"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"qkv_weight.shape"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "c4c025e3-a171-4801-8428-9e588bd516e7",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"x = torch.randn(16,4096, device=\"cuda\", dtype=torch.bfloat16)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "6dea6917-6644-41c5-af61-5f235fa653aa",
"metadata": {},
"outputs": [],
"source": [
"output_qkv = triton_mixed_mm(x,\n",
" qkv_weight,\n",
" qkv_scale,\n",
" qkv_zero,\n",
" group_size=quant_config['weight_quant_params']['group_size'],\n",
" fp8_fast_accum=False,\n",
" kernel_type=\"compute_bound\",\n",
" transposed=False)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "4bbc933e-a807-41db-992b-4bec442aadd2",
"metadata": {},
"outputs": [],
"source": [
"output_q = triton_mixed_mm(x,\n",
" triton_params[\"q\"][\"Wq\"],\n",
" triton_params[\"q\"][\"scale\"],\n",
" triton_params[\"q\"][\"zero\"],\n",
" group_size=quant_config['weight_quant_params']['group_size'],\n",
" fp8_fast_accum=False,\n",
" kernel_type=\"compute_bound\",\n",
" transposed=False)"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "4fbe32cd-c5a4-4ab8-a99c-1d028e4764fa",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.equal(output_qkv[:,:4096], output_q)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "ce2bc22b-9961-420f-b67b-d07fe309aa3d",
"metadata": {},
"outputs": [],
"source": [
"output_k = triton_mixed_mm(x,\n",
" triton_params[\"k\"][\"Wq\"],\n",
" triton_params[\"k\"][\"scale\"],\n",
" triton_params[\"k\"][\"zero\"],\n",
" group_size=quant_config['weight_quant_params']['group_size'],\n",
" fp8_fast_accum=False,\n",
" kernel_type=\"compute_bound\",\n",
" transposed=False)"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "cf9e24fe-ad1f-4c97-a0eb-7f13482d5b52",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.equal(output_qkv[:,4096:5120], output_k)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "fbf45ba1-5282-4cc5-a96b-849d162f8adf",
"metadata": {},
"outputs": [],
"source": [
"output_v = triton_mixed_mm(x,\n",
" triton_params[\"v\"][\"Wq\"],\n",
" triton_params[\"v\"][\"scale\"],\n",
" triton_params[\"v\"][\"zero\"],\n",
" group_size=quant_config['weight_quant_params']['group_size'],\n",
" fp8_fast_accum=False,\n",
" kernel_type=\"compute_bound\",\n",
" transposed=False)"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "c3be901d-10e8-42ba-8392-ee48ba5b8967",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.equal(output_qkv[:,5120:], output_v)"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "21dc7908-ce3e-4f60-a9ee-8dbd733a92ad",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.equal(output_qkv, torch.cat([output_q, output_k, output_v], dim=1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a50c3def-6633-4ef8-b7f7-2bd8c8701d55",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "c01f04a9-9f90-4ad0-9ad0-ec6c74804d05",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "badb7b0f-32c2-48e4-a1fa-27254cd548c0",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment