Skip to content

Instantly share code, notes, and snippets.

@catid
Created March 23, 2024 23:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save catid/a5ac7ece55627fffe73efce205f5c848 to your computer and use it in GitHub Desktop.
Save catid/a5ac7ece55627fffe73efce205f5c848 to your computer and use it in GitHub Desktop.
GIVT GMM Decoder (GPT-4)
# Collaboration between Claude-3 and GPT-4 to implement https://arxiv.org/pdf/2312.02116.pdf
# This is just the GMM decoder part of the model they propose (which is the new thing).
# This one was mainly generated by GPT-4.
# The AIs provided two implementations of the idea and revised eachothers' code.
# I tested that the unit tests pass but haven't tried it in a language model yet.
import torch
import torch.nn as nn
import torch.nn.functional as F
class GMMParametersPrediction(nn.Module):
def __init__(self, hidden_dim, output_dim, num_components):
super(GMMParametersPrediction, self).__init__()
self.num_components = num_components
self.mu = nn.Linear(hidden_dim, output_dim * num_components)
self.log_sigma = nn.Linear(hidden_dim, output_dim * num_components)
self.logits_pi = nn.Linear(hidden_dim, num_components)
def forward(self, x):
mu = self.mu(x).view(x.size(0), x.size(1), self.num_components, -1)
log_sigma = self.log_sigma(x).view(x.size(0), x.size(1), self.num_components, -1)
logits_pi = self.logits_pi(x)
pi = F.softmax(logits_pi, dim=-1)
sigma = torch.exp(log_sigma)
return mu, sigma, pi
class GMMOutput(nn.Module):
def __init__(self, hidden_dim, output_dim, num_mixtures):
super(GMMOutput, self).__init__()
self.num_mixtures = num_mixtures
self.fc_means = nn.Linear(hidden_dim, output_dim * num_mixtures)
self.fc_log_scales = nn.Linear(hidden_dim, output_dim * num_mixtures)
self.fc_logits_weights = nn.Linear(hidden_dim, num_mixtures)
def forward(self, x):
means = self.fc_means(x).view(x.size(0), x.size(1), self.num_mixtures, -1)
log_scales = self.fc_log_scales(x).view(x.size(0), x.size(1), self.num_mixtures, -1)
scales = torch.exp(log_scales)
logits_weights = self.fc_logits_weights(x)
weights = F.softmax(logits_weights, dim=-1)
return means, scales, weights
if __name__ == "__main__":
print("GMM Models Module. Define and test GMM parameter prediction models.")
import unittest
import torch
class TestGMMParametersPrediction(unittest.TestCase):
def setUp(self):
self.hidden_dim = 512 # Example hidden dimension
self.output_dim = 256 # Example output dimension (e.g., embedding size)
self.num_components = 10 # Number of GMM components
self.batch_size = 4
self.seq_length = 7
self.model = GMMParametersPrediction(self.hidden_dim, self.output_dim, self.num_components)
def test_output_shapes(self):
# Simulate input tensor
x = torch.randn(self.batch_size, self.seq_length, self.hidden_dim)
mu, sigma, pi = self.model(x)
self.assertEqual(mu.shape, (self.batch_size, self.seq_length, self.num_components, self.output_dim))
self.assertEqual(sigma.shape, (self.batch_size, self.seq_length, self.num_components, self.output_dim))
self.assertEqual(pi.shape, (self.batch_size, self.seq_length, self.num_components))
def test_constraints(self):
x = torch.randn(self.batch_size, self.seq_length, self.hidden_dim)
_, sigma, pi = self.model(x)
self.assertTrue(torch.all(sigma > 0), "All sigma values should be positive.")
self.assertTrue(torch.allclose(pi.sum(dim=-1), torch.ones(self.batch_size, self.seq_length)), "Mixing coefficients should sum to 1.")
class TestGMMOutput(unittest.TestCase):
def setUp(self):
self.hidden_dim = 512
self.output_dim = 256
self.num_mixtures = 10
self.batch_size = 4
self.seq_length = 7
self.model = GMMOutput(self.hidden_dim, self.output_dim, self.num_mixtures)
def test_output_shapes(self):
x = torch.randn(self.batch_size, self.seq_length, self.hidden_dim)
means, scales, weights = self.model(x)
self.assertEqual(means.shape, (self.batch_size, self.seq_length, self.num_mixtures, self.output_dim))
self.assertEqual(scales.shape, (self.batch_size, self.seq_length, self.num_mixtures, self.output_dim))
self.assertEqual(weights.shape, (self.batch_size, self.seq_length, self.num_mixtures))
def test_constraints(self):
x = torch.randn(self.batch_size, self.seq_length, self.hidden_dim)
_, scales, weights = self.model(x)
self.assertTrue(torch.all(scales > 0), "All scale values should be positive.")
self.assertTrue(torch.allclose(weights.sum(dim=-1), torch.ones(self.batch_size, self.seq_length)), "Mixing coefficients should sum to 1.")
if __name__ == "__main__":
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment