Skip to content

Instantly share code, notes, and snippets.

@catid
Created March 23, 2024 23:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save catid/977de9dcd74e05d91e29662001439f0c to your computer and use it in GitHub Desktop.
Save catid/977de9dcd74e05d91e29662001439f0c to your computer and use it in GitHub Desktop.
GIVT GMM Decoder (Claude 3)
# Collaboration between Claude-3 and GPT-4 to implement https://arxiv.org/pdf/2312.02116.pdf
# This is just the GMM decoder part of the model they propose (which is the new thing).
# This one was mainly generated by Claude-3.
# The AIs provided two implementations of the idea and revised eachothers' code.
# I tested that the unit tests pass but haven't tried it in a language model yet.
import torch
import torch.nn as nn
import torch.nn.functional as F
class ImprovedGMMParametersPrediction(nn.Module):
def __init__(self, hidden_dim, output_dim, num_components):
super().__init__()
self.num_components = num_components
# Initialize parameters prediction layers
self.mu = nn.Linear(hidden_dim, output_dim * num_components)
self.log_sigma = nn.Linear(hidden_dim, output_dim * num_components)
self.logits_pi = nn.Linear(hidden_dim, num_components)
def forward(self, x):
batch_size, seq_length, _ = x.size()
# Predict parameters and reshape appropriately
mu = self.mu(x).view(batch_size, seq_length, self.num_components, -1)
log_sigma = self.log_sigma(x).view(batch_size, seq_length, self.num_components, -1)
logits_pi = self.logits_pi(x).view(batch_size, seq_length, self.num_components)
# Softmax for mixing coefficients and exp for standard deviations
pi = F.softmax(logits_pi, dim=-1)
sigma = torch.exp(log_sigma)
return mu, sigma, pi
class ImprovedGMMOutput(nn.Module):
def __init__(self, hidden_dim, output_dim, num_mixtures):
super().__init__()
self.num_mixtures = num_mixtures
# Parameter prediction layers
self.fc_means = nn.Linear(hidden_dim, output_dim * num_mixtures)
self.fc_scales = nn.Linear(hidden_dim, output_dim * num_mixtures)
self.fc_weights = nn.Linear(hidden_dim, num_mixtures)
def forward(self, x):
batch_size, seq_length, _ = x.size()
# Predict means, scales, and weights for the GMM
means = self.fc_means(x).view(batch_size, seq_length, self.num_mixtures, -1)
scales = F.softplus(self.fc_scales(x)).view(batch_size, seq_length, self.num_mixtures, -1)
weights = F.softmax(self.fc_weights(x), dim=-1)
return means, scales, weights
# Unit tests
from torch.testing import assert_allclose
def test_improved_gmm_parameter_prediction():
batch_size = 2
seq_length = 3
hidden_dim = 4
output_dim = 5
num_components = 3
model = ImprovedGMMParametersPrediction(hidden_dim, output_dim, num_components)
x = torch.randn(batch_size, seq_length, hidden_dim)
mu, sigma, pi = model(x)
assert mu.shape == (batch_size, seq_length, num_components, output_dim)
assert sigma.shape == (batch_size, seq_length, num_components, output_dim)
assert pi.shape == (batch_size, seq_length, num_components)
assert torch.allclose(pi.sum(dim=-1), torch.ones(batch_size, seq_length), rtol=1e-5, atol=1e-8)
assert (sigma > 0).all()
def test_improved_gmm_output():
batch_size = 2
seq_length = 3
hidden_dim = 4
output_dim = 5
num_mixtures = 3
model = ImprovedGMMOutput(hidden_dim, output_dim, num_mixtures)
x = torch.randn(batch_size, seq_length, hidden_dim)
means, scales, weights = model(x)
assert means.shape == (batch_size, seq_length, num_mixtures, output_dim)
assert scales.shape == (batch_size, seq_length, num_mixtures, output_dim)
assert weights.shape == (batch_size, seq_length, num_mixtures)
assert torch.allclose(weights.sum(dim=-1), torch.ones(batch_size, seq_length), rtol=1e-5, atol=1e-8)
assert (scales > 0).all()
if __name__ == "__main__":
test_improved_gmm_parameter_prediction()
test_improved_gmm_output()
print("All tests passed!")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment