Skip to content

Instantly share code, notes, and snippets.

@buttercutter
Created December 31, 2023 10:21
Show Gist options
  • Save buttercutter/eddddcfa93fa82711afa173819f41a5d to your computer and use it in GitHub Desktop.
Save buttercutter/eddddcfa93fa82711afa173819f41a5d to your computer and use it in GitHub Desktop.
[Half-Quadratic Quantization of Large Machine Learning Models](https://mobiusml.github.io/hqq_blog/)
# Reference: [Half-Quadratic Quantization of Large Machine Learning Models](https://mobiusml.github.io/hqq_blog/)
import numpy as np
# Define the shrinkage function for soft-thresholding
def shrink(x, beta, p):
return np.sign(x) * np.maximum(np.abs(x) - (np.abs(x)**(p-1))/beta, 0)
# Define the quantization and dequantization operators
def quantize(W, s, z):
return np.round(W / s + z)
def dequantize(Wq, s, z):
return s * (Wq - z)
# Initialize parameters
W = np.random.randn(10, 10) # Replace with actual weights
print(f"W = {W}")
'''
The choice of scaling factor (s) and zero point (z) can significantly affect the accuracy of the dequantization process in recovering the original unquantized weights from the quantized weights. Here are some key points:
- The scaling factor s controls the "step size" of quantization levels. A larger s means coarser quantization and lower accuracy in representing the original distribution of weights.
- The zero point z determines the offset of the quantization range. An inappropriate zero point can clip part of the weights' distribution, losing information.
- An overly large s will quantize weights to a small set of levels, losing precision. Small s retains more precision but requires more bits for storage.
- A zero point z shifted significantly from the center of the weights' distribution will clip off values on one end, losing range. Centering z helps preserve the distribution.
- The optimal s and z depend on the statistical distribution of weights. These should be set to retain as much precision as possible for the weights.
- For a fixed number of bits, there is a tradeoff between s and z. Larger s may allow better z centering of the range.
- The dequantization accuracy depends directly on how well s and z can undo the quantization and recover the original unquantized weights.
So in summary, s and z should be carefully optimized based on the weight statistics to maximize dequantization accuracy and retain as much information as possible from the original weights.
Credit: Claude2 AI chatbot
'''
s = 1 # Scale factor (can be learned as well)
z = 0 # Zero-point (initially, can be 0)
beta = 0.001 # Beta for the HQQ algorithm
k = 0.9 # Update factor for beta
p = 1 # p-norm
num_iterations = 100 # Number of iterations for optimization
tolerance = 1e-5 # Tolerance for convergence
# Initialize the extra variable We to the original weights W
We = np.copy(W)
# Optimization loop
for i in range(num_iterations):
prev_We = np.copy(We)
# Update We using the shrinkage function and the previous We
Wq = quantize(W, s, z)
Wdq = dequantize(Wq, s, z)
We = shrink(Wdq, beta, p)
# Update z based on the new We
z = np.mean(Wq - (W - We) / s)
# Update beta
beta *= k
# Check for convergence
if np.linalg.norm(We - prev_We) < tolerance:
break
# Output the quantized weights
Wq = quantize(W, s, z)
print(f"Wq = {Wq}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment