Skip to content

Instantly share code, notes, and snippets.

@arif9799
Last active August 10, 2023 16:23
Show Gist options
  • Save arif9799/3ea4482c18281cd919457261f3173016 to your computer and use it in GitHub Desktop.
Save arif9799/3ea4482c18281cd919457261f3173016 to your computer and use it in GitHub Desktop.
Attention is not enough! Positional Encoding Formula
from manim import *
from manim.utils.unit import Percent, Pixels
from colour import Color
import gensim
from gensim.models import Word2Vec
import numpy as np
import torch as t
config.frame_width = 48
config.frame_height = 27
HIGH = True
if HIGH:
print("Its high")
config.pixel_height = 2160
config.pixel_width = 3840
else:
print("Its low")
config.pixel_height = 1080
config.pixel_width = 1920
################################################################################
# Creating Word Vectors #
################################################################################
def construct_Word_Embeddings(sentences, dim):
data = sentences
dimension = dim
model = Word2Vec(data,
min_count = 1,
vector_size = dimension,
window = 3)
sequences = []
for seq in data:
sequence = []
for word in seq:
word_emb = model.wv.get_vector(word, True).reshape(1,-1)
word_emb = t.from_numpy(word_emb)
sequence.append(word_emb)
sequences.append(t.round(t.cat(sequence, dim=0), decimals=2))
return sequences
data = [['Attention', 'is', 'not', 'enough'], ['Towards', 'Data', 'Science'], ['It', 'was', 'a', 'really', 'good', 'article']]
sentences_Embedded = construct_Word_Embeddings(data, 6)
################################################################################
# Creating Positional Encoding Vector #
################################################################################
def construct_Positional_Encoding(input):
x = input.unsqueeze(0)
max_len = 10000
input_dimension = x.shape[-1]
seq_len = x.shape[-2]
dimension_indices = t.arange(0,input_dimension, 2)
denominator = t.pow(max_len, dimension_indices/input_dimension)
numerator = t.arange(0, seq_len, 1).reshape(seq_len, 1)
even_positions_encoded = t.sin(numerator/denominator)
odd_positions_encoded = t.cos(numerator/denominator)
combined = t.stack([even_positions_encoded, odd_positions_encoded], dim = 2)
pe = combined.reshape(1,seq_len,input_dimension)
pe = pe.squeeze(0)
x = x.squeeze(0)
pe = t.round(pe, decimals=2)
x_pe = t.round(x + pe, decimals=2)
return x, pe, x_pe
sent, pos_enc, pos_encoded = construct_Positional_Encoding(sentences_Embedded[0])
class Figure6(Scene):
def construct(self):
################################################################################
# INTRO ANIMATION SAME ACROSS ALL VIDEOS #
################################################################################
def myAnimation(wordsForIntro: str):
fontHeight = config.frame_height//8
fontColor = WHITE
timePerChar = 0.1
C = MathTex(r"\mathbb{C}", color = fontColor).scale(config.frame_height//3)
self.play(Broadcast(C), run_time=1)
self.add(C)
# Building text objects of individual characters.
wordsMath = VGroup()
for word in wordsForIntro:
charTex = VGroup()
for i,ch in enumerate(word):
chTex = MathTex("\mathbb{%s}"%ch, color = fontColor).scale(fontHeight)
if i != 0:
chTex.next_to(charTex[-1], RIGHT, buff=0.05).align_to(C, DOWN)
else:
chTex.next_to(C, RIGHT, buff=0.05).align_to(C, DOWN)
charTex.add(chTex)
wordsMath.add(charTex)
# Succesion or AnimationGroup--- Both are messed up ----HENCE INEFFECIENT ANIMATION
for wInd in range(len(wordsMath)):
for chInd in range(len(wordsMath[wInd])):
self.play(Write(wordsMath[wInd][chInd], run_time = timePerChar))
self.wait(0.5)
for chInd in reversed(range(len(wordsMath[wInd]))):
self.play(Unwrite(wordsMath[wInd][chInd], run_time = timePerChar))
self.play(Circumscribe(C, color=MAROON_E, fade_out=False, time_width=2, shape=Circle, buff=1, stroke_width=config.frame_height//3, run_time=1.5))
self.play(ShrinkToCenter(C, run_time=0.25))
self.wait(0.5)
myAnimation(wordsForIntro= ['OMPLEX', 'ONCEPTS', 'OMPREHENSIBLE'])
################################################################################
# MAIN CODE OF THE FIGURE STARTS HERE #
# NO FUNCTION OR ANYTHING OF THAT SORT #
################################################################################
# ------------> VARIABLES DEFINITION
seq_len, dimension = sent.shape[0],sent.shape[1]
mainFormula = VGroup()
ledger = VGroup()
row_names_f = [MathTex('Word_' + str(i), font_size = 96).set_color(PURPLE) for i in range(seq_len)]
col_names_f = [MathTex('Dim_' + str(i), font_size = 96).set_color(PURPLE) for i in range(dimension)]
row_names_v = [MathTex('W_' + str(i)).set_color(PURPLE) for i in range(seq_len)]
col_names_v = [MathTex('D_' + str(i)).set_color(PURPLE) for i in range(dimension)]
num = 'position'
den = 'n^{ \\frac{2i}{d_{model}} }'
d_model = pos_enc.shape[1]
# ------------> MAIN FORMULA FOR EVEN AND ODD POSITION
evenPosition = 'PE_{position,2i} = \\sin \\Bigg( \\dfrac{%s}{%s}} \\Bigg)' %(num,den)
mainFormula.add(MathTex(evenPosition, font_size = 100).set_color(BLUE))
oddPosition = 'PE_{position,2i+1} = \\cos \\Bigg(\\dfrac{%s}{%s}} \\Bigg)' %(num,den)
mainFormula.add(MathTex(oddPosition, font_size = 100).set_color(GREEN))
# ------------> LEDGER OF MAIN FORMULA
ps = MathTex('position \Longleftarrow \: position \: of \: word \: or \: token \: in \: a \: sequence', font_size = 64, color = YELLOW)
ind = MathTex('i \\Longleftarrow \: index \: of \: embedding \: dimensions', font_size = 64, color = YELLOW)
dm = MathTex('d_{model} \\Longleftarrow \: embedding \: dimension \: of \: input \: for \: the \: model \:( \:for \:example, \:we \:will \:take \:6 \:below)', font_size = 64, color = YELLOW)
no = MathTex('n \\Longleftarrow \: Value \: of \: Max \: tokens \: in \: any \: given \: sequence \: (\: assumed \: to \: be \: 10000 \: )', font_size = 64, color = YELLOW)
ledger.add(*[ps,ind,dm,no])
# ------------> TABLE OF FORMULAES AND VALUES FOR POSITIONAL ENCODING MATRIX
# formula table dynamic version 1
row_names_f = [MathTex('Word_' + str(i), font_size = 96).set_color(PURPLE) for i in range(seq_len)]
col_names_f = [MathTex('Dim_' + str(i), font_size = 96).set_color(PURPLE) for i in range(dimension)]
tabl = []
for pos in range(seq_len):
line = []
for i in range(dimension//2):
n = 'position'
d = 'n^{ \\frac{2i}{d_{model}} }'
eP = 'PE_{%d,%d} = \\sin \\Bigg(\\dfrac{%s}{%s}} \\Bigg)' %(pos, 2*i, n, d)
oP= 'PE_{%d,%d} = \\cos \\Bigg(\\dfrac{%s}{%s}} \\Bigg)' %(pos, 2*i + 1, n, d)
line.append(eP)
line.append(oP)
tabl.append(line)
formulaTable = MathTable(table=tabl, row_labels = row_names_f, col_labels = col_names_f).shift(DOWN*4)
del row_names_f, col_names_f
# formula table dynamic version 2
row_names_f = [MathTex('Word_' + str(i), font_size = 96).set_color(PURPLE) for i in range(seq_len)]
col_names_f = [MathTex('Dim_' + str(i), font_size = 96).set_color(PURPLE) for i in range(dimension)]
tabl = []
for pos in range(seq_len):
line = []
for i in range(dimension//2):
n = '%d'%pos
d = 'n^{ \\frac{2i}{d_{model}} }'
eP = 'PE_{%d,%d} = \\sin \\Bigg(\\dfrac{%s}{%s}} \\Bigg)' %(pos, 2*i, n, d)
oP= 'PE_{%d,%d} = \\cos \\Bigg(\\dfrac{%s}{%s}} \\Bigg)' %(pos, 2*i + 1, n, d)
line.append(eP)
line.append(oP)
tabl.append(line)
formulaTable2 = MathTable(table=tabl, row_labels = row_names_f, col_labels = col_names_f).shift(DOWN*4)
del row_names_f, col_names_f
# formula table dynamic version 3
row_names_f = [MathTex('Word_' + str(i), font_size = 96).set_color(PURPLE) for i in range(seq_len)]
col_names_f = [MathTex('Dim_' + str(i), font_size = 96).set_color(PURPLE) for i in range(dimension)]
tabl = []
for pos in range(seq_len):
line = []
for i in range(dimension//2):
n = '%d'%pos
d = 'n^{ \\frac{%d}{d_{model}} }'%(2*i)
eP = 'PE_{%d,%d} = \\sin \\Bigg(\\dfrac{%s}{%s}} \\Bigg)' %(pos, 2*i, n, d)
oP= 'PE_{%d,%d} = \\cos \\Bigg(\\dfrac{%s}{%s}} \\Bigg)' %(pos, 2*i + 1, n, d)
line.append(eP)
line.append(oP)
tabl.append(line)
formulaTable3 = MathTable(table=tabl, row_labels = row_names_f, col_labels = col_names_f).shift(DOWN*4)
del row_names_f, col_names_f
# formula table dynamic version 4
row_names_f = [MathTex('Word_' + str(i), font_size = 96).set_color(PURPLE) for i in range(seq_len)]
col_names_f = [MathTex('Dim_' + str(i), font_size = 96).set_color(PURPLE) for i in range(dimension)]
tabl = []
for pos in range(seq_len):
line = []
for i in range(dimension//2):
n = '%d'%pos
d = 'n^{ \\frac{%d}{%d} }'%(2*i, d_model)
eP = 'PE_{%d,%d} = \\sin \\Bigg(\\dfrac{%s}{%s}} \\Bigg)' %(pos, 2*i, n, d)
oP= 'PE_{%d,%d} = \\cos \\Bigg(\\dfrac{%s}{%s}} \\Bigg)' %(pos, 2*i + 1, n, d)
line.append(eP)
line.append(oP)
tabl.append(line)
formulaTable4 = MathTable(table=tabl, row_labels = row_names_f, col_labels = col_names_f).shift(DOWN*4)
del row_names_f, col_names_f
# formula table dynamic version 5
row_names_f = [MathTex('Word_' + str(i), font_size = 96).set_color(PURPLE) for i in range(seq_len)]
col_names_f = [MathTex('Dim_' + str(i), font_size = 96).set_color(PURPLE) for i in range(dimension)]
tabl = []
for pos in range(seq_len):
line = []
for i in range(dimension//2):
n = '%d'%pos
d = '%d^{ \\frac{%d}{%d} }'%(10000,2*i, d_model)
eP = 'PE_{%d,%d} = \\sin \\Bigg(\\dfrac{%s}{%s}} \\Bigg)' %(pos, 2*i, n, d)
oP= 'PE_{%d,%d} = \\cos \\Bigg(\\dfrac{%s}{%s}} \\Bigg)' %(pos, 2*i + 1, n, d)
line.append(eP)
line.append(oP)
tabl.append(line)
formulaTable5 = MathTable(table=tabl, row_labels = row_names_f, col_labels = col_names_f).shift(DOWN*4)
formulaValuesTable = MathTable(table= pos_enc.numpy(), row_labels = row_names_v, col_labels = col_names_v)
mainFormula.arrange(DOWN, buff=1).to_edge(UP*4 + LEFT*8)
ledger.arrange(DOWN, buff =1).next_to(mainFormula, RIGHT*5)
self.play(Write(mainFormula),run_time = 3)
self.play(Write(ledger),run_time = 3)
self.play(Write(formulaTable),run_time = 5)
self.wait(2)
self.play(ReplacementTransform(formulaTable, formulaTable2))
self.remove(formulaTable)
self.wait(2)
self.play(ReplacementTransform(formulaTable2, formulaTable3))
self.wait(2)
self.remove(formulaTable2)
self.play(ReplacementTransform(formulaTable3, formulaTable4))
self.wait(2)
self.remove(formulaTable3)
self.play(ReplacementTransform(formulaTable4, formulaTable5))
self.wait(2)
self.remove(formulaTable4)
self.play(formulaTable5.animate.scale(0.75).shift(UP*3))
self.wait(2)
formulaValuesTable.next_to(formulaTable5, DOWN*3)
self.play(TransformMatchingShapes(formulaTable5.copy(), formulaValuesTable))
self.wait(8)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment