Last active
August 8, 2023 02:11
-
-
Save arif9799/e1edd265f8c0346a74fd0275fe717210 to your computer and use it in GitHub Desktop.
Attention is not enough! Word Embeddings plus Positional Encoding
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from manim import * | |
from manim.utils.unit import Percent, Pixels | |
from colour import Color | |
import gensim | |
from gensim.models import Word2Vec | |
import numpy as np | |
import torch as t | |
config.frame_width = 64 | |
config.frame_height = 36 | |
HIGH = False | |
if HIGH: | |
print("Its high") | |
config.pixel_height = 2160 | |
config.pixel_width = 3840 | |
else: | |
print("Its low") | |
config.pixel_height = 1080 | |
config.pixel_width = 1920 | |
################################################################################ | |
# Creating Word Vectors # | |
################################################################################ | |
def construct_Word_Embeddings(sentences, dim): | |
data = sentences | |
dimension = dim | |
model = Word2Vec(data, | |
min_count = 1, | |
vector_size = dimension, | |
window = 3) | |
sequences = [] | |
for seq in data: | |
sequence = [] | |
for word in seq: | |
word_emb = model.wv.get_vector(word, True).reshape(1,-1) | |
word_emb = t.from_numpy(word_emb) | |
sequence.append(word_emb) | |
sequences.append(t.round(t.cat(sequence, dim=0), decimals=2)) | |
return sequences | |
data = [['Attention', 'is', 'not', 'enough'], ['Towards', 'Data', 'Science'], ['It', 'was', 'a', 'really', 'good', 'article']] | |
sentences_Embedded = construct_Word_Embeddings(data, 6) | |
################################################################################ | |
# Creating Positional Encoding Vector # | |
################################################################################ | |
def construct_Positional_Encoding(input): | |
x = input.unsqueeze(0) | |
max_len = 10000 | |
input_dimension = x.shape[-1] | |
seq_len = x.shape[-2] | |
dimension_indices = t.arange(0,input_dimension, 2) | |
denominator = t.pow(max_len, dimension_indices/input_dimension) | |
numerator = t.arange(0, seq_len, 1).reshape(seq_len, 1) | |
even_positions_encoded = t.sin(numerator/denominator) | |
odd_positions_encoded = t.cos(numerator/denominator) | |
combined = t.stack([even_positions_encoded, odd_positions_encoded], dim = 2) | |
pe = combined.reshape(1,seq_len,input_dimension) | |
pe = pe.squeeze(0) | |
x = x.squeeze(0) | |
pe = t.round(pe, decimals=2) | |
x_pe = t.round(x + pe ,decimals=2) | |
return x, pe, x_pe | |
sent, pos_enc, pos_encoded = construct_Positional_Encoding(sentences_Embedded[0]) | |
def fullMatrix(matrixContents, matrixHeight, matrixColor, matrixName, rowNames, colNames, rightColNames=None, rCols = False): | |
MatrixA = VGroup() | |
# ----------------> EMBEDDINGS MATRIX | |
matrixEmbeddings = Matrix(matrixContents, h_buff=1.5, v_buff=1).set(height = matrixHeight, color = matrixColor) | |
MatrixA.add(matrixEmbeddings) | |
# ----------------> EMBEDDINGS MATRIX ROW NAMES | |
matrixEmbeddingsRows = matrixEmbeddings.get_rows() | |
matrixEmbeddingsRowNames =VGroup(*[MathTex(rowNames[i]) | |
.next_to(matrixEmbeddingsRows[i], LEFT, buff=matrixEmbeddings.height/4) | |
.add_updater(lambda x, y = matrixEmbeddingsRows[i]: x.next_to(y, LEFT, buff=matrixEmbeddings.height/4)) for i in range(sent.shape[0])]) | |
MatrixA.add(matrixEmbeddingsRowNames) | |
if rCols: | |
# ----------------> POSITIONAL AWARE ENCODED MATRIX ROW NAMES RIGHT SIDE | |
matrixPositionAwareRowNames_R = VGroup(*[MathTex(rightColNames[i]) | |
.next_to(matrixEmbeddingsRows[i], RIGHT, buff=matrixEmbeddings.height/4) | |
.add_updater(lambda x, y = matrixEmbeddingsRows[i]: x.next_to(y, RIGHT, buff=matrixEmbeddings.height/4)) for i in range(sent.shape[0])]) | |
MatrixA.add(matrixPositionAwareRowNames_R) | |
# ----------------> EMBEDDINGS MATRIX COLUMN NAMES | |
matrixEmbeddingsCols = matrixEmbeddings.get_columns() | |
matrixEmbeddingsColNames = VGroup( * [MathTex(colNames[i]) | |
.next_to(matrixEmbeddingsCols[i], UP, buff=matrixEmbeddings.height/5) | |
.add_updater(lambda x, y = matrixEmbeddingsCols[i]: x.next_to(y, UP, buff=matrixEmbeddings.height/5)) for i in range(sent.shape[1])] ) | |
MatrixA.add(matrixEmbeddingsColNames) | |
# ----------------> EMBEDDINGS MATRIX Text | |
TextA = Text(matrixName, font_size = 36).next_to(MatrixA[0],DOWN, buff=MatrixA[0].height/5).add_updater(lambda x, y = MatrixA[0]: x.next_to(y, DOWN, buff=y.height/5)) | |
MatrixA.add(TextA) | |
return MatrixA | |
class Figure5(Scene): | |
def construct(self): | |
################################################################################ | |
# INTRO ANIMATION SAME ACROSS ALL VIDEOS # | |
################################################################################ | |
def myAnimation(wordsForIntro: str): | |
fontHeight = config.frame_height//8 | |
fontColor = WHITE | |
timePerChar = 0.1 | |
C = MathTex(r"\mathbb{C}", color = fontColor).scale(config.frame_height//3) | |
self.play(Broadcast(C), run_time=1) | |
self.add(C) | |
# Building text objects of individual characters. | |
wordsMath = VGroup() | |
for word in wordsForIntro: | |
charTex = VGroup() | |
for i,ch in enumerate(word): | |
chTex = MathTex("\mathbb{%s}"%ch, color = fontColor).scale(fontHeight) | |
if i != 0: | |
chTex.next_to(charTex[-1], RIGHT, buff=0.05).align_to(C, DOWN) | |
else: | |
chTex.next_to(C, RIGHT, buff=0.05).align_to(C, DOWN) | |
charTex.add(chTex) | |
wordsMath.add(charTex) | |
# Succesion or AnimationGroup--- Both are messed up ----HENCE INEFFECIENT ANIMATION | |
for wInd in range(len(wordsMath)): | |
for chInd in range(len(wordsMath[wInd])): | |
self.play(Write(wordsMath[wInd][chInd], run_time = timePerChar)) | |
self.wait(0.5) | |
for chInd in reversed(range(len(wordsMath[wInd]))): | |
self.play(Unwrite(wordsMath[wInd][chInd], run_time = timePerChar)) | |
self.play(Circumscribe(C, color=MAROON_E, fade_out=False, time_width=2, shape=Circle, buff=1, stroke_width=config.frame_height//3, run_time=1.5)) | |
self.play(ShrinkToCenter(C, run_time=0.25)) | |
self.wait(0.5) | |
# myAnimation(wordsForIntro= ['OMPLEX', 'ONCEPTS', 'OMPREHENSIBLE']) | |
################################################################################ | |
# MAIN CODE OF THE FIGURE STARTS HERE # | |
# NO FUNCTION OR ANYTHING OF THAT SORT # | |
################################################################################ | |
# ----------------> Useful Variables | |
words = data[0] | |
dimensions = ['d_' + str(i) for i in range(sent.shape[1])] | |
positions = ['position_' + str(i) for i in range(sent.shape[0])] | |
_LHS = VGroup() | |
# MatrixA = VGroup() | |
MatrixB = VGroup() | |
MatrixC = VGroup() | |
# ----------------> MATRIX A | |
MatrixA = fullMatrix(matrixContents=sent.numpy(), matrixHeight=5, matrixColor=PURPLE, matrixName=f"WORD EMBEDDINGS", rowNames=words, colNames=dimensions) | |
# ----------------> PLUS SIGN | |
Plus = Tex("+", color=WHITE, font_size = 300).next_to(MatrixA[0], RIGHT*2).add_updater(lambda x: x.next_to(MatrixA[0], RIGHT*2)) | |
# ----------------> MATRIX B | |
MatrixB = fullMatrix(matrixContents=pos_enc.numpy(), matrixHeight=5, matrixColor=PURPLE, matrixName=f"POSITIONAL ENCODINGS", rowNames=positions, colNames=dimensions) | |
MatrixB[0].next_to(Plus, RIGHT, buff=MatrixB[0].width/3).add_updater(lambda x: x.next_to(Plus, RIGHT, buff=MatrixB[0].width/3)) | |
# ----------------> EQUALS SIGN | |
Equals = Tex("=", color=WHITE, font_size = 300).next_to(MatrixB[0], RIGHT*2).add_updater(lambda e: e.next_to(MatrixB[0], RIGHT*2)) | |
# # ----------------> ACCUMULATING THE UPPER LINE | |
_LHS.add(VGroup(*[MatrixA, Plus, MatrixB, Equals])) | |
_LHS.move_to(ORIGIN) | |
self.play(Write(_LHS), run_time = 3) | |
# ----------------> MATRIX C | |
MatrixC = fullMatrix(matrixContents=pos_encoded.numpy(), matrixHeight=6, matrixColor=RED, matrixName=f"POSITION AWARE WORD EMBEDDINGS", rowNames=positions, colNames=dimensions, rightColNames=positions, rCols=True) | |
MatrixC.next_to(_LHS, DOWN*5).add_updater(lambda x,y=_LHS: x.next_to(y, DOWN*5)) | |
self.play(TransformMatchingShapes(_LHS.copy(), MatrixC), run_time = 2) | |
self.wait(8) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment