Skip to content

Instantly share code, notes, and snippets.

@mgostIH
Created February 24, 2024 21:00
Show Gist options
  • Save mgostIH/688bb767d4571223e5d23820b38374fb to your computer and use it in GitHub Desktop.
Save mgostIH/688bb767d4571223e5d23820b38374fb to your computer and use it in GitHub Desktop.
Layernorm vs Scaled Normalization
import numpy as np
import matplotlib.pyplot as plt
# Standard Layer Normalization
def layernorm(x, epsilon=1e-5):
mean = np.mean(x, axis=-1, keepdims=True)
variance = np.var(x, axis=-1, keepdims=True)
normalized_x = (x - mean) / np.sqrt(variance + epsilon)
return normalized_x
# Normalization with scaling to match LayerNorm norm
def normalize(x, epsilon=1e-5):
norm_x = np.sqrt(np.sum(x**2, axis=-1, keepdims=True) + epsilon)
return x / norm_x
# Scaled normalize to make the norm equal to that of layernorm
def scaled_normalize(x, epsilon=1e-5):
D = x.shape[-1] # Assuming x is of shape (N, D)
# The sqrt(D) factor makes the norm of layernorm and normalize the same
return np.sqrt(D) * normalize(x, epsilon)
# Generate 1000 random points in D=100
D = 100
points = np.random.randn(1000, D)
# Apply layernorm and scaled_normalize
ln_points = layernorm(points)
sn_points = scaled_normalize(points)
# Calculate norms
original_norms = np.linalg.norm(points, axis=1)
ln_norms = np.linalg.norm(ln_points, axis=1)
sn_norms = np.linalg.norm(sn_points, axis=1)
# Printing out the norms range for layernorm and scaled_normalize
print("Original Norms Range:", np.min(original_norms), np.max(original_norms))
print("LayerNorm Norms Range:", np.min(ln_norms), np.max(ln_norms))
print("Scaled Normalize Norms Range:", np.min(sn_norms), np.max(sn_norms))
# Plotting only the histograms of the norms of the points before any normalization
plt.figure(figsize=(8, 4))
plt.hist(original_norms, bins=30, alpha=0.75, color='blue')
plt.title('Histogram of Norms Before Normalization')
plt.xlabel('Norm Value')
plt.ylabel('Frequency')
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment