Skip to content

Instantly share code, notes, and snippets.

@Helw150
Last active May 10, 2024 19:58
Show Gist options
  • Save Helw150/70216e0b9f22650db54db6941cd8daea to your computer and use it in GitHub Desktop.
Save Helw150/70216e0b9f22650db54db6941cd8daea to your computer and use it in GitHub Desktop.
text = # Tokenized Text Corresponding to Recording Transcript
audio = # Mel Spectrogram of the Recording
# Only Train Connector and Projection
self.encoder.freeze()
self.llama.freeze()
# Convert Raw Audio Signal to 1500 Embeddings with Whisper Encoder (CNN+Transformer)
audio_features = self.encoder(audio)
# Learned Query Tokens that will serve as the Q in QKV Cross Attention with the Above
static_query_tokens = (query_tokens + query_position_embeds)
# Cross Attention Between Query Tokens and Extracted Audio Features
virt_whisper_tokens = self.connector.transformer(
queries=static_query_tokens,
cross_attn_keys_and_values=audio_features,
)
# Linear Projection From Whisper Embedding Space to LLama Input Embedding Space
virtual_audio_tokens = self.projection(virt_whisper_tokens, axis="embed")
# Ground Truth Embedding of the Transcript
text_embeds = self.llama.embeddings.embed(text)
# Get Output Embedding of Just the Final token in response to both text and audio embeddings.
audio = self.llama.transformer(virtual_audio_tokens)[-1]
text = self.llama.transformer(text_embeds)[-1]
# L2 Loss Between Final Embedding
diff_distill = audio_pred - text_pred
loss = hax.dot(diff_distill, diff_distill, axis="embed") ** 0.5
""""
Minimizing the above loss is equivalent to minimizing the KL Divergence of the final token.
Back of the envelope proof:
KL Divergence Loss is defined as loss_{kl} = P_{target} * (log P_{target} - log P_{source}.
This function achieves it's global minimum at P_{target} = P_{source}.
For neural models, we just define P as softmax(matmul(OutputHiddenState, EmbeddingMatrix)).
So in this case, KL Divergence is minimized at softmax(matmul(OutputHiddenState_{target}, EmbeddingMatrix)) = softmax(matmul(OutputHiddenState_{source}, EmbeddingMatrix)).
If EmbeddingMatrix is held constant (e.g. we are doing LoRA Training), this simplifies down to just OutputHiddenState_{target} = OutputHiddenState_{source}).
This means that minimizing the L2 Loss of these output hidden states should lead to a global minimum for the KL as well, but it's much cheaper and more stable to compute when vocabulary size > embedding dimension.
Simplified Proof Of Concept Notebook: https://colab.research.google.com/drive/1g1BEegIJzoZ1PHY_PeRNJkuveIQn7QIp?usp=sharing
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment