Skip to content

Instantly share code, notes, and snippets.

@CoffeeVampir3
Last active April 27, 2024 21:11
Show Gist options
  • Save CoffeeVampir3/c490286467fd5e1cc070d0a7a8cf3d6f to your computer and use it in GitHub Desktop.
Save CoffeeVampir3/c490286467fd5e1cc070d0a7a8cf3d6f to your computer and use it in GitHub Desktop.
Bitnet 1.58 MLP Example
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "207ab62c-c7ea-41f0-b1da-d1d21b55c784",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/blackroot/miniforge3/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import os\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from torch.utils.data import Dataset\n",
"from torchvision import transforms\n",
"from PIL import Image\n",
"import random\n",
"import matplotlib.pyplot as plt\n",
"from torch.utils.data import random_split\n",
"import torch.optim as optim\n",
"import lovely_tensors as lt\n",
"\n",
"lt.monkey_patch()\n",
"\n",
"class ImageFolderDataset(Dataset):\n",
" def __init__(self, root_dirs, index_offset=0, transform=None):\n",
" self.root_dirs = root_dirs\n",
" self.transform = transform\n",
" \n",
" self.classes = []\n",
" self.class_to_idx = {}\n",
" self.images = []\n",
" self.labels = []\n",
" \n",
" for i, root_dir in enumerate(root_dirs):\n",
" class_name = os.path.basename(root_dir)\n",
" self.classes.append(class_name)\n",
" self.class_to_idx[class_name] = i + index_offset\n",
" \n",
" for img_name in os.listdir(root_dir):\n",
" img_path = os.path.join(root_dir, img_name)\n",
" self.images.append(img_path)\n",
" self.labels.append(i + index_offset)\n",
" \n",
" def __len__(self):\n",
" return len(self.images)\n",
" \n",
" def __getitem__(self, index):\n",
" img_path = self.images[index]\n",
" label = self.labels[index]\n",
" \n",
" image = Image.open(img_path).convert(\"RGB\")\n",
" if self.transform:\n",
" image = self.transform(image)\n",
" \n",
" return image, label\n",
"\n",
"transform = transforms.Compose([\n",
" transforms.Resize((64, 64)),\n",
" transforms.Grayscale(num_output_channels=1),\n",
" transforms.RandomHorizontalFlip(),\n",
" transforms.RandomRotation(10),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.5], std=[0.5])\n",
"])\n",
"dataset = ImageFolderDataset([\"cougar\", \"piano\", \"crystal\", \"slug\"], transform=transform)\n",
"\n",
"train_size = int(0.8 * len(dataset))\n",
"test_size = len(dataset) - train_size\n",
"\n",
"# Split the dataset into train and test sets\n",
"train_dataset, test_dataset = random_split(dataset, [train_size, test_size])\n",
"\n",
"# Create data loaders for train and test sets\n",
"train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
"test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1ebcb5b8-6816-4007-8641-9d3354849774",
"metadata": {},
"outputs": [],
"source": [
"class ImageMLP(nn.Module):\n",
" def __init__(self):\n",
" super(ImageMLP, self).__init__()\n",
" self.flatten = nn.Flatten()\n",
" self.linear = BitLinearTrain(64 * 64, 4)\n",
" self.softmax = nn.Softmax(dim=1)\n",
"\n",
" def forward(self, x):\n",
" x = self.flatten(x)\n",
" x = self.linear(x)\n",
" x = self.softmax(x)\n",
" return x\n",
"\n",
"class BitLinearTrain(nn.Linear):\n",
" def forward(self, x):\n",
" w = self.weight\n",
" x_norm = x\n",
" x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()\n",
" w_quant = w + (weight_quant(w) - w).detach()\n",
" y = F.linear(x_quant, w_quant)\n",
" return y\n",
"\n",
"def activation_quant(x):\n",
" scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)\n",
" y = (x * scale).round().clamp_(-128, 127) / scale\n",
" return y\n",
"\n",
"def weight_quant(w):\n",
" scale = 1.0 / w.abs().mean().clamp_(min=1e-5)\n",
" u = (w * scale).round().clamp_(-1, 1) / scale\n",
" return u"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "00695407-226e-47b4-9e31-e5303b981b38",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor[1, 4] x∈[0.162, 0.401] μ=0.250 σ=0.109 grad SoftmaxBackward0 [[0.162, 0.179, 0.259, 0.401]]\n"
]
}
],
"source": [
"model = ImageMLP()\n",
"image = torch.randn(1, 1, 64, 64) # Batch size 1, single channel, 64x64 image\n",
"output = model(image)\n",
"print(output)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c95d7f8a-55f4-4e9a-81c8-0c42c8bec890",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch [1/100], Train Loss: 1.3884, Train Acc: 0.2451, Test Loss: 1.3878, Test Acc: 0.2338\n",
"Epoch [2/100], Train Loss: 1.3718, Train Acc: 0.3562, Test Loss: 1.3748, Test Acc: 0.2597\n",
"Epoch [3/100], Train Loss: 1.3481, Train Acc: 0.3464, Test Loss: 1.3966, Test Acc: 0.2078\n",
"Epoch [4/100], Train Loss: 1.3429, Train Acc: 0.3366, Test Loss: 1.3924, Test Acc: 0.2078\n",
"Epoch [5/100], Train Loss: 1.3229, Train Acc: 0.3660, Test Loss: 1.3604, Test Acc: 0.3117\n",
"Epoch [6/100], Train Loss: 1.3339, Train Acc: 0.3889, Test Loss: 1.3848, Test Acc: 0.2727\n",
"Epoch [7/100], Train Loss: 1.3300, Train Acc: 0.3791, Test Loss: 1.3706, Test Acc: 0.3117\n",
"Epoch [8/100], Train Loss: 1.2961, Train Acc: 0.4379, Test Loss: 1.3149, Test Acc: 0.3506\n",
"Epoch [9/100], Train Loss: 1.3127, Train Acc: 0.4052, Test Loss: 1.3745, Test Acc: 0.2727\n",
"Epoch [10/100], Train Loss: 1.2839, Train Acc: 0.4608, Test Loss: 1.3506, Test Acc: 0.2857\n",
"Epoch [11/100], Train Loss: 1.3015, Train Acc: 0.4020, Test Loss: 1.3809, Test Acc: 0.2597\n",
"Epoch [12/100], Train Loss: 1.2927, Train Acc: 0.4281, Test Loss: 1.3772, Test Acc: 0.2597\n",
"Epoch [13/100], Train Loss: 1.3007, Train Acc: 0.4150, Test Loss: 1.3349, Test Acc: 0.3247\n",
"Epoch [14/100], Train Loss: 1.2959, Train Acc: 0.4412, Test Loss: 1.3590, Test Acc: 0.3117\n",
"Epoch [15/100], Train Loss: 1.2863, Train Acc: 0.4346, Test Loss: 1.3331, Test Acc: 0.3117\n",
"Epoch [16/100], Train Loss: 1.2932, Train Acc: 0.4412, Test Loss: 1.3328, Test Acc: 0.3247\n",
"Epoch [17/100], Train Loss: 1.2711, Train Acc: 0.4542, Test Loss: 1.3486, Test Acc: 0.2338\n",
"Epoch [18/100], Train Loss: 1.2717, Train Acc: 0.4706, Test Loss: 1.3386, Test Acc: 0.2857\n",
"Epoch [19/100], Train Loss: 1.2723, Train Acc: 0.5000, Test Loss: 1.3531, Test Acc: 0.2727\n",
"Epoch [20/100], Train Loss: 1.2816, Train Acc: 0.4804, Test Loss: 1.3285, Test Acc: 0.3377\n",
"Epoch [21/100], Train Loss: 1.2643, Train Acc: 0.4706, Test Loss: 1.3221, Test Acc: 0.3377\n",
"Epoch [22/100], Train Loss: 1.2625, Train Acc: 0.4967, Test Loss: 1.3460, Test Acc: 0.3506\n",
"Epoch [23/100], Train Loss: 1.2554, Train Acc: 0.4935, Test Loss: 1.3319, Test Acc: 0.3247\n",
"Epoch [24/100], Train Loss: 1.2628, Train Acc: 0.4575, Test Loss: 1.3445, Test Acc: 0.2727\n",
"Epoch [25/100], Train Loss: 1.2559, Train Acc: 0.4902, Test Loss: 1.3224, Test Acc: 0.3247\n",
"Epoch [26/100], Train Loss: 1.2464, Train Acc: 0.5131, Test Loss: 1.3255, Test Acc: 0.3506\n",
"Epoch [27/100], Train Loss: 1.2546, Train Acc: 0.5000, Test Loss: 1.3380, Test Acc: 0.3377\n",
"Epoch [28/100], Train Loss: 1.2521, Train Acc: 0.5033, Test Loss: 1.3039, Test Acc: 0.3896\n",
"Epoch [29/100], Train Loss: 1.2472, Train Acc: 0.5098, Test Loss: 1.3216, Test Acc: 0.3377\n",
"Epoch [30/100], Train Loss: 1.2547, Train Acc: 0.5065, Test Loss: 1.3584, Test Acc: 0.2857\n",
"Epoch [31/100], Train Loss: 1.2583, Train Acc: 0.4967, Test Loss: 1.3419, Test Acc: 0.3377\n",
"Epoch [32/100], Train Loss: 1.2473, Train Acc: 0.5033, Test Loss: 1.3403, Test Acc: 0.3247\n",
"Epoch [33/100], Train Loss: 1.2433, Train Acc: 0.5131, Test Loss: 1.3628, Test Acc: 0.2857\n",
"Epoch [34/100], Train Loss: 1.2617, Train Acc: 0.4869, Test Loss: 1.3178, Test Acc: 0.3896\n",
"Epoch [35/100], Train Loss: 1.2537, Train Acc: 0.4804, Test Loss: 1.3340, Test Acc: 0.3766\n",
"Epoch [36/100], Train Loss: 1.2573, Train Acc: 0.4706, Test Loss: 1.3322, Test Acc: 0.3766\n",
"Epoch [37/100], Train Loss: 1.2322, Train Acc: 0.5033, Test Loss: 1.3301, Test Acc: 0.2987\n",
"Epoch [38/100], Train Loss: 1.2293, Train Acc: 0.5261, Test Loss: 1.3428, Test Acc: 0.3247\n",
"Epoch [39/100], Train Loss: 1.2406, Train Acc: 0.4837, Test Loss: 1.3140, Test Acc: 0.3636\n",
"Epoch [40/100], Train Loss: 1.2583, Train Acc: 0.4706, Test Loss: 1.3421, Test Acc: 0.3117\n",
"Epoch [41/100], Train Loss: 1.2217, Train Acc: 0.5163, Test Loss: 1.3177, Test Acc: 0.4156\n",
"Epoch [42/100], Train Loss: 1.2367, Train Acc: 0.5163, Test Loss: 1.3153, Test Acc: 0.3506\n",
"Epoch [43/100], Train Loss: 1.2313, Train Acc: 0.5294, Test Loss: 1.3459, Test Acc: 0.3247\n",
"Epoch [44/100], Train Loss: 1.2365, Train Acc: 0.5000, Test Loss: 1.3098, Test Acc: 0.3766\n",
"Epoch [45/100], Train Loss: 1.2157, Train Acc: 0.5392, Test Loss: 1.3091, Test Acc: 0.3896\n",
"Epoch [46/100], Train Loss: 1.2334, Train Acc: 0.5163, Test Loss: 1.3196, Test Acc: 0.4286\n",
"Epoch [47/100], Train Loss: 1.2350, Train Acc: 0.5098, Test Loss: 1.2998, Test Acc: 0.4026\n",
"Epoch [48/100], Train Loss: 1.2516, Train Acc: 0.4869, Test Loss: 1.2899, Test Acc: 0.4416\n",
"Epoch [49/100], Train Loss: 1.2196, Train Acc: 0.5294, Test Loss: 1.3050, Test Acc: 0.4026\n",
"Epoch [50/100], Train Loss: 1.2427, Train Acc: 0.4869, Test Loss: 1.3235, Test Acc: 0.3377\n",
"Epoch [51/100], Train Loss: 1.2176, Train Acc: 0.5261, Test Loss: 1.3047, Test Acc: 0.3896\n",
"Epoch [52/100], Train Loss: 1.2205, Train Acc: 0.5359, Test Loss: 1.2923, Test Acc: 0.4156\n",
"Epoch [53/100], Train Loss: 1.2315, Train Acc: 0.5261, Test Loss: 1.2855, Test Acc: 0.4286\n",
"Epoch [54/100], Train Loss: 1.2233, Train Acc: 0.5163, Test Loss: 1.3073, Test Acc: 0.3896\n",
"Epoch [55/100], Train Loss: 1.2181, Train Acc: 0.5359, Test Loss: 1.3375, Test Acc: 0.3117\n",
"Epoch [56/100], Train Loss: 1.2148, Train Acc: 0.5229, Test Loss: 1.3074, Test Acc: 0.3636\n",
"Epoch [57/100], Train Loss: 1.2290, Train Acc: 0.5261, Test Loss: 1.3215, Test Acc: 0.3636\n",
"Epoch [58/100], Train Loss: 1.2147, Train Acc: 0.5392, Test Loss: 1.3163, Test Acc: 0.3636\n",
"Epoch [59/100], Train Loss: 1.2224, Train Acc: 0.5327, Test Loss: 1.3162, Test Acc: 0.3636\n",
"Epoch [60/100], Train Loss: 1.2158, Train Acc: 0.5425, Test Loss: 1.3096, Test Acc: 0.3377\n",
"Epoch [61/100], Train Loss: 1.2278, Train Acc: 0.5196, Test Loss: 1.3247, Test Acc: 0.3247\n",
"Epoch [62/100], Train Loss: 1.2139, Train Acc: 0.5490, Test Loss: 1.2920, Test Acc: 0.3506\n",
"Epoch [63/100], Train Loss: 1.2271, Train Acc: 0.5065, Test Loss: 1.2936, Test Acc: 0.4286\n",
"Epoch [64/100], Train Loss: 1.2281, Train Acc: 0.5261, Test Loss: 1.2811, Test Acc: 0.4416\n",
"Epoch [65/100], Train Loss: 1.2119, Train Acc: 0.5490, Test Loss: 1.2986, Test Acc: 0.4156\n",
"Epoch [66/100], Train Loss: 1.2119, Train Acc: 0.5359, Test Loss: 1.3000, Test Acc: 0.4286\n",
"Epoch [67/100], Train Loss: 1.2148, Train Acc: 0.5359, Test Loss: 1.2936, Test Acc: 0.3766\n",
"Epoch [68/100], Train Loss: 1.2027, Train Acc: 0.5654, Test Loss: 1.3177, Test Acc: 0.3636\n",
"Epoch [69/100], Train Loss: 1.2142, Train Acc: 0.5621, Test Loss: 1.3276, Test Acc: 0.3636\n",
"Epoch [70/100], Train Loss: 1.2198, Train Acc: 0.5294, Test Loss: 1.3201, Test Acc: 0.3506\n",
"Epoch [71/100], Train Loss: 1.2167, Train Acc: 0.5294, Test Loss: 1.2761, Test Acc: 0.4156\n",
"Epoch [72/100], Train Loss: 1.2071, Train Acc: 0.5556, Test Loss: 1.2906, Test Acc: 0.4156\n",
"Epoch [73/100], Train Loss: 1.2218, Train Acc: 0.5033, Test Loss: 1.3335, Test Acc: 0.3377\n",
"Epoch [74/100], Train Loss: 1.2106, Train Acc: 0.5294, Test Loss: 1.2671, Test Acc: 0.4545\n",
"Epoch [75/100], Train Loss: 1.2024, Train Acc: 0.5327, Test Loss: 1.2874, Test Acc: 0.4026\n",
"Epoch [76/100], Train Loss: 1.2039, Train Acc: 0.5359, Test Loss: 1.3201, Test Acc: 0.3766\n",
"Epoch [77/100], Train Loss: 1.2146, Train Acc: 0.5163, Test Loss: 1.2959, Test Acc: 0.4026\n",
"Epoch [78/100], Train Loss: 1.2081, Train Acc: 0.5294, Test Loss: 1.3248, Test Acc: 0.3377\n",
"Epoch [79/100], Train Loss: 1.2084, Train Acc: 0.5523, Test Loss: 1.2996, Test Acc: 0.4545\n",
"Epoch [80/100], Train Loss: 1.2040, Train Acc: 0.5784, Test Loss: 1.3071, Test Acc: 0.3766\n",
"Epoch [81/100], Train Loss: 1.1909, Train Acc: 0.5686, Test Loss: 1.3533, Test Acc: 0.3377\n",
"Epoch [82/100], Train Loss: 1.2093, Train Acc: 0.5261, Test Loss: 1.2765, Test Acc: 0.4416\n",
"Epoch [83/100], Train Loss: 1.2042, Train Acc: 0.5686, Test Loss: 1.2995, Test Acc: 0.4286\n",
"Epoch [84/100], Train Loss: 1.1912, Train Acc: 0.5686, Test Loss: 1.3140, Test Acc: 0.3766\n",
"Epoch [85/100], Train Loss: 1.2020, Train Acc: 0.5588, Test Loss: 1.2986, Test Acc: 0.4286\n",
"Epoch [86/100], Train Loss: 1.1934, Train Acc: 0.5686, Test Loss: 1.3028, Test Acc: 0.4026\n",
"Epoch [87/100], Train Loss: 1.1820, Train Acc: 0.5752, Test Loss: 1.2957, Test Acc: 0.4156\n",
"Epoch [88/100], Train Loss: 1.1831, Train Acc: 0.5850, Test Loss: 1.2988, Test Acc: 0.3896\n",
"Epoch [89/100], Train Loss: 1.1941, Train Acc: 0.5621, Test Loss: 1.3145, Test Acc: 0.4416\n",
"Epoch [90/100], Train Loss: 1.2011, Train Acc: 0.5490, Test Loss: 1.3271, Test Acc: 0.3766\n",
"Epoch [91/100], Train Loss: 1.2018, Train Acc: 0.5261, Test Loss: 1.3209, Test Acc: 0.3506\n",
"Epoch [92/100], Train Loss: 1.1850, Train Acc: 0.5719, Test Loss: 1.3075, Test Acc: 0.3636\n",
"Epoch [93/100], Train Loss: 1.1925, Train Acc: 0.5523, Test Loss: 1.3425, Test Acc: 0.3247\n",
"Epoch [94/100], Train Loss: 1.2059, Train Acc: 0.5588, Test Loss: 1.2893, Test Acc: 0.4156\n",
"Epoch [95/100], Train Loss: 1.2010, Train Acc: 0.5458, Test Loss: 1.3426, Test Acc: 0.3506\n",
"Epoch [96/100], Train Loss: 1.2081, Train Acc: 0.5458, Test Loss: 1.2879, Test Acc: 0.3896\n",
"Epoch [97/100], Train Loss: 1.2034, Train Acc: 0.5425, Test Loss: 1.3067, Test Acc: 0.4156\n",
"Epoch [98/100], Train Loss: 1.1987, Train Acc: 0.5654, Test Loss: 1.3199, Test Acc: 0.3896\n",
"Epoch [99/100], Train Loss: 1.2173, Train Acc: 0.5523, Test Loss: 1.3263, Test Acc: 0.3377\n",
"Epoch [100/100], Train Loss: 1.2025, Train Acc: 0.5588, Test Loss: 1.3163, Test Acc: 0.3636\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"\n",
"# Set device\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"# Move model to device\n",
"model = ImageMLP().to(device)\n",
"\n",
"# Define loss function and optimizer\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.Adam(model.parameters(), lr=0.0001)\n",
"\n",
"# Training loop\n",
"num_epochs = 100\n",
"\n",
"for epoch in range(num_epochs):\n",
" # Training\n",
" model.train()\n",
" train_loss = 0.0\n",
" train_correct = 0\n",
" train_total = 0\n",
"\n",
" for images, labels in train_loader:\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" # Forward pass\n",
" outputs = model(images)\n",
" loss = criterion(outputs, labels)\n",
"\n",
" # Backward pass and optimization\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # Calculate loss and accuracy\n",
" train_loss += loss.item() * images.size(0)\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" train_total += labels.size(0)\n",
" train_correct += (predicted == labels).sum().item()\n",
"\n",
" train_loss = train_loss / len(train_dataset)\n",
" train_acc = train_correct / train_total\n",
"\n",
" # Testing\n",
" model.eval()\n",
" test_loss = 0.0\n",
" test_correct = 0\n",
" test_total = 0\n",
"\n",
" with torch.no_grad():\n",
" for images, labels in test_loader:\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" # Forward pass\n",
" outputs = model(images)\n",
" loss = criterion(outputs, labels)\n",
"\n",
" # Calculate loss and accuracy\n",
" test_loss += loss.item() * images.size(0)\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" test_total += labels.size(0)\n",
" test_correct += (predicted == labels).sum().item()\n",
"\n",
" test_loss = test_loss / len(test_dataset)\n",
" test_acc = test_correct / test_total\n",
"\n",
" print(f\"Epoch [{epoch+1}/{num_epochs}], \"\n",
" f\"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, \"\n",
" f\"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "5440ad7f-015c-44af-a1b7-56a8b0d36d09",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Layer: linear.weight\n",
"Weight Shape: torch.Size([4, 4096])\n",
"Weight Data: Parameter containing:\n",
"Parameter[4, 4096] n=16384 (64Kb) x∈[-0.055, 0.055] μ=0.000 σ=0.014 grad cuda:0\n",
"---\n"
]
}
],
"source": [
"for name, param in model.named_parameters():\n",
" if 'weight' in name:\n",
" print(f\"Layer: {name}\")\n",
" print(f\"Weight Shape: {param.shape}\")\n",
" print(f\"Weight Data: {param}\")\n",
" print(\"---\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9a02a356-05e5-44c4-bdd3-34a652bc38ea",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"\n",
"def roundclip(x, a, b):\n",
" return np.maximum(a, np.minimum(b, np.round(x)))\n",
"\n",
"def quantize_weights(weights):\n",
" # Compute the average absolute value of the weight matrix\n",
" gamma = np.mean(np.abs(weights))\n",
" \n",
" # Scale the weight matrix by the average absolute value\n",
" scaled_weights = weights / (gamma + 1e-8)\n",
" \n",
" # Round each scaled weight to the nearest integer in {-1, 0, +1}\n",
" quantized_weights = roundclip(scaled_weights, -1, 1)\n",
" \n",
" return quantized_weights\n",
"\n",
"weights = model.linear.weight.detach().cpu().numpy()\n",
"\n",
"# Quantize the weights\n",
"quantized_weights = quantize_weights(weights)\n",
"\n",
"# Convert the quantized weights back to a PyTorch tensor\n",
"quantized_weights_tensor = torch.from_numpy(quantized_weights).to(model.linear.weight.device)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b0f0d742-afeb-4246-80cf-d1c47b7d3f76",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor[4, 4096] n=16384 (64Kb) x∈[-1.000, 1.000] μ=0.011 σ=0.834 cuda:0\n",
"tensor[4, 4096] n=16384 (64Kb) x∈[-1.000, 1.000] μ=0.011 σ=0.834 cuda:0\n",
"tensor([[-1., -1., -0., ..., -1., -1., -0.],\n",
" [-1., 1., -0., ..., 1., -0., 1.],\n",
" [ 1., 1., 1., ..., 1., 1., 1.],\n",
" [-1., -1., -1., ..., -1., -1., -1.]], device='cuda:0')\n"
]
}
],
"source": [
"print(quantized_weights_tensor)\n",
"print(quantized_weights_tensor.v)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a68d791c-7127-442d-a30d-4702d63fc6c2",
"metadata": {},
"outputs": [],
"source": [
"class BitLinear(nn.Module):\n",
" def __init__(self, in_features, out_features, quantized_weights):\n",
" super(BitLinear, self).__init__()\n",
" self.in_features = in_features\n",
" self.out_features = out_features\n",
" self.quantized_weights = nn.Parameter(quantized_weights, requires_grad=False)\n",
"\n",
" def forward(self, input):\n",
" # Perform matrix multiplication with the quantized weights\n",
" output = torch.matmul(input, self.quantized_weights.t())\n",
" return output"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d680df92-4f5e-408e-9442-48bd10a3a61b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Prior to quantizing to bit linear: Test Loss: 1.2852, Test Acc: 0.4286\n",
"After quantizing to bit linear: Test Loss: 1.3021, Test Acc: 0.4416\n"
]
}
],
"source": [
"model.eval()\n",
"test_loss = 0.0\n",
"test_correct = 0\n",
"test_total = 0\n",
"\n",
"with torch.no_grad():\n",
" for images, labels in test_loader:\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" # Forward pass\n",
" outputs = model(images)\n",
" loss = criterion(outputs, labels)\n",
"\n",
" # Calculate loss and accuracy\n",
" test_loss += loss.item() * images.size(0)\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" test_total += labels.size(0)\n",
" test_correct += (predicted == labels).sum().item()\n",
"\n",
"test_loss = test_loss / len(test_dataset)\n",
"test_acc = test_correct / test_total\n",
"print(f\"Prior to quantizing to bit linear: Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}\")\n",
"\n",
"bit_linear = BitLinear(model.linear.in_features, model.linear.out_features, quantized_weights_tensor)\n",
"model.linear = bit_linear\n",
"\n",
"test_loss = 0.0\n",
"test_correct = 0\n",
"test_total = 0\n",
"\n",
"with torch.no_grad():\n",
" for images, labels in test_loader:\n",
" images = images.to(device)\n",
" labels = labels.to(device)\n",
"\n",
" # Forward pass\n",
" outputs = model(images)\n",
" loss = criterion(outputs, labels)\n",
"\n",
" # Calculate loss and accuracy\n",
" test_loss += loss.item() * images.size(0)\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" test_total += labels.size(0)\n",
" test_correct += (predicted == labels).sum().item()\n",
"\n",
"test_loss = test_loss / len(test_dataset)\n",
"test_acc = test_correct / test_total\n",
"print(f\"After quantizing to bit linear: Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment