Created
February 24, 2024 21:00
-
-
Save mgostIH/688bb767d4571223e5d23820b38374fb to your computer and use it in GitHub Desktop.
Layernorm vs Scaled Normalization
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
# Standard Layer Normalization | |
def layernorm(x, epsilon=1e-5): | |
mean = np.mean(x, axis=-1, keepdims=True) | |
variance = np.var(x, axis=-1, keepdims=True) | |
normalized_x = (x - mean) / np.sqrt(variance + epsilon) | |
return normalized_x | |
# Normalization with scaling to match LayerNorm norm | |
def normalize(x, epsilon=1e-5): | |
norm_x = np.sqrt(np.sum(x**2, axis=-1, keepdims=True) + epsilon) | |
return x / norm_x | |
# Scaled normalize to make the norm equal to that of layernorm | |
def scaled_normalize(x, epsilon=1e-5): | |
D = x.shape[-1] # Assuming x is of shape (N, D) | |
# The sqrt(D) factor makes the norm of layernorm and normalize the same | |
return np.sqrt(D) * normalize(x, epsilon) | |
# Generate 1000 random points in D=100 | |
D = 100 | |
points = np.random.randn(1000, D) | |
# Apply layernorm and scaled_normalize | |
ln_points = layernorm(points) | |
sn_points = scaled_normalize(points) | |
# Calculate norms | |
original_norms = np.linalg.norm(points, axis=1) | |
ln_norms = np.linalg.norm(ln_points, axis=1) | |
sn_norms = np.linalg.norm(sn_points, axis=1) | |
# Printing out the norms range for layernorm and scaled_normalize | |
print("Original Norms Range:", np.min(original_norms), np.max(original_norms)) | |
print("LayerNorm Norms Range:", np.min(ln_norms), np.max(ln_norms)) | |
print("Scaled Normalize Norms Range:", np.min(sn_norms), np.max(sn_norms)) | |
# Plotting only the histograms of the norms of the points before any normalization | |
plt.figure(figsize=(8, 4)) | |
plt.hist(original_norms, bins=30, alpha=0.75, color='blue') | |
plt.title('Histogram of Norms Before Normalization') | |
plt.xlabel('Norm Value') | |
plt.ylabel('Frequency') | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment