Created
March 21, 2024 21:58
-
-
Save BlackSamorez/af8fd5accfb51d719757ef5e0014f03e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"env: CUDA_VISIBLE_DEVICES=4\n", | |
"env: CUDA_HOME=/mnt/nfs/clustersw/shared/cuda/12.1.0\n" | |
] | |
} | |
], | |
"source": [ | |
"import torch\n", | |
"from torch import nn\n", | |
"from torch.nn import functional as F\n", | |
"\n", | |
"import aqlm\n", | |
"from aqlm.inference_kernels.cuda_kernel import CUDA_KERNEL" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"in_features = 4096\n", | |
"out_features = 4096 * 3\n", | |
"\n", | |
"dtype = torch.bfloat16\n", | |
"device = \"cuda\"\n", | |
"factory_kwargs = {\"dtype\": dtype, \"device\": device}\n", | |
"\n", | |
"codebooks = torch.rand((1, 2**16, 1, 8), **factory_kwargs) # [num_codebooks, codebook_size, out_group_size, in_group_size]\n", | |
"codes = torch.randint(\n", | |
" 0, 2**14,\n", | |
" (out_features, in_features // 8, 1),\n", | |
" device=device,\n", | |
" dtype=torch.int16,\n", | |
") # [num_out_groups, num_in_groups, num_codebooks]\n", | |
"\n", | |
"# SCALES\n", | |
"scales = torch.rand((out_features, 1, 1, 1), **factory_kwargs) # [num_out_groups, 1, 1, 1]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"reference_weight = aqlm.utils._dequantize_weight(\n", | |
" codes,\n", | |
" codebooks,\n", | |
" scales,\n", | |
")\n", | |
"\n", | |
"weight = CUDA_KERNEL.code1x16_dequant(\n", | |
" codes,\n", | |
" codebooks,\n", | |
" scales,\n", | |
")\n", | |
"\n", | |
"vllm_weight = CUDA_KERNEL.vllm_dequant(\n", | |
" codes,\n", | |
" codebooks,\n", | |
" scales,\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"np.set_printoptions(2, 10000, linewidth=200, floatmode=\"fixed\", suppress=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"torch.testing.assert_close(weight, reference_weight)\n", | |
"torch.testing.assert_close(vllm_weight, reference_weight)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 4.62 s, sys: 19.5 ms, total: 4.64 s\n", | |
"Wall time: 4.64 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"with torch.no_grad():\n", | |
" for i in range(1000):\n", | |
" aqlm.utils._dequantize_weight(\n", | |
" codes,\n", | |
" codebooks,\n", | |
" scales,\n", | |
" )\n", | |
" torch.cuda.synchronize()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 899 ms, sys: 169 µs, total: 899 ms\n", | |
"Wall time: 896 ms\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"with torch.no_grad():\n", | |
" for i in range(1000):\n", | |
" CUDA_KERNEL.code1x16_dequant(\n", | |
" codes,\n", | |
" codebooks,\n", | |
" scales,\n", | |
" )\n", | |
" torch.cuda.synchronize()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"CPU times: user 1.31 s, sys: 11.5 ms, total: 1.32 s\n", | |
"Wall time: 1.32 s\n" | |
] | |
} | |
], | |
"source": [ | |
"%%time\n", | |
"\n", | |
"with torch.no_grad():\n", | |
" for i in range(1000):\n", | |
" CUDA_KERNEL.vllm_dequant(\n", | |
" codes,\n", | |
" codebooks,\n", | |
" scales,\n", | |
" )\n", | |
" torch.cuda.synchronize()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.14" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment