Created
August 8, 2023 02:18
-
-
Save arif9799/66b3b20e9531b9f2150ae3b1caaba35b to your computer and use it in GitHub Desktop.
Attention is not enough! Scaled Dot Product Attention Animation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from manim import * | |
from manim.utils.unit import Percent, Pixels | |
from colour import Color | |
config.frame_width = 146 | |
config.frame_height = 85.5 | |
HIGH = True | |
if HIGH: | |
print("Its high") | |
config.pixel_height = 2160 | |
config.pixel_width = 3840 | |
else: | |
print("Its low") | |
config.pixel_height = 1080 | |
config.pixel_width = 1920 | |
def scaledDotProductFigure(): | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
################################################################################ | |
# SCALED DOT PRODUCT FIGURE CODE BELOW # | |
################################################################################ | |
sDPCaps = VGroup() # scaled dot product Captions | |
sDPBoxes = VGroup() # scaled dot product Boxes | |
sDPArrows = VGroup() # scaled dot product Boxes | |
vBuff = 18 | |
boxBuff = MED_LARGE_BUFF | |
# Q | |
_Q = Text("Q", font_size=96, slant=OBLIQUE).shift(DOWN*10) | |
sDPCaps.add(_Q) | |
# K | |
_K = Text("K", font_size=96, slant=OBLIQUE).next_to(_Q, RIGHT*20) | |
sDPCaps.add(_K) | |
# Matmul box 1 | |
_matMulTxt1 = Text("Matrix Multiplication 1", font_size=96, slant=OBLIQUE).next_to(VGroup(*[_Q, _K]), UP*vBuff) | |
_matMulBox1 = always_redraw(lambda : SurroundingRectangle(mobject=_matMulTxt1, color=PURPLE, fill_opacity = 0.25, fill_color = PURPLE, corner_radius=0.3, buff=boxBuff)) | |
sDPCaps.add(_matMulTxt1) | |
sDPBoxes.add(_matMulBox1) | |
# Scale Box | |
_scaleTxt = Text("Scale", font_size=96, slant=OBLIQUE).next_to(_matMulTxt1, UP*vBuff) | |
_scaleBox = always_redraw(lambda : SurroundingRectangle(mobject=_scaleTxt, color=YELLOW, fill_opacity = 0.25, fill_color = YELLOW, corner_radius=0.3, buff=boxBuff)) | |
sDPCaps.add(_scaleTxt) | |
sDPBoxes.add(_scaleBox) | |
# Mask Optional | |
_maskTxt = Text("Mask (Optional)", font_size=96, slant=OBLIQUE).next_to(_scaleTxt, UP*vBuff) | |
_maskBox = always_redraw(lambda : SurroundingRectangle(mobject=_maskTxt, color=PINK, fill_opacity = 0.25, fill_color = PINK, corner_radius=0.3, buff=boxBuff)) | |
sDPCaps.add(_maskTxt) | |
sDPBoxes.add(_maskBox) | |
# Softmax Optional | |
_sFTxt = Text("Softmax", font_size=96, slant=OBLIQUE).next_to(_maskTxt, UP*vBuff) | |
_sFBox = always_redraw(lambda : SurroundingRectangle(mobject=_sFTxt, color=GREEN, fill_opacity = 0.25, fill_color = GREEN, corner_radius=0.3, buff=boxBuff)) | |
sDPCaps.add(_sFTxt) | |
sDPBoxes.add(_sFBox) | |
# V | |
_V = Text("V", font_size=96, slant=OBLIQUE).next_to(_sFTxt, RIGHT*15) | |
sDPCaps.add(_V) | |
# Matmul box 2 | |
_matMulTxt2 = Text("Matrix Multiplication 2", font_size=96, slant=OBLIQUE).next_to(VGroup(*[_V, _sFTxt]), UP*vBuff) | |
_matMulBox2 = always_redraw(lambda : SurroundingRectangle(mobject=_matMulTxt2, color=PURPLE, fill_opacity = 0.25, fill_color = PURPLE, corner_radius=0.3, buff=boxBuff)) | |
sDPCaps.add(_matMulTxt2) | |
sDPBoxes.add(_matMulBox2) | |
# _Q to matmul box 1 | |
_Q_to_matMulBox1 = always_redraw(lambda : Arrow(start = _Q.get_top(), end = [_Q.get_top()[0], _matMulBox1.get_bottom()[1],0], stroke_width=15, buff=0.2)) | |
sDPArrows.add(_Q_to_matMulBox1) | |
# _K to matmul box 1 | |
_K_to_matMulBox1 = always_redraw(lambda : Arrow(start = _K.get_top(), end = [_K.get_top()[0], _matMulBox1.get_bottom()[1],0], stroke_width=15, buff=0.2)) | |
sDPArrows.add(_K_to_matMulBox1) | |
# matMul Box 1 to scaleBox | |
_matMulBox1_to_scaleBox = always_redraw(lambda : Arrow(start = _matMulBox1.get_top(), end = _scaleBox.get_bottom(), stroke_width=15, buff=0.2)) | |
sDPArrows.add(_matMulBox1_to_scaleBox) | |
# scaleBox to MaskBox | |
_scaleBox_to_maskBox = always_redraw(lambda : Arrow(start = _scaleBox.get_top(), end = _maskBox.get_bottom(), stroke_width=15, buff=0.2)) | |
sDPArrows.add(_scaleBox_to_maskBox) | |
# MaskBox to softmax | |
sFMax_to_scaleBox = always_redraw(lambda : Arrow(start = _maskBox.get_top(), end = _sFBox.get_bottom(), stroke_width=15, buff=0.2)) | |
sDPArrows.add(sFMax_to_scaleBox) | |
# softmax to matmul box 2 | |
_sFBox_to_matMulBox2 = always_redraw(lambda : Arrow(start = _sFBox.get_top(), end = [_sFBox.get_top()[0], _matMulBox2.get_bottom()[1],0], stroke_width=15, buff=0.2)) | |
sDPArrows.add(_sFBox_to_matMulBox2) | |
# _V to matmul box 2 | |
_V_to_matMulBox2 = always_redraw(lambda : Arrow(start = _V.get_top(), end = [_V.get_top()[0], _matMulBox2.get_bottom()[1],0], stroke_width=15, buff=0.2)) | |
sDPArrows.add(_V_to_matMulBox2) | |
return sDPCaps, sDPBoxes, sDPArrows | |
class Figure9(Scene): | |
def construct(self): | |
def myAnimation(wordsForIntro: str): | |
fontHeight = config.frame_height//8 | |
fontColor = WHITE | |
timePerChar = 0.1 | |
C = MathTex(r"\mathbb{C}", color = fontColor).scale(config.frame_height//3) | |
self.play(Broadcast(C), run_time=1) | |
self.add(C) | |
h = config.frame_height // 1.5 | |
svg2 = SVGMobject("/Users/arifwaghbakriwala/Desktop/Northeastern/Projects/Manimations/assets/svg/pl2.svg", | |
height=config.frame_height, | |
width= config.frame_width, | |
stroke_color=MAROON, | |
stroke_width=7, | |
fill_color=BLUE, | |
fill_opacity=0 | |
)#.to_edge(LEFT).rotate(180*DEGREES).flip(RIGHT) | |
self.play(Write(svg2), run_time=2) | |
self.add(svg2) | |
# Building text objects of individual characters. | |
wordsMath = VGroup() | |
for word in wordsForIntro: | |
charTex = VGroup() | |
for i,ch in enumerate(word): | |
chTex = MathTex("\mathbb{%s}"%ch, color = fontColor).scale(fontHeight) | |
if i != 0: | |
chTex.next_to(charTex[-1], RIGHT, buff=0.05).align_to(C, DOWN) | |
else: | |
chTex.next_to(C, RIGHT, buff=0.05).align_to(C, DOWN) | |
charTex.add(chTex) | |
wordsMath.add(charTex) | |
# Succesion or AnimationGroup--- Both are messed up ----HENCE INEFFECIENT ANIMATION | |
for wInd in range(len(wordsMath)): | |
for chInd in range(len(wordsMath[wInd])): | |
self.play(Write(wordsMath[wInd][chInd], run_time = timePerChar)) | |
self.wait(0.5) | |
for chInd in reversed(range(len(wordsMath[wInd]))): | |
self.play(Unwrite(wordsMath[wInd][chInd], run_time = timePerChar)) | |
self.play(Circumscribe(C, color=MAROON_E, fade_out=False, time_width=2, shape=Circle, buff=1, stroke_width=config.frame_height//5, run_time=1.5)) | |
self.play(AnimationGroup(*[ShrinkToCenter(C, run_time=0.75), ShrinkToCenter(svg2, run_time=0.75)])) | |
self.wait(0.5) | |
self.remove(C,svg2, wordsMath) | |
myAnimation(wordsForIntro= ['OMPLEX', 'ONCEPTS', 'OMPREHENSIBLE']) | |
################################################################################ | |
# MAIN CODE OF THE FIGURE STARTS HERE # | |
# NO FUNCTION OR ANYTHING OF THAT SORT # | |
################################################################################ | |
# ---------------> TITLE OF VIDEO | |
title = Title('Scaled Dot Product Attention', match_underline_width_to_text=True, underline_buff= MED_LARGE_BUFF, font_size = 300).shift(DOWN*7) | |
self.play(DrawBorderThenFill(title), run_time = 2) | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` | |
scaleCaps, scaleBoxes, scaleArrows = scaledDotProductFigure() | |
unifiedMobject = VGroup(*[scaleCaps, scaleBoxes, scaleArrows]) | |
self.play(Write(unifiedMobject)) | |
self.play(scaleCaps.animate.move_to(ORIGIN)) | |
self.play(ShowPassingFlash(scaleArrows.copy(),time_width=3, color=RED)) | |
self.play(unifiedMobject.animate.set_height(config.frame_height//2)) | |
formula = MathTex('Softmax \: \\biggl( \\frac{(Q\: . \:K^T)}{sqrt{d}} \\biggr) \: . \: V', font_size = 300, color = WHITE).next_to(scaleCaps, RIGHT* 10) | |
self.play(DrawBorderThenFill(formula), run_time = 2) | |
self.play(VGroup(*[ scaleCaps, formula]).animate.move_to(ORIGIN)) | |
self.wait(10) | |
#```````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````` |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment