Skip to content

Instantly share code, notes, and snippets.

@arif9799
Created August 8, 2023 02:18
Show Gist options
  • Save arif9799/66b3b20e9531b9f2150ae3b1caaba35b to your computer and use it in GitHub Desktop.
Save arif9799/66b3b20e9531b9f2150ae3b1caaba35b to your computer and use it in GitHub Desktop.
Attention is not enough! Scaled Dot Product Attention Animation
from manim import *
from manim.utils.unit import Percent, Pixels
from colour import Color
config.frame_width = 146
config.frame_height = 85.5
HIGH = True
if HIGH:
print("Its high")
config.pixel_height = 2160
config.pixel_width = 3840
else:
print("Its low")
config.pixel_height = 1080
config.pixel_width = 1920
def scaledDotProductFigure():
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
################################################################################
# SCALED DOT PRODUCT FIGURE CODE BELOW #
################################################################################
sDPCaps = VGroup() # scaled dot product Captions
sDPBoxes = VGroup() # scaled dot product Boxes
sDPArrows = VGroup() # scaled dot product Boxes
vBuff = 18
boxBuff = MED_LARGE_BUFF
# Q
_Q = Text("Q", font_size=96, slant=OBLIQUE).shift(DOWN*10)
sDPCaps.add(_Q)
# K
_K = Text("K", font_size=96, slant=OBLIQUE).next_to(_Q, RIGHT*20)
sDPCaps.add(_K)
# Matmul box 1
_matMulTxt1 = Text("Matrix Multiplication 1", font_size=96, slant=OBLIQUE).next_to(VGroup(*[_Q, _K]), UP*vBuff)
_matMulBox1 = always_redraw(lambda : SurroundingRectangle(mobject=_matMulTxt1, color=PURPLE, fill_opacity = 0.25, fill_color = PURPLE, corner_radius=0.3, buff=boxBuff))
sDPCaps.add(_matMulTxt1)
sDPBoxes.add(_matMulBox1)
# Scale Box
_scaleTxt = Text("Scale", font_size=96, slant=OBLIQUE).next_to(_matMulTxt1, UP*vBuff)
_scaleBox = always_redraw(lambda : SurroundingRectangle(mobject=_scaleTxt, color=YELLOW, fill_opacity = 0.25, fill_color = YELLOW, corner_radius=0.3, buff=boxBuff))
sDPCaps.add(_scaleTxt)
sDPBoxes.add(_scaleBox)
# Mask Optional
_maskTxt = Text("Mask (Optional)", font_size=96, slant=OBLIQUE).next_to(_scaleTxt, UP*vBuff)
_maskBox = always_redraw(lambda : SurroundingRectangle(mobject=_maskTxt, color=PINK, fill_opacity = 0.25, fill_color = PINK, corner_radius=0.3, buff=boxBuff))
sDPCaps.add(_maskTxt)
sDPBoxes.add(_maskBox)
# Softmax Optional
_sFTxt = Text("Softmax", font_size=96, slant=OBLIQUE).next_to(_maskTxt, UP*vBuff)
_sFBox = always_redraw(lambda : SurroundingRectangle(mobject=_sFTxt, color=GREEN, fill_opacity = 0.25, fill_color = GREEN, corner_radius=0.3, buff=boxBuff))
sDPCaps.add(_sFTxt)
sDPBoxes.add(_sFBox)
# V
_V = Text("V", font_size=96, slant=OBLIQUE).next_to(_sFTxt, RIGHT*15)
sDPCaps.add(_V)
# Matmul box 2
_matMulTxt2 = Text("Matrix Multiplication 2", font_size=96, slant=OBLIQUE).next_to(VGroup(*[_V, _sFTxt]), UP*vBuff)
_matMulBox2 = always_redraw(lambda : SurroundingRectangle(mobject=_matMulTxt2, color=PURPLE, fill_opacity = 0.25, fill_color = PURPLE, corner_radius=0.3, buff=boxBuff))
sDPCaps.add(_matMulTxt2)
sDPBoxes.add(_matMulBox2)
# _Q to matmul box 1
_Q_to_matMulBox1 = always_redraw(lambda : Arrow(start = _Q.get_top(), end = [_Q.get_top()[0], _matMulBox1.get_bottom()[1],0], stroke_width=15, buff=0.2))
sDPArrows.add(_Q_to_matMulBox1)
# _K to matmul box 1
_K_to_matMulBox1 = always_redraw(lambda : Arrow(start = _K.get_top(), end = [_K.get_top()[0], _matMulBox1.get_bottom()[1],0], stroke_width=15, buff=0.2))
sDPArrows.add(_K_to_matMulBox1)
# matMul Box 1 to scaleBox
_matMulBox1_to_scaleBox = always_redraw(lambda : Arrow(start = _matMulBox1.get_top(), end = _scaleBox.get_bottom(), stroke_width=15, buff=0.2))
sDPArrows.add(_matMulBox1_to_scaleBox)
# scaleBox to MaskBox
_scaleBox_to_maskBox = always_redraw(lambda : Arrow(start = _scaleBox.get_top(), end = _maskBox.get_bottom(), stroke_width=15, buff=0.2))
sDPArrows.add(_scaleBox_to_maskBox)
# MaskBox to softmax
sFMax_to_scaleBox = always_redraw(lambda : Arrow(start = _maskBox.get_top(), end = _sFBox.get_bottom(), stroke_width=15, buff=0.2))
sDPArrows.add(sFMax_to_scaleBox)
# softmax to matmul box 2
_sFBox_to_matMulBox2 = always_redraw(lambda : Arrow(start = _sFBox.get_top(), end = [_sFBox.get_top()[0], _matMulBox2.get_bottom()[1],0], stroke_width=15, buff=0.2))
sDPArrows.add(_sFBox_to_matMulBox2)
# _V to matmul box 2
_V_to_matMulBox2 = always_redraw(lambda : Arrow(start = _V.get_top(), end = [_V.get_top()[0], _matMulBox2.get_bottom()[1],0], stroke_width=15, buff=0.2))
sDPArrows.add(_V_to_matMulBox2)
return sDPCaps, sDPBoxes, sDPArrows
class Figure9(Scene):
def construct(self):
def myAnimation(wordsForIntro: str):
fontHeight = config.frame_height//8
fontColor = WHITE
timePerChar = 0.1
C = MathTex(r"\mathbb{C}", color = fontColor).scale(config.frame_height//3)
self.play(Broadcast(C), run_time=1)
self.add(C)
h = config.frame_height // 1.5
svg2 = SVGMobject("/Users/arifwaghbakriwala/Desktop/Northeastern/Projects/Manimations/assets/svg/pl2.svg",
height=config.frame_height,
width= config.frame_width,
stroke_color=MAROON,
stroke_width=7,
fill_color=BLUE,
fill_opacity=0
)#.to_edge(LEFT).rotate(180*DEGREES).flip(RIGHT)
self.play(Write(svg2), run_time=2)
self.add(svg2)
# Building text objects of individual characters.
wordsMath = VGroup()
for word in wordsForIntro:
charTex = VGroup()
for i,ch in enumerate(word):
chTex = MathTex("\mathbb{%s}"%ch, color = fontColor).scale(fontHeight)
if i != 0:
chTex.next_to(charTex[-1], RIGHT, buff=0.05).align_to(C, DOWN)
else:
chTex.next_to(C, RIGHT, buff=0.05).align_to(C, DOWN)
charTex.add(chTex)
wordsMath.add(charTex)
# Succesion or AnimationGroup--- Both are messed up ----HENCE INEFFECIENT ANIMATION
for wInd in range(len(wordsMath)):
for chInd in range(len(wordsMath[wInd])):
self.play(Write(wordsMath[wInd][chInd], run_time = timePerChar))
self.wait(0.5)
for chInd in reversed(range(len(wordsMath[wInd]))):
self.play(Unwrite(wordsMath[wInd][chInd], run_time = timePerChar))
self.play(Circumscribe(C, color=MAROON_E, fade_out=False, time_width=2, shape=Circle, buff=1, stroke_width=config.frame_height//5, run_time=1.5))
self.play(AnimationGroup(*[ShrinkToCenter(C, run_time=0.75), ShrinkToCenter(svg2, run_time=0.75)]))
self.wait(0.5)
self.remove(C,svg2, wordsMath)
myAnimation(wordsForIntro= ['OMPLEX', 'ONCEPTS', 'OMPREHENSIBLE'])
################################################################################
# MAIN CODE OF THE FIGURE STARTS HERE #
# NO FUNCTION OR ANYTHING OF THAT SORT #
################################################################################
# ---------------> TITLE OF VIDEO
title = Title('Scaled Dot Product Attention', match_underline_width_to_text=True, underline_buff= MED_LARGE_BUFF, font_size = 300).shift(DOWN*7)
self.play(DrawBorderThenFill(title), run_time = 2)
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
scaleCaps, scaleBoxes, scaleArrows = scaledDotProductFigure()
unifiedMobject = VGroup(*[scaleCaps, scaleBoxes, scaleArrows])
self.play(Write(unifiedMobject))
self.play(scaleCaps.animate.move_to(ORIGIN))
self.play(ShowPassingFlash(scaleArrows.copy(),time_width=3, color=RED))
self.play(unifiedMobject.animate.set_height(config.frame_height//2))
formula = MathTex('Softmax \: \\biggl( \\frac{(Q\: . \:K^T)}{sqrt{d}} \\biggr) \: . \: V', font_size = 300, color = WHITE).next_to(scaleCaps, RIGHT* 10)
self.play(DrawBorderThenFill(formula), run_time = 2)
self.play(VGroup(*[ scaleCaps, formula]).animate.move_to(ORIGIN))
self.wait(10)
#````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````````
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment