Skip to content

Instantly share code, notes, and snippets.

@elisohl-ncc
Created May 17, 2024 17:12
Show Gist options
  • Save elisohl-ncc/104c3c63899e9a0aae6641feaaee7f2b to your computer and use it in GitHub Desktop.
Save elisohl-ncc/104c3c63899e9a0aae6641feaaee7f2b to your computer and use it in GitHub Desktop.
from manim import *
from manim_cranim import *
from manim_cranim.util import _enc, _dec, _rng, blk
from manim.mobject.text.text_mobject import remove_invisible_chars # not included in __all__ for some reason
from Crypto.Cipher import AES
from util import ToCScene, inc
from itertools import product
from random import Random
from decimal import Decimal
from math import log2
_seed = _rng.randbytes(256)
_rng.seed(_enc(b'lol sup'*13))
IV = blk()
_rng.seed(_seed)
MSG = b'YELLOW SUBMARINE'
DEC_OUT = bytes_xor(MSG, IV)
CT = _enc(DEC_OUT)
FP_RATE = FN_RATE = Decimal(0.28)
"""
hello and welcome
taking on challenge 17, from the start of set 3
this challenge is a spin on the classic padding oracle attack
which is one of my absolute favorite crypto attacks,
and it's one that shows up in the wild surprisingly often.
Now, if you prefer reading over watching videos, you're in luck.
I have a pair of blog posts, one of which inspired this series, and the other of which was inspired by it.
The first of these posts is a quick introduction to the padding oracle attack,
and you can find it by googling "padding oracle".
This post was a direct precursor to the Guided Tour; I wanted to write a series of deep-dives in this style,
but I quickly realized that videos would work better for the amount of information I had to share.
The second post actually grew out of the background research I did for this video:
I wanted to discuss how the attack changes when your oracle is less than perfectly reliable,
but I couldn't find a good primary source on that, so I wrote one.
but anyway, that's enough backstory - let's get to the crypto.
"""
title = "Padding Oracles"
sections = [
"Background",
"In the wild",
"Problems",
"The attack (simple case)", # reliable oracle
"Layers of defense",
"The attack (hard case)", # unreliable oracle, no mac
"The attack (impossible difficulty)", # same as above, with mac
]
class c17_00_toc(ToCScene):
title = title
sections = sections
"""
in this attack we'll be exploiting a very small information leak.
let's start with the big-picture view.
the setup is:
the defender comes up with a plaintext,
which they encode, pad out to the block length,
and encrypt with AES-CBC using a secret key.
they take the result and provide it to the attacker.
so far so good.
now, we are the attacker.
our goal is to decrypt this message.
how is that possible? well, in general, it's not, because we don't know the key.
we need some kind of toehold.
however, a very small toehold will do.
for instance, what if we could get the defender to decrypt a message for us, then tell us whether it has valid padding?
[show dialogue on screen]
now, this might seem unrealistic, but you'd be surprised.
it is still quite common for systems to leak this information by accident.
and a recurring theme in cryptography is that small information leaks can compound into big attacks.
the padding oracle attack might just be the most famous variation on this theme,
though it does have some competition from, say, Bleichenbacher's attacks on RSA,
which we'll cover in a later video. but I digress.
we're here to talk about padding oracles. but before we get into the attack itself,
let's take a quick detour to observe a padding oracle in its natural habitat
"""
class c17_01_intro(MovingCameraScene):
def construct(self):
# intro from toc
toc = c17_00_toc().title_and_toc(indicated=0)[1][0]
toc.save_state()
frame_orig = self.camera.frame.copy()
# introduce message as text
pt_msg = "This is a plaintext string!"
text = Text(pt_msg).scale(0.7)
self.play(ReplacementTransform(toc, text), run_time=2)
self.next_section()
cbc = CBCBlocks("This is a plaintext string!", direction=RIGHT)#.zoom(0.5)
pts = Group(*[ch for block in cbc.blocks for ch in block.pt])
# message text -> message bytes
self.camera.frame.save_state()
self.play(
TextToBuffer(text, pts, pad=False, lag_ratio=0.01), # this would be terser if we used pad=True here and deleted the next block but that would mean rewriting it and no one has time for that
self.camera.frame.animate.move_to(pts).scale_to_fit_width(pts.width*1.2),
lag_ratio=0.17,
)
# add padding
pad_bytes = pts[-5:]
self.play(FadeIn(*pad_bytes, lag_ratio=0.17))
self.remove(*text, *pad_bytes)
self.add(pts)
self.wait(0.1)
self.next_section()
# add rest of cbc machinery
cbc_sans_pts = cbc.copy()
cbc_sans_pts.remove(*[block.pt for block in cbc_sans_pts.blocks])
self.play(self.camera.frame.animate.restore().scale(2.05), Write(cbc_sans_pts))
self.remove(cbc_sans_pts, *pts)
self.add(cbc)
self.wait()
self.next_section()
# add divider + labels
divider = DashedLine(
cbc.get_corner(DOWN+LEFT)+DOWN+0.5*LEFT,
cbc.get_corner(DOWN+RIGHT)+DOWN+0.5*RIGHT,
dash_length=0.15,
)
self.play(
Write(divider),
self.camera.frame.animate.move_to(divider),
)
label_defender = Text("Defender", color=C_PT).align_to(self.camera.frame, LEFT+UP).shift(DOWN+RIGHT)
label_attacker = Text("Attacker", color=C_CT).align_to(self.camera.frame, LEFT).align_to(divider, DOWN).shift(DOWN+RIGHT)
self.play(
FadeIn(label_defender, label_attacker),
)
self.next_section()
# give the IV+ciphertext to the attacker
full_ct = VGroup(
cbc.blocks[0].iv.copy(),
cbc.blocks[0].ct.copy(),
cbc.blocks[1].ct.copy()
)
ct_divider_gap = DOWN*(divider.get_center()[1] - cbc.blocks[0].ct.get_center()[1])
self.play(full_ct.animate.move_to(self.camera.frame.get_edge_center(DOWN)+ct_divider_gap))
self.next_section()
anchor = Point().move_to(Group(divider, full_ct))
q_1_txt = "What can you tell me about this ciphertext?"
q_1 = Text(q_1_txt, color=C_CT, stroke_color=C_CT)
a_1_txt = "What do you want to know?"
a_1 = Text(a_1_txt, color=C_PT, stroke_color=C_PT)
q_2_txt = "After decryption, does it have valid padding?"
q_2 = Text(q_2_txt, color=C_CT, stroke_color=C_CT)
a_2_txt = "Yea, I guess so... why?"
a_2 = Text(a_2_txt, color=C_PT, stroke_color=C_PT)
q_3_txt = "lol you'll see :)"
q_3 = Text(q_3_txt, color=C_CT, stroke_color=C_CT)
qna = Group(q_1, a_1, q_2, a_2, q_3).arrange(DOWN).move_to(anchor).shift(1.3*RIGHT+0.2*UP)
for mobj in qna:
mobj.align_to(qna, LEFT)
#self.add(qna)
for line in qna:
self.play(FadeIn(line, shift=0.5*LEFT, run_time=1.5))
self.next_section()
#self.play(FadeOut(*self.mobjects))
self.remove(*qna)
self.play(AnimationGroup(FadeOut(*self.mobjects),
AnimationGroup(Transform(qna, toc.restore()), self.camera.frame.animate.become(frame_orig)),
lag_ratio=0.5,
run_time=3))
self.wait(0.1)
self.next_section()
"""
suppose you’re reviewing an API for a web site that keeps track of its users
using encrypted tokens. luckily for us, the developer thought this idea was
simple enough that they could get away with rolling their own crypto.
Here's what they're doing.
First, they take some encoded data about the user (some of which might be
secret from the user), then they encrypt this data under AES-CBC using a
key that is only stored server-side.
They concatenate the resulting IV and ciphertext, then send this blob to the user.
As the user, we don't know the decryption key, so we can't immediately
decrypt this, and we're not meant to.
The idea is that we save the ciphertext and include it in future API queries.
When we do this, the server will be able to decrypt it, read the contents, and learn all about us.
This is all well and good.
But what happens if we provide an invalid token?
What if we just make one up?
There are two main error cases here.
The first is when the padding is bad and the decryption fails, which is the most likely outcome.
The second is when the padding is valid. In this case, the data is probably still garbage, causing the deserialization to fail.
In both of these cases, our query fails and we get an error back from the server.
That's fine though, because right now we're not interested in making this query succeed -
we're just interested in finding out _how_ it failed.
In this example, it's easy, because the server just tells us what happened.
If we see ‘decryption failed’, we know the padding is invalid;
if we see 'deserialization failed', or no error at all, then the padding is valid.
This gives us a padding oracle, and so this is exactly the kind of toehold we need to launch our attack.
If we can make enough queries to this oracle, we can decrypt this token.
But that's not all we can do: We'll also be able to make arbitrary modifications to the token,
or even encrypt entirely new messages under the secret key - all without ever learning the key's value.
"""
class c17_02_example(MovingCameraScene): # the camera doesn't actually move but it's convenient being able to refer to self.camera.frame
def construct(self):
toc = c17_00_toc().title_and_toc(indicated=1)[1][1]
h1 = DOWN*2.5
h2 = DOWN*5.5
half_center = self.camera.frame_width/4
lpos = LEFT*half_center + UP*3.2
rpos = RIGHT*half_center + UP*3.2
# introduce framing
divider = DashedLine(self.camera.frame.get_edge_center(UP), self.camera.frame.get_edge_center(DOWN))
client = Text("Client").move_to(lpos)
server = Text("Server").move_to(rpos)
self.play(AnimationGroup(
ReplacementTransform(toc, client, run_time=1.5),
ReplacementTransform(toc.copy(), server, run_time=1.5),
Succession(Wait(0.5), Write(divider, run_time=2)),
))
# introduce token (plaintext)
pt_string = '{"uid":10,"role":"user"}'
token_pt_text = Text(pt_string).scale(0.7).set_x(rpos[0])
token_pt_text.save_state()
token_pt = Blocks(pt_string, n_cols=1, zoom=0.7, c_fills=C_PT).move_to(rpos+h1)
self.play(Write(token_pt_text))
self.next_section()
self.play(TextToBuffer(token_pt_text, token_pt, pad=True, lag_ratio=1/30))
self.next_section()
# encrypt
iv = blk()
ct_bs = _enc(token_pt.bytes, mode=AES.MODE_CBC, iv=iv)
token_ct = Blocks(iv+ct_bs, n_cols=1, zoom=0.7, c_fills=[C_IV]*16+[C_CT]*32).move_to(rpos+h2)
enc = EncBox.with_arrows(token_pt, token_ct)
self.play(LaggedStart(
Write(enc),
Write(token_ct),
lag_ratio=0.5
))
self.next_section()
# transfer to client
self.play(token_ct.animate(lag_ratio=0.0015).move_to(lpos+h2), FadeOut(enc, token_pt), run_time=1.5)
self.next_section()
# client prepares request
request_str = r"&\texttt{GET /index} \\ &\texttt{token=}"
request = MathTex(request_str).match_x(token_ct)
# {{{ fiddly bullshit to get block lined up with request's equals sign
final_char = request[0][-2]
token_ct.save_state()
token_ct.scale(0.7).next_to(final_char, RIGHT, buff=0).shift(1.17*LEFT) # NOTE not sure why this shift-left is necessary...
Group(request, token_ct).match_y(Point(lpos+h1))
anchor = token_ct.get_center()
request.align_to(token_ct, LEFT)
token_ct.restore()
# }}}
self.play(LaggedStart(
Write(request.align_to(token_ct, LEFT)),
token_ct.animate.zoom(0.7).move_to(anchor),
lag_ratio=0.5,
))
self.next_section()
# client makes request
#self.play(Group(request, token_ct).animate.match_x(Point(rpos)))
to_server_1 = VGroup(request.copy(), token_ct.copy())
target_pos = to_server_1.get_center()
to_server_1.move_to(rpos+h1)
self.play(FadeIn(to_server_1, target_position=target_pos))
self.next_section()
# server reads token
token_pt.move_to(rpos+h2)
bendy_dec = DecBox.with_arrows(to_server_1[1], token_pt)
dec = DecBox.with_arrows((bendy_dec, UP, UP), (bendy_dec, DOWN, DOWN))
self.play(Write(dec), Write(token_pt))
self.next_section()
self.play(FadeOut(to_server_1, dec))
self.next_section()
# server unpads and deserializes token
#padding = token_pt[-1][-8:]
#message = token_pt[0] + token_pt[1][:8]
#self.play(FadeOut(padding))
#self.next_section()
token_pt_text.restore()
self.play(BufferToText(token_pt, token_pt_text, lag_ratio=1/30, unpad=True))
self.next_section()
# discard serverside stuff for now
self.play(FadeOut(token_pt_text))
self.next_section()
# what if the token is fucked up?
garbage_ct_bytes = blk()+blk()+blk()
func = FuncBox(r"\texttt{urandom(3} \cdot \texttt{16)}", zoom=0.7).next_to(token_ct, DOWN, buff=1)
arrow = BendyArrow(func, UP, token_ct, DOWN)
self.play(
Rewrite(token_ct, [' ']*16*3, [C_IV]*16+[C_CT]*32), # wish i could bring this Rewrite() into the next play() but manim breaks if you have multiple animations on a mobject in the same play() so rip
)
self.next_section()
self.play(LaggedStart(
LaggedStart(
FadeIn(func, arrow, shift=0.5*UP, rate_func=rate_functions.there_and_back_with_pause, run_time=3),
FlowThru(arrow),
lag_ratio=0.2,
),
Rewrite(token_ct, garbage_ct_bytes, [C_IV]*16+[C_CT]*32),
lag_ratio=1/3,
))
self.next_section()
# send the fucked up random token to the server and see what happens
to_server_2 = VGroup(request.copy(), token_ct.copy()).move_to(rpos+h1)
self.play(FadeIn(to_server_2, target_position=target_pos))
self.next_section()
# server reads token (and freaks out)
garbage_pt_bytes = _dec(garbage_ct_bytes[16:], mode=AES.MODE_CBC, iv=garbage_ct_bytes[:16])
garbage_token_pt = Blocks(garbage_pt_bytes, n_cols=1, zoom=0.7, pad=False)
garbage_token_pt.move_to(rpos+h2)
bendy_dec = DecBox.with_arrows(to_server_2[1], garbage_token_pt)
dec = DecBox.with_arrows((bendy_dec, UP, UP), (bendy_dec, DOWN, DOWN))
self.play(Write(dec), Write(garbage_token_pt))
self.next_section()
# oh no bad padding - error!
error_1 = Text("{'error':'decryption failed'}").scale(0.6).move_to(garbage_token_pt)
self.play(Wiggle(garbage_token_pt[-1]), run_time=2)
self.next_section()
self.play(FadeOut(to_server_2, dec), ReplacementTransform(garbage_token_pt, error_1))
self.next_section()
self.play(error_1.animate.match_x(client))
self.play(error_1.animate.set_opacity(0))
self.next_section()
# but what if... good padding?
less_garbage_ct_bytes = bytes.fromhex("25c86966fcd8c989e6155e1f07b75d4be838da7942e899e1a289425dac1103c4ee7b764aa0e5fdb1c0e351347c50f405")
less_garbage_pt_bytes = _dec(less_garbage_ct_bytes[16:], mode=AES.MODE_CBC, iv=less_garbage_ct_bytes[:16])
assert less_garbage_pt_bytes[-1] == 1
self.play(Rewrite(token_ct, less_garbage_ct_bytes, c_fills=[C_IV]*16+[C_CT]*32))
self.next_section()
# send the new token
to_server_3 = VGroup(request.copy(), token_ct.copy()).move_to(rpos+h1)
self.play(FadeIn(to_server_3, target_position=target_pos))
self.next_section()
# decrypt the new token
less_garbage_pt_bytes = _dec(less_garbage_ct_bytes[16:], mode=AES.MODE_CBC, iv=less_garbage_ct_bytes[:16])
less_garbage_token_pt = Blocks(less_garbage_pt_bytes, n_cols=1, zoom=0.7, pad=False)
less_garbage_token_pt.move_to(rpos+h2)
bendy_dec = DecBox.with_arrows(to_server_2[1], less_garbage_token_pt)
dec = DecBox.with_arrows((bendy_dec, UP, UP), (bendy_dec, DOWN, DOWN))
self.play(Write(dec), Write(less_garbage_token_pt))
self.next_section()
self.play(FadeOut(to_server_3, dec))
# a single 0x01 byte is valid padding!
last_byte = less_garbage_token_pt[-1][-1]
self.play(LaggedStart(
FocusOn(last_byte.text.get_corner(DOWN+RIGHT), run_time=1),
Wiggle(last_byte, scale_value=1.2, run_time=1),
lag_ratio=1/2,
))
self.next_section()
self.play(Rewrite(less_garbage_token_pt, c_fills=[C_PT]*(16*2-1) + [C_PAD], rev=True))
self.next_section()
self.play(FadeOut(last_byte))
less_garbage_token_pt.remove(last_byte)
self.next_section()
# but this still decodes to garbage (albeit printable garbage)
less_garbage_text = ''.join(chr(b) for b in less_garbage_token_pt.bytes[:-1])
decoded = Text('\n'.join(chunk_bytes(less_garbage_text, 16))).scale(0.7).move_to(less_garbage_token_pt)
self.play(BufferToText(less_garbage_token_pt, decoded))
self.next_section()
error_2 = Text("{'error':'deserialization failed'}").scale(0.6).move_to(decoded)
decoded_1 = decoded[:16] # split on newline
decoded_2 = decoded[16:]
self.remove(decoded)
self.play(
Transform(decoded_1, error_2),
Transform(decoded_2, error_2),
)
self.remove(decoded_1, decoded_2)
self.add(error_2)
self.next_section()
self.play(error_2.animate.match_x(client))
self.play(error_2.animate.set_opacity(0))
self.next_section()
error_2.save_state().center().scale(1.7)
self.play(
FadeOut(*self.mobjects),
error_1.animate.set_opacity(1).center().shift(UP).scale(1.7).align_to(error_2, LEFT),
error_2.restore().animate.set_opacity(1).center().shift(DOWN).scale(1.7),
)
self.next_section()
toc = c17_00_toc().title_and_toc(indicated=1)[1][1]
self.play(Transform(VGroup(error_1, error_2), VGroup(toc, toc.copy())))
self.wait()
"""
Now, we're about to get into the details of how this attack works.
But, if you'll bear with me, I want to take the long way there,
which starts by talking about why the attack might seem like it shouldn't work.
Cryptography is a field full of counterintuitive results,
and if you just hear someone explain them straightforwardly and uncritically,
it can be easy to forget how strange and surprising some of this stuff really is.
so, just to counteract that effect, and for an extra bit of dramatic flair,
let's put ourselves in the shoes of a developer who gets paged at 2AM to fix this vulnerability,
and who wants nothing more than to go back to sleep.
Here are some of the things that developer might say to try and downplay this issue.
first off, like we just saw, a padding oracle tells us whether or not any given ciphertext decrypts to a plaintext with valid padding.
that is to say, all we get is a yes-or-no answer to a small, simple question.
and in fact it turns out the answer is "no" for about 99.6% of all possible ciphertexts.
so, because this answer is so biased, the information leak here is very small.
next up, intuitively, all we have is a ciphertext, so surely the attack must involve modifying the ciphertext,
but for any good cipher these modifications will have unpredictable results.
that last statement assumed that we don't know they key,
so we might as well add that as long as our cipher is good,
there's no way that this issue could lead to key compromise.
as for the actual information we're leaking, that seems pretty insignificant, too:
after all, padding usually only appears in the last block of a message,
so this oracle can only tell us about that last block.
and finally, padding bytes and message bytes do not overlap,
so the oracle can't even tell us about the message bytes.
all of this seems to point towards this attack not working.
i could go on, but you get the idea.
the thing is: all of these points are true, but none of them actually matter. the attack still works.
and this is something i want to emphasize: in cryptography, beware of intuitive arguments.
there are many fields where you can get away with, or even are encouraged to use,
this sort of heuristic reasoning; this is not such a field.
this is why we care so much about things like formal verification, proofs of security, and so on:
because it's much easier to trust an argument when it's backed up by mathematical rigor,
although even there you still have to be careful,
because subtle mistakes in proofs can be easy to make and hard to find, and can lead to real breaks in supposedly secure systems.
but anyway, let's go ahead and call these points what they are.
they lack rigor, so they're not actual problems, they're just complaints,
and we don't care about complaints.
Now that we've gotten that out of the way, let's see the attack.
"""
class c17_03_problems(Scene):
def construct(self):
title = Title("Problems")
#self.play(Write(title))
blist = BulletedList(
r"Problem: Oracle gives very little information per query ($\ll 1$ bit)",
"Problem: Ciphertext changes scramble plaintext",
"Problem: We can't learn the key",
"Problem: Oracle only tells us about last block",
"Problem: Oracle doesn't tell us about message bytes",
).scale(0.8)
complaints = BulletedList(
r"Complaint: Leaks very little information per query ($\ll 1$ bit)",
"Complaint: Ciphertext changes scramble plaintext",
"Complaint: We can't learn the key",
"Complaint: Oracle only tells us about last block",
"Complaint: Oracle doesn't tell us about message bytes",
"Solution: Make it work anyway",
).scale(0.8).align_to(blist, UP)
toc = c17_00_toc().title_and_toc(indicated=2)[1][2]
#self.play(FadeTransform(toc, title))
self.play(TransformMatchingShapes(toc, title, fade_transform_mismatches=True))
self.next_section()
for line in blist:
self.play(Write(line))
self.wait(0.1)
self.next_section()
ComplaintTransform = lambda problem, complaint: AnimationGroup(Transform(problem[:8], complaint[:10]), Transform(problem[8:], complaint[10:]))
self.play(*[ComplaintTransform(a, b) for a, b in zip(blist, complaints)])
self.wait(0.1)
self.next_section()
self.play(Write(complaints[-1]))
self.next_section()
self.remove(title)
self.play(AnimationGroup(FadeOut(*self.mobjects), FadeTransform(title, toc), lag_ratio=0.5), run_time=2)
self.wait(0.1)
self.next_section()
"""
Once again, this is CBC encryption. If you've been following the guided tour,
then by now this figure should look very familiar.
The key point to notice here is that each plaintext block is XORed with the
previous ciphertext block (or IV) prior to encryption.
For this attack, we won't be spending a ton of time with this multi-block
figure, because it turns out that the multi-block case of this attack
reduces nicely to the single-block case.
To see why, let's review CBC decryption. We'll just flip these around, and there we go.
In this figure, we know the ciphertext, but not the plaintext or the key.
We'd like to find the plaintext, but to do that, we need to be able to
compute D_k, which of course depends on the key k.
Now, we don't know k, but the defender does, and so the strategy for this
attack is to get the defender to compute D_k, and then trick them into
revealing the result to us.
Now, this attack is usually called a chosen-ciphertext attack,
which is not wrong, because it can be carried out in the ciphertext,
but it can also be carried out in the IV, which is what we'll do here, because I think this is more clear.
We'll be leaving the ciphertext unchanged and messing around with the IV instead.
so it might make more sense to think of it as a chosen-IV attack.
So we'll take this ciphertext block, isolate it, and fill in a new IV and plaintext.
This IV is attacker-controlled, and we could set it to whatever we want,
but for now let's keep things simple and just set it to zero.
"""
class c17_04_multiblock(Scene):
@staticmethod
def introduce(cbc):
# custom creation animation - Create(cbc) and Write(cbc) also work,
# but i think this looks nicer
return LaggedStart(
AnimationGroup(
Write(cbc.ct_to_xor if hasattr(cbc, "ct_to_xor") else cbc.iv_to_xor),
Write(cbc.pt_to_xor),
Write(cbc.xor_to_enc),
Write(cbc.enc_to_ct),
FadeIn(cbc.xor),
FadeIn(cbc.enc),
),
FadeIn(cbc.ct, shift=0.5*DOWN, run_time=0.5),
lag_ratio=0.2
)
def construct(self):
toc = c17_00_toc().title_and_toc(indicated=3)[1][3]
toc.save_state()
# introduce multi-block case
cbc_1 = CBCBlock()
cbc_2 = CBCBlock(prev=cbc_1)
dots = Text("...").move_to(cbc_2.ct).align_to(cbc_2.ct, RIGHT).shift(DOWN).shift(LEFT*0.2)
to_dots = BendyArrow(cbc_2.ct, DOWN, dots, LEFT)
dots.shift(RIGHT*0.2)
ZoomableVGroup(cbc_1, cbc_2, dots, to_dots).center().zoom(0.55)
iv = cbc_1.iv
cbc_1.remove(iv)
self.play(LaggedStart(
FadeTransform(toc, iv),
Write(cbc_1.pt),
self.introduce(cbc_1),
Write(cbc_2.pt, run_time=1),
self.introduce(cbc_2),
AnimationGroup(Write(to_dots), Write(dots)),
lag_ratio=0.4,
))
self.next_section()
# flip the arrows
cbc_dec_1 = CBCBlock(dec=True)
cbc_dec_2 = CBCBlock(prev=cbc_dec_1, dec=True)
### ugly bit: we need to include this in the cbc vgroup bc otherwise the dots_arrow tip throws off the center()
_dots = Text("...").move_to(cbc_dec_2.ct).align_to(cbc_dec_2.ct, RIGHT).shift(DOWN).shift(LEFT*0.2)
_dots_arrow = BendyArrow(cbc_dec_2.ct, DOWN, _dots, LEFT)
_dots.shift(RIGHT*0.2)
###
ZoomableVGroup(cbc_dec_1, cbc_dec_2, _dots, _dots_arrow).center().zoom(0.55)
cbc_dec_1.remove(cbc_dec_1.iv)
self.play(
Transform(cbc_1, cbc_dec_1),
Transform(cbc_2, cbc_dec_2),
)
self.next_section()
# the problem is D_k
self.play(LaggedStart(
Circumscribe(cbc_1.enc, color=C_STROKE), # _.enc is actually D_k here but don't get me started on that
Circumscribe(cbc_2.enc, color=C_STROKE),
lag_ratio=3/4,
))
self.next_section()
# setup: build the whole final scene
# isolate one block of ciphertext
ct = cbc_1.ct.flip(RIGHT) # makes transform look better
cbc_1.remove(ct)
mobjs = Group(*self.mobjects)
new_cbc = c17_05_singleblock_add_extra_state.get_flipped_cbc().zoom(0.7)
new_ct = new_cbc.ct
self.play(LaggedStart(
AnimationGroup(
FadeOut(iv, cbc_1.iv_to_xor, scale=0.8, run_time=1),
FadeOut(cbc_1.remove(cbc_1.iv, cbc_1.iv_to_xor), scale=0.8, run_time=1),
FadeOut(cbc_2.ct_to_xor, scale=0.8, run_time=1),
FadeOut(cbc_2.remove(cbc_2.ct_to_xor), dots, to_dots, scale=0.8, run_time=1),
),
ReplacementTransform(ct, new_ct, path_arc=-TAU/4, lag_ratio=0.01, run_time=2),
))
new_pt = new_cbc.pt
new_iv = new_cbc.iv
new_cbc.remove(new_ct, new_pt, new_iv)
self.play(Write(VGroup(new_iv, new_pt)), Write(new_cbc[::-1]), run_time=3)
self.next_section()
# we know the IV and ciphertext...
#ct_box = SurroundingRectangle(new_ct, color=C_CT, stroke_width=6, buff=0.15)
#self.play(Write(ct_box))
#self.next_section()
# ...but not the plaintext or key
#pt_box = SurroundingRectangle(new_cbc.pt, color=C_PT, stroke_width=6, buff=0.15)
#k_box = SurroundingRectangle(new_cbc.enc.text[-1][-1], color=C_PT, stroke_width=6, buff=0.05)
#self.play(AnimationGroup(Write(k_box), Write(pt_box), lag_ratio=0.5), run_time=2)
#self.next_section()
# drop these boxes
#self.play(FadeOut(ct_box, pt_box, k_box))
# let's start with a zero IV
self.play(Rewrite(new_iv, bytes(16), C_IV), rate_func=rate_functions.ease_in_out_sine, run_time=1.5)
self.next_section()
self.wait(0.1)
"""
I'm leaving the ciphertext blank, because I don't want to distract from
what's important about this attack, which is the interaction between the IV
and with the output from D_k, which we'll add a block for right here.
"""
class c17_05_singleblock_add_extra_state(Scene):
@staticmethod
def get_flipped_cbc(iv=None):
cbc = CBCBlock(dec=True, buff=2, iv=iv)
cbc.enc.save_state()
cbc.flip(RIGHT) # flips along horizontal axis
anchor = cbc.enc.get_center()
cbc.enc.restore().move_to(anchor)
cbc.xor_to_enc.flip(UP)
cbc.enc_to_ct.flip(UP)
return cbc
@staticmethod
def get_flipped_cbc_with_extra_block(iv=None):
cbc = c17_05_singleblock_add_extra_state.get_flipped_cbc(iv=iv)
anchor_up = cbc.ct.get_edge_center(DOWN)
anchor_down = cbc.xor.get_edge_center(UP)
anchor_1_3 = (2*anchor_up + anchor_down)/3
anchor_2_3 = (anchor_up + 2*anchor_down)/3
cbc.enc.move_to(anchor_1_3)
cbc.enc_to_ct.become(BendyArrow(cbc.ct, DOWN, cbc.enc, UP))
cbc.dec_out = Block().move_to(anchor_2_3)
cbc.add(cbc.dec_out)
cbc.xor_to_enc.become(VGroup(
BendyArrow(cbc.dec_out, DOWN, cbc.xor, UP),
BendyArrow(cbc.enc, DOWN, cbc.dec_out, UP),
)) # turning this arrow into 2 arrows makes the transform animation better
return cbc
def construct(self):
self.add(cbc := self.get_flipped_cbc(iv=bytes(16)).zoom(0.7))
self.wait(0.5)
cbc_extra = self.get_flipped_cbc_with_extra_block(iv=bytes(16)).zoom(0.7)
cbc_extra.remove(dec_out := cbc_extra.dec_out)
cbc.xor_to_enc.remove(tip := cbc.xor_to_enc.get_tip())
self.play(Transform(cbc, cbc_extra), Write(dec_out, run_time=1), FadeOut(tip))
self.remove(dec_out)
cbc_extra.add(dec_out)
"""
Now, if we send the defender a ciphertext and an IV, the defender will
compute the rest of this diagram. Of course, they don't tell us the result;
they only tell us whether or not the result has valid padding.
If the padding is valid, that could mean one of sixteen things.
It probably means the plaintext ends with one byte of value 1.
but it is also possible that the plaintext ends with, say, two bytes of value 2, or so on.
However, these possibilities are not all equally likely.
For a uniformly random plaintext,
The odds of getting valid one-byte padding are 1 in 256;
The odds of getting valid two-byte padding are 1 in 256 squared;
And it drops off exponentially from there.
So, even though longer padding is not likely,
it's likely enough that our attack will have to take this possibility into account.
Now, the first key insight in this attack is to recognize that by making modifications to the IV,
we can predictably modify the plaintext.
You might recall that earlier we were complaining about how we don't know what'll happen to the plaintext if we change the ciphertext.
This is true, but the same is not true for the IV:
changing the IV changes the plaintext in an entirely predictable way.
To be specific, flipping a bit in the IV will flip the corresponding bit in the plaintext.
Setting the IV’s final byte to any value will xor that value into the
plaintext’s final byte. This is exactly the same property that we took
advantage of in challenge 16.
In this context, we'll use this property to launch a search.
We'll loop through each possible value for the last byte of the IV. In effect,
this also loops through each possible value for the last byte of the plaintext.
It might take a while, but we know that sooner or later we'll set this byte
to value 1, or to another byte that constitutes valid padding - and we know
that once this happens, the oracle will tell us.
Now, as we just discussed, technically we don't know the length of the padding -
we only know that it's valid.
The most likely outcome by far is that it's a single byte of value 1.
That said, it could also be 2 bytes of value 2, or so on up to 16 bytes of value 16.
"""
class c17_06_singleblock_first_iv_byte_search(Scene):
def construct(self):
self.add(cbc := c17_05_singleblock_add_extra_state.get_flipped_cbc_with_extra_block(iv=bytes(16)).zoom(0.7))
self.wait(0.5)
self.next_section()
# review: valid padding
pt = cbc.pt
pt.flip(RIGHT)
self.play(Rewrite(pt, [' ']*15 + ['01']*1, [C_PT]*15 + [C_PAD]*1))
self.play(Rewrite(pt, [' ']*14 + ['02']*2, [C_PT]*14 + [C_PAD]*2))
self.next_section()
# review: probabilities of each padding string
p_1 = MathTex(r"\frac{1}{2^8}").scale(0.4).next_to(cbc.pt[-1], UP)
p_2 = p_1.copy().next_to(cbc.pt[-2], UP)
p_3 = p_1.copy().next_to(cbc.pt[-3], UP)
dot_1 = MathTex(r"\cdot").scale(0.4).move_to(Group(p_1, p_2))
dot_2 = MathTex(r"\cdot").scale(0.4).move_to(Group(p_2, p_3))
self.play(AnimationGroup(
Rewrite(cbc.pt, [' ']*15 + ['01']*1, [C_PT]*15 + [C_PAD]*1),
Write(p_1),
lag_ratio=0.5,
))
self.next_section()
self.play(AnimationGroup(
Rewrite(cbc.pt, [' ']*14 + ['02']*2, [C_PT]*14 + [C_PAD]*2),
Write(p_2), Write(dot_1),
lag_ratio=0.5,
))
self.next_section()
self.play(AnimationGroup(
Rewrite(cbc.pt, [' ']*13 + ['03']*3, [C_PT]*13 + [C_PAD]*3),
Write(p_3), Write(dot_2),
lag_ratio=0.5,
))
self.next_section()
self.play(
Rewrite(cbc.pt, None, c_fills=C_PT),
*[Unwrite(m) for m in (p_1, p_2, p_3, dot_1, dot_2)],
)
# iterate over each value for the last byte of the IV
# until we get a hit
#print(self.mobjects)
for i, t in zip(range(1, DEC_OUT[-1]^1), [0.5]*3+[0.2]*5+[0.05]*255):
self.play(inc(cbc.iv, -1, i, run_time=t))
self.play(LaggedStart(
inc(cbc.iv, -1, DEC_OUT[-1]^1, run_time=0.2),
FocusOn(cbc.pt),
Rewrite(cbc.pt, c_fills=[(C_PT, C_PAD, C_PT)]*16),
lag_ratio=0.7,
))
self.next_section()
# valid? length 1?
self.play(
Rewrite(cbc.pt, [' ']*15+[1], c_fills=[C_PT]*15 + [C_PAD*1]),
run_time=3,
)
self.next_section()
# longer?
self.play(Rewrite(cbc.pt, [' ']*14 + [2]*2, [C_PT]*14 + [C_PAD]*2))
self.play(Rewrite(cbc.pt, [' ']*13 + [3]*3, [C_PT]*13 + [C_PAD]*3))
for i in range(4, 17):
j = 16-i
cbc.pt.rewrite([' ']*j+[i]*i, c_fills=[C_PT]*j + [C_PAD]*i)
self.wait(0.1)
self.wait(0.1)
self.next_section()
"""
We can rule out these edge cases by changing the penultimate byte of the IV.
If the corresponding byte of the plaintext is part of the padding, then changing it will cause the padding to become invalid.
Contrapositively, if the padding is still valid, then it must have length 1 and value 1.
If this test fails, we'll continue the byte search. But if it succeeds, then
we've recovered the exact value for the final plaintext byte.
Now, check this out. XOR has some useful algebraic properties, which you can
read about in the bottom left, if you're into that kind of thing.
Otherwise, just trust me when I say that we can turn these arrows around
without invalidating the diagram. Then we just compute the xor of the IV
and plaintext, write it down, and just like that, we've recovered the last
byte of D_k's output.
We just made length-1 padding. Now let's try to make length-2 padding.
We'll use the same xor algebra as before to set the final byte of the
plaintext to value 2. Then we'll search through each value for the
penultimate byte until we get a hit from the padding oracle.
Note that this time we don't have to check the length of the padding -
since we set the last byte of the plaintext to value 2, we know for a fact
that any valid padding must have length 2.
Now, using the same algebra as before, we can recover the penultimate byte
of D_k's output. From here, the attack proceeds just how you'd expect: we
set the final 2 bytes to value 3, then search through each possible value
for the antepenultimate byte, until we find the value that works.
As you can probably guess by now, we can repeat this search for each byte
until we've recovered all sixteen bytes of D_k's output. For each step,
we're deriving an IV that sets the final n-1 plaintext bytes to value n.
Then we're searching through the n'th-from-last IV byte until we hit valid
padding. We take that IV byte, xor it against n, and that's our D_k byte.
If you squint just right, this might look kind of like our byte-at-a-time
ECB decryption attack. Of course, it's not exactly the same, but it's fast
for the same reason, namely because the attack proceeds one byte at a time,
and therefore only has to deal with byte-sized search spaces.
The process is kind of hypnotic; I'll wait for a moment and let it play out.
OK, let's wrap it up. I hope you're having as much fun watching this as I had making it.
But I'll tell you what's even more fun: we've now pulled off a padding oracle attack!
We've recovered the full output of D_k, and we've done this without any knowledge of k.
I mentioned earlier that the multi-block case of this attack reduces to the
single-block case. Just to show what I mean, let's bring this result back
to the multi-block case and start putting the pieces together.
"""
class c17_07_singleblock_derive_dec_bytes(Scene):
def construct(self):
self.add(cbc := c17_05_singleblock_add_extra_state.get_flipped_cbc_with_extra_block(iv=bytes(16)).zoom(0.7))
iv_byte = DEC_OUT[-1]^1
cbc.pt.flip(RIGHT).rewrite([16]*16, c_fills=C_PAD)
cbc.iv.flip(RIGHT).rewrite([0]*15+[iv_byte], c_fills=C_IV)
cbc.ct.flip(RIGHT) # don't rly need this but it's just good bookkeeping
# test by flipping penultimate byte
self.play(AnimationGroup(
FocusOn(cbc.iv[-2]),
Rewrite(cbc.iv, [0]*14 + [0xff] + [iv_byte], c_fills=C_IV),
FlowThru(cbc.iv_to_xor, cbc.xor, cbc.pt_to_xor, run_time=2.5, lag_ratio=0.7, box_time_fac=1.3),
Rewrite(cbc.pt, [16]*14 + [16^0xff] + [16], c_fills=WHITE),
lag_ratio=0.7,
))
self.next_section()
# we've recovered its value
self.play(Rewrite(cbc.pt, None, [C_PT]*15 + [C_PAD]))
self.next_section()
self.play(Rewrite(cbc.pt, [' ']*15 + [1], [C_PT]*15 + [C_PAD]))
self.next_section()
# note on math stuff
math_note = MathTex(r"""
&\text{\textit{Note for algebra nerds:} XOR defines an \textbf{abelian group} where \textbf{each element is its own inverse}.} \\
&\text{Because of this, if } a = b \oplus c \text{ then we also know that } b = a \oplus c \text{ and } c = b \oplus a. \text{ The argument is trivial:} \\
&a = b \oplus c \\
&a \oplus c = b \oplus c \oplus c = b \\
&a \oplus c \oplus a = b \oplus a \\
&\text{By the way, if you don't know what an \textbf{abelian group} is, \textbf{don't worry}: we'll get to that in set 8,} \\
&\text{and besides, the math works whether you understand it or not :)}
""").scale(0.42)
r"""
&a \oplus b = c \\
& a \oplus b \oplus b = c \oplus b \\
& a = b \oplus c \\
& a \oplus c = b \oplus c \oplus c = b \\
"""
#math_note.shift(DOWN*2.5 + LEFT)
math_note.align_to(DOWN*3.8 + LEFT*6.6, DL)
self.play(Write(math_note), run_time=3)
self.next_section()
self.play(Unwrite(math_note), run_time=2)
self.next_section()
get_rev_arrow = lambda arrow: BendyArrow(arrow, DOWN, arrow, UP, tip_length=0.17*0.7, tip_width=0.17*0.7).zoom(0.7).scale(1/0.7)
pt_to_xor_down = cbc.pt_to_xor.copy()
pt_to_xor_up = get_rev_arrow(pt_to_xor_down)
xor_to_enc_down = cbc.xor_to_enc[0].copy()
xor_to_enc_up = get_rev_arrow(xor_to_enc_down)
iv_to_xor_right = cbc.iv_to_xor.copy()
iv_to_xor_up = BendyArrow(cbc.xor, LEFT, cbc.iv, DOWN, tip_length=0.17*0.7, tip_width=0.17*0.7).zoom(0.7).scale(1/0.7)
# turn arrows around and fill in first byte of D_k!
self.play(
cbc.pt_to_xor.animate.become(pt_to_xor_up),
cbc.xor_to_enc[0].animate.become(xor_to_enc_up),
)
self.play(
Rewrite(cbc.dec_out, [' ']*15 + [DEC_OUT[-1]]),
FadeOut(cbc.iv[-1].copy(), target_position=cbc.dec_out[-1]),
FadeOut(cbc.pt[-1].copy(), target_position=cbc.dec_out[-1]),
)
self.next_section()
# ok enough preamble, now it's time to do the rest of the attack
for ind in range(2, 17):
_ind = 16-ind
rt = 1 if ind < 6 else 1/3
#rt = 1
cr = 1/30 if ind < 4 else 1/120
# rewrite pt and blank out IV
self.play(
Rewrite(cbc.pt, [' ']*(_ind+1) + [ind]*(ind-1), [C_PT]*_ind + [WHITE]*ind),
Rewrite(cbc.iv, [0]*_ind + [' ']*ind),
run_time=rt,
)
# fill in new IV bytes
self.play(
Rewrite(cbc.iv, [0]*(_ind) + [' '] + [i ^ ind for i in DEC_OUT[_ind+1:]]),
FadeOut(cbc.pt[_ind+1:].copy(), target_position=cbc.iv[_ind+1:]),
FadeOut(cbc.dec_out[_ind+1:].copy(), target_position=cbc.iv[_ind+1:]),
*([cbc.xor_to_enc[0].animate.become(xor_to_enc_down),
cbc.iv_to_xor.animate.become(iv_to_xor_up)]
if ind == 2 else []),
run_time=rt,
)
if ind == 2:
self.play(cbc.pt_to_xor.animate.become(pt_to_xor_down),
cbc.iv_to_xor.animate.become(iv_to_xor_right))
# search time!
target_byte = DEC_OUT[_ind] ^ ind
iv_vals = list(range(0, target_byte+1, (1 if target_byte < 50 else 3)))
iv_vals[-1] = target_byte # just in case range() didn't end on target_byte
pt_vals = [i ^ DEC_OUT[_ind] for i in iv_vals]
self.play(
Cycle(cbc.iv[_ind], iv_vals, rate=cr),
Cycle(cbc.pt[_ind], pt_vals, rate=cr),
)
# update the pt to confirm that it has valid padding
self.play(
Rewrite(cbc.pt, [' ']*_ind + [ind]*ind, c_fills=[C_PT]*_ind+[C_PAD]*ind),
run_time=0.75*rt,
)
# update dec_out
self.play(
Rewrite(cbc.dec_out, [' ']*_ind+[*DEC_OUT[_ind:]]),
FadeOut(cbc.iv[_ind].copy(), target_position=cbc.dec_out[_ind]),
FadeOut(cbc.pt[_ind].copy(), target_position=cbc.dec_out[_ind]),
run_time=rt if ind < 16 else 1,
)
self.next_section()
# fade out everything but the ct (incl. dec_out, bc we want to dramatically reintroduce it)
cbc.remove(cbc.ct)
self.add(cbc.ct)
self.play(FadeOut(cbc))
"""
And just like that, here we are, back where we started.
But things have changed, because now we know how to recover these cipher calls' outputs.
"""
class c17_08_back_to_multiblock(Scene):
def construct(self):
cbc_1 = CBCBlock()
cbc_2 = CBCBlock(prev=cbc_1)
dots = Text("...").move_to(cbc_2.ct).align_to(cbc_2.ct, RIGHT).shift(DOWN).shift(LEFT*0.2)
to_dots = BendyArrow(cbc_2.ct, DOWN, dots, LEFT)
dots.shift(RIGHT*0.2)
ZoomableVGroup(cbc_1, cbc_2, dots, to_dots).center().zoom(0.55)
# flip the arrows
cbc_dec_1 = CBCBlock(dec=True)
cbc_dec_2 = CBCBlock(prev=cbc_dec_1, dec=True)
### ugly bit: we need to include this in the cbc vgroup bc otherwise the dots_arrow tip throws off the center()
_dots = Text("...").move_to(cbc_dec_2.ct).align_to(cbc_dec_2.ct, RIGHT).shift(DOWN).shift(LEFT*0.2)
_dots_arrow = BendyArrow(cbc_dec_2.ct, DOWN, _dots, LEFT)
_dots.shift(RIGHT*0.2)
###
ZoomableVGroup(cbc_dec_1, cbc_dec_2, _dots, _dots_arrow).center().zoom(0.55)
cbc_1.become(cbc_dec_1)
cbc_2.become(cbc_dec_2)
###### DONE REBUILDING c17_04_multiblock
# bring back the isolated block of ciphertext
ct = cbc_1.ct.flip(RIGHT) # makes transform look better
cbc_1.remove(ct)
mobjs = Group(*self.mobjects)
new_cbc = c17_05_singleblock_add_extra_state.get_flipped_cbc().zoom(0.7)
new_ct = new_cbc.ct
self.play(LaggedStart(
ReplacementTransform(new_ct, ct, path_arc=TAU/4, lag_ratio=0.005, run_time=2),
AnimationGroup(
FadeIn(cbc_1.iv, cbc_1.iv_to_xor, scale=0.8, run_time=1),
FadeIn(cbc_1.remove(iv_1 := cbc_1.iv, arrow_1 := cbc_1.iv_to_xor), scale=0.8, run_time=1),
FadeIn(cbc_2.ct_to_xor, scale=0.8, run_time=1),
FadeIn(cbc_2.remove(arrow_2 := cbc_2.ct_to_xor), dots, to_dots, scale=0.8, run_time=1),
run_time=1,
rate_func=linear,
),
lag_ratio=0.5,
))
cbc_1.add(iv_1, arrow_1)
cbc_2.add(arrow_2)
self.next_section()
# the problem is^W^W was D_k
self.play(LaggedStart(
Circumscribe(cbc_1.enc, color=C_STROKE), # _.enc is actually D_k here but don't get me started on that
Circumscribe(cbc_2.enc, color=C_STROKE),
lag_ratio=3/4,
))
self.next_section()
return
# introduce dec_out boxes
"""
Given those outputs, we now can recover the plaintext of any message encrypted with this key.
Let's work an example. Here's a sample ciphertext and IV.
We'll start with the second block of the ciphertext.
Running the padding oracle attack on this second block
allows us to recover the output from the second call to D_k, one byte at a time.
Once we have that, we just evaluate this XOR, and the suspense is almost over, but not quite,
as we discover that this second plaintext block is in fact a full block of padding.
Moving on to the first block of the plaintext, we run the attack again,
and we see that it decrypts to our favorite one-block string.
So there we have it: that's the padding oracle attack, or at least that's the simple version of it.
But there are still some unanswered questions here.
Probably the biggest one is, how can you prevent this attack?
That question has a short answer and a long answer.
The short answer is: use authenticated encryption.
And we'll get to that, but let's take the long road there.
"""
class c17_09_dec_out(Scene):
@staticmethod
def add_extra_block(cbc):
_zoom = cbc._zoom
cbc.zoom(1/_zoom)
anchor_up = cbc.ct.get_edge_center(UP)
anchor_down = cbc.xor.get_edge_center(DOWN)
anchor_1_3 = (7*anchor_up + 3*anchor_down)/10
anchor_2_3 = (3*anchor_up + 7*anchor_down)/10
cbc.enc.move_to(anchor_1_3)
cbc.enc_to_ct.become(BendyArrow(cbc.ct, UP, cbc.enc, DOWN))
cbc.dec_out = Block().move_to(anchor_2_3)
cbc.add(cbc.dec_out)
cbc.xor_to_enc.become(VGroup(
BendyArrow(cbc.dec_out, UP, cbc.xor, DOWN),
BendyArrow(cbc.enc, UP, cbc.dec_out, DOWN),
)) # turning this arrow into 2 arrows makes the transform animation better
cbc.zoom(_zoom)
def construct(self):
cbc_1 = CBCBlock(dec=True)
cbc_2 = CBCBlock(prev=cbc_1, dec=True)
dots = Text("...").move_to(cbc_2.ct).align_to(cbc_2.ct, RIGHT).shift(DOWN).shift(LEFT*0.2)
to_dots = BendyArrow(cbc_2.ct, DOWN, dots, LEFT)
dots.shift(RIGHT*0.2)
ZoomableVGroup(cbc_1, cbc_2, dots, to_dots).center().zoom(0.55)
old_cbc_1 = cbc_1.copy()
old_cbc_2 = cbc_2.copy()
self.add_extra_block(cbc_1)
self.add_extra_block(cbc_2)
self.add(old_cbc_1, old_cbc_2)
cbc_1.remove(dec_out_1 := cbc_1.dec_out)
old_cbc_1.xor_to_enc.remove(tip_1 := old_cbc_1.xor_to_enc.get_tip())
cbc_2.remove(dec_out_2 := cbc_2.dec_out)
old_cbc_2.xor_to_enc.remove(tip_2 := old_cbc_2.xor_to_enc.get_tip())
# add dec_out blocks
self.play(
ReplacementTransform(old_cbc_1, cbc_1),
ReplacementTransform(old_cbc_2, cbc_2),
Write(dec_out_1, run_time=1),
Write(dec_out_2, run_time=1),
FadeOut(tip_1, shift=0.2*UP),
FadeOut(tip_2, shift=0.2*UP),
)
self.next_section()
# pop out the XORs
#self.play(FocusOn(cbc_1.xor), FocusOn(cbc_2.xor.get_edge_center(RIGHT)))
# populate a sample ciphertext
iv = blk()
ct = _enc(b"YELLOW SUBMARINE", pad=True, iv=iv, mode=AES.MODE_CBC)
ct_1, ct_2 = ct[:16], ct[16:]
pt = _dec(ct, iv=iv, mode=AES.MODE_CBC)
pt_1, pt_2 = pt[:16], pt[16:]
self.play(
Rewrite(cbc_1.iv, iv),
Rewrite(cbc_1.ct, ct_1),
Rewrite(cbc_2.ct, ct_2),
)
self.next_section()
# populate the second dec_out block
self.play(
Rewrite(dec_out_2, _dec(ct_2), lag_ratio=0.25, rev=True),
run_time=3,
)
self.next_section()
# show the second plaintext block
self.play(Rewrite(cbc_2.pt, pt_2, C_PAD))
self.next_section()
# populate the first dec_out block
self.play(
Rewrite(dec_out_1, _dec(ct_1), lag_ratio=0.25, rev=True),
run_time=3,
)
self.next_section()
# show the first plaintext block
self.play(Rewrite(cbc_1.pt, pt_1.decode("ASCII")))
self.next_section()
# fade out
toc = c17_00_toc().title_and_toc(indicated=3)[1][3]
iv = cbc_1.iv
self.remove(iv)
self.play(AnimationGroup(FadeOut(*self.mobjects), Transform(iv, toc), lag_ratio=0.5), run_time=3)
self.next_section()
"""
The long road starts by doubling back to our example from earlier.
We had a web service with a custom token-based authentication scheme.
The backend might look a little bit like this, at least if we're using Flask.
The choice of framework is kind of arbitrary; I just like Flask because it's terse.
In case you aren't familiar, let's take a lightning tour.
We'll start at the top; this function handles requests for the "/index" route,
as we see from this app.route decorator here.
Request context is exposed to the handler through this `request` object.
We use this to get the token bytestring, which we assume will arrive hex-encoded.
Of course, this code throws an error 500 in a few cases, but we'll gloss over that here.
Moving on, we split the token into an IV and a ciphertext, which we hand off to AES.
In this case, decryption is treated as a separate operation from unpadding, which happens down here.
If unpadding fails, we return an error.
If the padding is valid, we treat the unpadded plaintext as JSON and try to deserialize it.
Again, if this fails, we return an error.
If the JSON is valid, we're good to go. The endpoint doesn't actually do anything, because that's not the point.
The point is that this endpoint has a padding oracle.
It's a very simple one here: "decryption failed" means bad padding, anything else means good padding.
"""
class c10_10_implementation_1(Scene):
def move_highlight(self, src, dst):
get_rect = lambda mob: SurroundingRectangle(mob, fill_color=YELLOW, fill_opacity=0.5, stroke_width=0, buff=0.03)
self.play(Transform(highlight := get_rect(src), get_rect(dst)), run_time=1.5)
self.next_section()
self.remove(highlight)
def construct(self):
toc = c17_00_toc().title_and_toc(indicated=4)[1][4]
# introduce original code snippet
kwargs = {"language": "python", "tab_width": 4, "font": "Monospace", "line_spacing": 0.6, "style": "inkpot", "margin": 0.4, "background": "window", "insert_line_no": False}
code_1 = Code("snippets/c17_flask_excerpt_1.py", **kwargs).scale(0.5).next_to(ORIGIN, LEFT, buff=0.5)
self.add(toc)
self.play(#LaggedStart(
#Write(code_1.background_mobject, run_time=1.5),
FadeTransform(toc, code_1, rate_func=rate_functions.ease_in_expo),
#lag_ratio=0.5,
run_time=2.5,
)#)
self.remove(*code_1)
self.add(code_1)
self.next_section()
# Highlight app.route decorator
lines = code_1.code.chars
decorator = lines[1]
self.play(HighlightRadial(decorator))
self.next_section()
# Highlight request object
half_in = lambda t: smooth(t/2)
half_out = lambda t: smooth(t/2+0.5)
request = lines[4][23:30]
self.play(HighlightRadial(request, rate_func=half_in))
#self.next_section()
# expand highlight and remove
fromhex = lines[4][9:]
self.move_highlight(request, fromhex)
self.play(HighlightRadial(fromhex, rate_func=half_out, run_time=1))
self.next_section()
# moving on - iv + ct, AES, dec, padding, json
iv_ct = remove_invisible_chars(lines[8])
aes = lines[9][9:-12]
dec = lines[9][-11:]
unpad = lines[14][10:15]
unpad_err = lines[16][9:]
loads = lines[21][10:]
loads_err = lines[23][9:]
# IV and CT
self.play(HighlightRadial(iv_ct, rate_func=half_in))
self.next_section()
# AES
self.move_highlight(iv_ct, aes)
# dec
self.move_highlight(aes, dec)
self.play(HighlightRadial(dec, rate_func=half_out))
self.next_section()
# unpad
self.play(HighlightRadial(unpad)); self.next_section()
self.play(HighlightRadial(unpad_err)); self.next_section()
# json
self.play(HighlightRadial(loads)); self.next_section()
self.play(HighlightRadial(loads_err)); self.next_section()
# good to go
self.play(HighlightRadial(remove_invisible_chars(lines[-1])))
self.next_section()
"""
We've just seen what we can do with an oracle like this. But how do we fix it?
Again, the correct answer is to use authenticated encryption,
but for now let's pretend we don't know that.
Here's a first attempt at a fix. Let's combine these error cases.
This is a big improvement.
The error message no longer specifies whether we hit bad padding and bad JSON.
However, this is still not perfect, because this code is not constant-time.
"""
class c10_10_implementation_2(Scene):
def construct(self):
kwargs = {"language": "python", "tab_width": 4, "font": "Monospace", "line_spacing": 0.6, "style": "inkpot", "margin": 0.4, "background": "window", "insert_line_no": False}
code_1 = Code("snippets/c17_flask_excerpt_1.py", **kwargs).scale(0.5).next_to(ORIGIN, LEFT, buff=0.5)
code_2 = Code("snippets/c17_flask_excerpt_2.py", **kwargs).scale(0.5).align_to(code_1, UP+LEFT)
print(code_1.get_corner(UL))
print(code_1.get_corner(UR))
# rewrite code snippet to remove trivial oracle
self.play(CodeRewrite(code_1, code_2), run_time=4)
"""
Some of you will immediately know what I mean by that,
but not everyone will, so let's dig into it.
We have two likely cases here.
In the first case, the function returns after throwing an exception here.
In the second case, the exception is thrown here.
It technically could also return down here, but this doesn't actually matter for our purposes.
The point is that the time it takes for this function takes to return depends on which exit point it takes,
and how much work it does in the process.
Does it just execute these statements,
or does it execute these ones, too?
The latter case will take very slightly longer.
This difference is very small, and in this example it would require a lot of measurements to compensate for measurement noise.
That noise could come from the network, from server load, or from any number of other sources.
However, there are methods for overcoming this. The simplest one is to just increase our sample size.
By doing this we can reach nearly arbitrary levels of precision at the cost of speed.
We can make this much more efficient by applying statistical methods.
We'll touch on that later in this video, and we'll really dig into it at the end of set 4 when we implement our own timing attacks.
For the moment, let's focus on spotting padding oracles, not fixing or exploiting them.
We've already determined that this function is not constant-time, regardless of whether or not the functions it calls are all constant-time.
However, the situation gets even worse, because it turns out that those functions aren't constant-time either.
What that means is that even if the calling code was constant time, we'd still have problems.
Now, JSON deserialization is guaranteed to be full of branching on secret data, so I'm not even going to go there.
That is obviously not constant-time.
It might be more interesting to look at the unpad function.
Now, to be fair, in general it is not possible to unpad a plaintext in constant time.
You can validate it in constant-time, but unpadding is going to involve secret-dependent memory access patterns, which are not allowed in constant-time code.
So on a theoretical level we wouldn't expect this to be fully constant-time.
But let's look at the code anyway. We'll treat this as sort of a case study.
Now, this might seem excessive or even needlessly cruel,
but I want to drive home just how common this kind of issue is.
Most people who don't look at this stuff for a living tend to underestimate how common or important this is.
This function comes from pycryptodome, which is a great library that I often recommend.
I'm not doing this to give pycryptodome a hard time.
It's not, oh, look, they messed up, it's more like oh, look, even they have these issues.
So here's the unpad function.
This function has a lot of cases, but since we're using PKCS7 we can focus on this block here.
"""
class c17_11_timing_side_channels(MovingCameraScene):
def construct(self):
# introduce original code snippet
kwargs = {"language": "python", "tab_width": 4, "font": "Monospace", "line_spacing": 0.6, "style": "inkpot", "margin": 0.4, "background": "window", "insert_line_no": False}
code_1 = [-6.61445313, 2.95214844, 0] # upper left corner of code_1, as reported by previous scene
code_2 = Code("snippets/c17_flask_excerpt_2.py", **kwargs).scale(0.5).align_to(code_1, UP+LEFT)
code_3 = Code("snippets/c17_unpad.py", **kwargs).scale(0.391).next_to(LEFT*0.5, RIGHT)
self.add(code_2)
# throws an exception here or here
exception_1 = remove_invisible_chars(code_2.code[14])
exception_2 = remove_invisible_chars(code_2.code[16])
self.play(Highlight(exception_1, radial=True))
self.next_section()
self.play(Highlight(exception_2, radial=True))
self.next_section()
# highlight the final return statement
return_statement = remove_invisible_chars(code_2.code[21])
self.play(Highlight(return_statement, radial=True))
self.next_section()
# highlight the first and second execution path
linenos_1 = [4, 5, 8, 9, 10, 14]
lines_1 = [remove_invisible_chars(code_2.code[i]) for i in linenos_1]
linenos_2 = [15, 16]
lines_2 = [remove_invisible_chars(code_2.code[i]) for i in linenos_2]
self.play(LaggedStart(
LaggedStart(*[Highlight(line, delay=1) for line in lines_1], lag_ratio=0.00),
LaggedStart(*[Highlight(line, delay=0.4) for line in lines_2], lag_ratio=0.00),
lag_ratio=0.4,
run_time=6
))
self.next_section()
# highlight json function (let's not even go there!)
json_call = code_2.code[16][10:20]
self.play(Indicate(json_call))
# highlight unpad function
unpad_call = code_2.code[14][10:15]
self.play(Indicate(unpad_call))
# explode unpad() function
self.play(FadeTransform(unpad_call.copy(), code_3, stretch=False, dim_to_match=0))
self.next_section()
## fade out irrelevant parts of unpad()
self.play(
AnimationGroup(
*[line.animate.set_opacity(1/3) for line in code_3.code[:23]],
lag_ratio=0.03,
),
AnimationGroup(
*[line.animate.set_opacity(1/3) for line in code_3.code[29:][::-1]],
lag_ratio=0.03,
),
run_time=2,
)
focused_lines = remove_invisible_chars(code_3.code[23:28])
self.play(self.camera.frame.animate.move_to(focused_lines.get_center()).scale(0.4))
self.next_section()
"""
We get the padding length here. So far so good. The padding length itself
is secret data, but its location in the plaintext is not secret, so we can
access it safely.
Next we have some bounds checks, which are fine, but after those we get to this line.
Here we're slicing into the padded data to get the padding bytes.
This will allocate a new bytes object sized to the padding length,
then it will copy the padding bytes from the old buffer into this new one.
All of these operations depend on the padding length, which is secret.
So this is not constant-time.
The same thing happens over here:
this "b-char" function is fine, but then its return value gets extended out to the padding length,
so again we're allocating and writing memory based on a secret value, so that's another timing leak.
And, of course, the comparison here is not constant-time either.
It'll finish executing as soon as it finds a mismatch between its arguments.
Of these three issues, the comparison is actually the smallest issue, and also the easiest one to fix.
There's a function in the standard library called secrets.compare_digest
that could be used as a drop-in replacement for this equality check.
secrets.compare_digest is really a generic constant-time buffer comparison,
but for some reason they decided to name it after the specific use case of comparing digests.
However, even if this change was made,
the other two issues would still remain, and in general they are not so easy to fix.
To be fair, these timing side-channels are quite small.
If you were trying to measure this over a network, the ratio of signal to noise would be very heavily biased towards noise.
That said, this is not entirely impossible to exploit; the attack would just be very slow.
So, what can we do about this?
We've just seen that, while we can make the timing side-channel very small,
we can't eliminate it completely.
What we can do is prevent it from being exploited.
Recall that the padding oracle attack is a chosen-ciphertext attack.
So let's just prevent the attacker from choosing ciphertexts.
We'll do this by introducing message authentication codes.
In modern cryptosystems, authenticated encryption is the norm.
CBC mode by itself does not provide authentication, but we can fix that by bolting a MAC onto it.
Given the choice between doing this versus using a cipher mode that provides authentication by default, like GCM, I'd say use the authenticated mode.
This is just because with authenticated modes you don't have to handle the MAC ourself.
In fact, what I'd really say is to use an authenticated stream cipher like ChaCha20-Poly1305,
but we'll get to that in a later video.
"""
class c17_12_unpad_side_channels(MovingCameraScene):
def construct(self):
# setup
kwargs = {"language": "python", "tab_width": 4, "font": "Monospace", "line_spacing": 0.6, "style": "inkpot", "margin": 0.4, "background": "window", "insert_line_no": False}
#code_3 = Code("snippets/c17_unpad.py", **kwargs).scale(0.391).scale(scale_fac)
code_3 = Code("snippets/c17_unpad.py", **kwargs).scale(0.391).next_to(LEFT*0.5, RIGHT)
focused_lines = remove_invisible_chars(code_3.code[23:29])
code_3.code.set_opacity(1/3)
for line in focused_lines:
line.set_opacity(1)
self.add(code_3)
self.camera.frame.save_state().move_to(focused_lines[:-1].get_center()).scale(0.4)
code_4 = Code("snippets/c17_unpad_patched.py", **kwargs).scale(0.391)
focused_lines = remove_invisible_chars(code_4.code[23:31])
code_4.align_to(code_3, UL)
code_4.code.set_opacity(1/3)
for line in focused_lines:
line.set_opacity(1)
# highlight: getting the padding length
padlen = remove_invisible_chars(code_3.code[23])
self.play(Highlight(padlen, radial=True))
self.next_section()
# highlight: slicing into the padded data, b-char extension, and comparison
slice_1 = code_3.code[27][6:32]
slice_2 = code_3.code[27][-30:-1]
slice_3 = code_3.code[27][32:-30]
for mobj in slice_1, slice_2, slice_3:
self.play(Highlight(mobj, radial=True, buff=0.02))
self.next_section()
# compare_digest
self.play(CodeRewrite(code_3, code_4, run_time=2))
self.next_section()
# the buffer issues still remain
self.play(Highlight(remove_invisible_chars(code_4.code[28:30]), radial=True))
self.next_section()
# what can we do about this? prevent the attacker from choosing ciphertexts
# start by tearing down the code scene
cbc_blocks = CBCBlocks([None]*16*3, direction=DOWN).zoom(1/2)
self.play(
self.camera.frame.animate.restore(),
FadeTransform(code_4, cbc_blocks, rate_func=rate_functions.ease_in_expo),
run_time=1.5,
)
self.wait(0.5)
self.next_section()
"""
Now a rookie mistake here would be to just MAC the ciphertext.
As we've seen, we have to cover both the ciphertext and the IV,
because the attacker can use either of them to carry out this attack.
And, just as a general rule, there aren't many reasons to pass up integrity and authenticity checks.
It's just common sense to authenticate everything you can,
except maybe if you care about deniability - and even if you do care about deniability,
you still might be better off using a MAC and just disclosing the MAC keys after you've used them.
Strange as that may sound, it does work, and you can see it in action in the OTR protocol.
Now, there's a whole other conversation to be had about whether deniability is useful in practice,
and in fact there was a good talk about this at Real World Crypto 2023, but we won't go into that here.
Anyway, once we compute the MAC, we'll just append the result to the ciphertext.
Note that we're MACing the ciphertext, not the plaintext; this is an important distinction,
and a subject of some controversy. I won't get into it,
except to say that MACing plaintexts is a surprisingly common mistake
that has been recommended and implemented in some high-profile cases by people who really should know better.
Anyway, MACs are keyed, and the attacker can't compute the MAC without knowing the key.
So the server, who knows the key, can just recompute the MAC from the ciphertext,
then confirm that the tag on the message matches what they computed locally.
If it doesn't match, then the attacker has tampered with the message.
If it does match, then they can believe this message was created by someone with knowledge of the key.
Now, of course this doesn't prevent replay attacks,
but it does prevent attackers from making up ciphertexts and passing them off as authentic.
This is sufficient to prevent padding oracle attacks.
Now, a quick word about key management.
Technically we could reuse the encryption key as a MAC key, and it would probably be fine.
That said, it also feels like it invites trouble.
We would say that, if nothing else, this is bad cryptographic hygiene.
The best practice here is to use separate encryption and authentication keys,
just so there's as little coupling between these algorithms as possible.
Given a single key, these specialized keys are easy to derive.
For example, we could use a KDF, or just take some domain-separated hashes of the top-level key.
We won't worry about those details here;
we'll just write k' for the MAC key and k for the encryption key, and we'll leave it at that.
As another note, we haven't specified a MAC function here;
there are plenty of choices, but in practice, people who like CBC seem to like pairing it with HMAC.
The main disadvantage of this construct is its speed, or lack thereof.
But performance aside, from a cryptographic standpoint,
while CBC with HMAC falls short of being the absolute best, it's usually at least acceptable.
That said, this can still go very wrong if you make a mistake, as we will see shortly.
Anyway, just to recap, we've gone through a few layers of defense.
First, make sure externally served error messages are generic.
Second, validate padding in constant-time and minimize the timing leak in unpadding.
Third, and most importantly, you just MAC your ciphertexts,
so that attacker-crafted ciphertexts will be rejected as invalid.
Now let's switch back to the attacker's perspective and look at these same defenses,
focusing on how each one of them can fail in practice.
"""
class c17_13_mac_iv(MovingCameraScene):
def construct(self):
# setup
cbc_blocks = CBCBlocks([None]*16*4, direction=DOWN)
mac_block = cbc_blocks[-1].ct.rewrite(c_fills=C_MAC)
cbc_blocks = cbc_blocks[:-1]
self.add(cbc_blocks)
self.camera.frame.save_state().move_to(cbc_blocks.get_center()).scale(2)
to_mac = VGroup(*[block.ct for block in cbc_blocks])
kwargs = {"mobject": to_mac, "direction": RIGHT, "sharpness": 4.0}
brace = Brace(**kwargs)
self.play(FadeIn(brace, shift=LEFT))
to_mac.add(cbc_blocks[0].iv)
self.play(Transform(brace, Brace(**kwargs)))
# introduce MAC
mac_box = FuncBox(r"\textit{MAC}_{k'}").scale(0.7).next_to(brace, RIGHT).align_to(brace, DOWN)
arrows = mac_box.get_arrows((brace, RIGHT, UP), (mac_block, DOWN, RIGHT))
mac_func = VGroup(arrows[0], mac_box, arrows[1])
self.play(AnimationGroup(
Write(mac_func),
FadeIn(mac_block, shift=0.5*LEFT, run_time=0.5),
lag_ratio=2/3,
run_time=2,
))
self.next_section()
# ramble about keys (and algorithms)
funcs = [mac_box,
cbc_blocks[0].enc,
cbc_blocks[1].enc,
cbc_blocks[2].enc]
self.remove(*funcs)
self.add(*funcs)
self.play(AnimationGroup(
Indicate(mac_box, color=None, scale_factor=1.7),
AnimationGroup(*[Indicate(func, color=None, scale_factor=1.7) for func in funcs[1:]]),
lag_ratio=0.66,
run_time=3,
))
self.next_section()
# back to toc
toc = c17_00_toc().title_and_toc(indicated=4)[1][4]
self.remove(mac_block)
self.play(AnimationGroup(
FadeOut(*self.mobjects),
Transform(mac_block, toc),
self.camera.frame.animate.restore(),
run_time=3,
lag_ratio=0.3,
))
self.next_section()
self.wait(0.5)
"""
[over ToC: The attack (harder case)]
So, our first line of defense is minimizing the data leak.
We give generic error messages, we get our code as close as possible to constant time, and so on.
We might take further measures, depending on what side channels we're worried about, but that depends a lot on context:
for example, power analysis might be unlikely for attacking a desktop but very likely for attacking a smart card.
there are tons of other interesting potential side channels, including EM radiation and acoustic signals,
and that's a very fun rabbit hole to go down, though it's a bit too far off course for this video.
At the very least, though, we know timing side-channels are almost always in play.
Now, we'll do a deep dive into those when we get to challenges 31 and 32.
For now, let's abstract away the specifics of our side channel, and just say that we get a somewhat-reliable padding oracle from it.
It might not be perfectly reliable because these channels can contain noise which can throw off our measurements.
"""
"""
So, here's the question: as our side channel gets less reliable, how does the attack change?
Is it prevented? does it become impractical?
In particular, does a less reliable oracle necessarily lead to a less reliable attack?
Surprisingly, it turns out that the answer to this second question is "no".
As for the first question, that one has a slightly longer answer, but let's give it a shot.
Incidentally, this is the topic of the second blog post I mentioned earlier, which you can find a link to in the description below.
So if you have trouble following this discussion, maybe give that post a try.
Let's ease into it with some simple cases. Suppose we had an unreliable oracle that sometimes gives false negatives.
Whenever we give it invalid padding, it'll give a correct answer, but sometimes it'll mistakenly tell us valid padding is invalid.
Let's look at how our byte search might change.
Let's represent the search like this. I've reduced the search space from 256 elements to 16 so it fits better on screen.
This doesn't really change the underlying ideas at all, it just makes them easier to see.
"""
class c17_14_questions(Scene):
def construct(self):
toc = c17_00_toc().title_and_toc(indicated=5)[1][5]
qs = VGroup(
VGroup(Text("As our oracle gets less reliable,"),
Text("how does the attack change?")).arrange(DOWN),
VGroup(Text("Does a less reliable oracle"),
Text("imply a less reliable attack?")).arrange(DOWN)
).arrange(DOWN, buff=2)
self.play(
Unwrite(toc),
Succession(Write(qs[0]), Write(qs[1])),
)
self.next_section()
a2 = MarkupText('<span foreground="red"><u>No!</u></span>').next_to(qs[1], DOWN)
self.play(Write(a2))
a1 = MarkupText('<span foreground="red">Let\'s see...</span>').next_to(qs[0], DOWN)
self.play(Write(a1))
p_chart, e_chart, query_chart = c17_15_simple_cases().get_charts()
query_chart.move_to(ORIGIN)
self.play(LaggedStart(FadeOut(qs, a2),
Transform(a1, query_chart),
lag_ratio=1/3))
"""
We'll start by scanning through all 16 candidate byte values,
sending each one to the oracle and recording the responses we get.
We'll stop when we get a positive response.
With a perfect oracle, we expect this to happen on our first run through the search space.
However, since this oracle can give false negatives,
we may need to make several passes through the space before we get a hit.
This is trivial, though, and we see that on the third pass through the search space we find our result.
Our second simple case is an oracle that only gives false positives.
Again, these cases aren't meant to be realistic, they're just warmups.
This time, we get lots of hits on our first scan through the search space.
Every time we get a negative result, we rule out that guess.
Then we scan again through the positive results, building a new shortlist as we go.
Each scan reduces the number of candidates, until eventually we're left with just one,
at which point we know we've found the correct byte.
Interestingly, unlike the other search, this one seems to use a sort of winnowing strategy,
where some percentage of the candidate bytes are ruled out with each round of queries.
We can formalize this idea.
"""
LOG_SCALE = 3
from math import log10
class c17_15_simple_cases(Scene):
@staticmethod
def ys_to_logs(ys):
scale = 10**LOG_SCALE
return [log10((scale-1)*y+1) for y in ys] # good enough
def get_charts(self):
p_chart = BarChart(
values=[1/16]*16,
y_range=[0, 1, 0.2],
y_length=3,
x_length=5.5,
bar_names=['']*8 + ["Confidences"],
x_axis_config={"font_size": 36},
bar_colors=[BLUE],
)
e_chart = BarChart(
values=self.ys_to_logs([1/16]*16),
y_range=[0, LOG_SCALE, 1],
y_length=3,
x_length=5.5,
bar_names=['']*8 + ["E[Information Gained]"],
x_axis_config={"font_size": 36},
bar_colors=[GREEN],
)
charts = Group(p_chart, e_chart)
charts.arrange(DOWN, buff=0) # type: ignore
e_chart.align_to(p_chart, RIGHT)
query_chart = BarChart(
values=[0.00001]*16,
y_range=[0, 1, 0.2],
y_length=p_chart.height+3,
x_length=5.5,
bar_names=['']*8 + ["Queries"],
x_axis_config={"font_size": 36},
bar_colors=[BLUE],
)
query_chart.y_axis.numbers.set_opacity(0)
query_chart.y_axis.ticks.set_opacity(0)
Group(charts, query_chart).arrange(RIGHT, buff=0.25) # type: ignore
e_chart.y_axis.numbers.set_opacity(0)
e_chart.y_axis.add_labels({0: "0", 1: "0.01", 2: "0.1", 3: "1"})
return p_chart, e_chart, query_chart
def construct(self):
p_chart, e_chart, query_chart = self.get_charts()
query_chart.save_state().move_to(ORIGIN)
self.add(query_chart)
fn_text = VGroup(Text("False"), Text("negatives")).arrange(DOWN).next_to(query_chart, LEFT, buff=0.4)
self.play(Write(fn_text))
box_mobs = [[mob[0]] for mob in query_chart.bars]
box_size = p_chart.bars[0].get_width()
def get_box(index, result):
fill_color = BW_GREEN if result else RED
box = Square(
side_length=box_size,
stroke_color=C_STROKE, fill_color=fill_color,
stroke_width=2, fill_opacity=1
).next_to(box_mobs[index][-1], UP, buff=0.05)
box_mobs[index].append(box)
#query_chart.add(box)
return box
# false negatives
for _ in range(2):
self.play(LaggedStart(*[
FadeIn(get_box(i, False), shift=0.3*DOWN, run_time=1/3)
for i in range(16)
], lag_ratio=0.4))
#box = get_box(i, False)
#self.play(FadeIn(box, shift=0.3*DOWN, run_time=1/3))
self.next_section()
#for i in range(9):
# box = get_box(i, i == 8)
# self.play(FadeIn(box, shift=0.3*DOWN, run_time=1/3))
self.play(LaggedStart(*[
FadeIn(get_box(i, i == 8), shift=0.3*DOWN, run_time=1/3)
for i in range(9)
], lag_ratio=0.4))
self.play(Circumscribe(VGroup(*box_mobs[8]), color=BLACK))
self.next_section()
self.play(*[Unwrite(box) for col in box_mobs for box in col[1:]], run_time=1.5)
self.next_section()
box_mobs = [col[:1] for col in box_mobs]
fp_text = VGroup(Text("False"), Text("positives")).arrange(DOWN).move_to(fn_text)
self.play(Transform(fn_text[1], fp_text[1]))
seed = bytes.fromhex("41ef5e793231800db1cb43bc9fb29bf1fdafe0b5f78da0f9d28b9e2a50ac18a6cebdf90e49a28d2bcfbbc6d1c87894e45dac893c6048b39dc8a77d12d959652fd738ab7633960ab3b9d4a66bb5301a09914a8958b5b5482438af7a7a70c0df56a21e2719335a250da549de3d0ab6678356078f69ec70680885d9ee68b0eb419e")
_rng.seed(seed)
positives = [True]*16
flag = True
while positives.count(True) > 1:
old_positives = positives
positives = [b and _rng.random() > 0.4 or i == 8 for i, b in enumerate(positives)]
self.play(LaggedStart(*[
FadeIn(get_box(i, b), shift=0.3*DOWN, run_time=1/3)
for i, (b, _b) in enumerate(zip(positives, old_positives))
if (_b or flag)
], lag_ratio=0.4))
#if b and positives.count(True) == 1: break
#self.play(FadeIn(get_box(i, b), shift=0.3*DOWN, run_time=1/3))
flag = False
self.play(Circumscribe(VGroup(*box_mobs[8]), color=BLACK))
self.next_section()
self.play(
*[Unwrite(box) for col in box_mobs for box in col[1:]],
Unwrite(fn_text[0]),
Unwrite(fn_text[1]),
#FadeOut(query_chart),
run_time=1.5
)
self.next_section()
box_mobs = [col[:1] for col in box_mobs]
"""
Let's add a second chart here for our confidence in each candidate.
This chart tracks the probability of that candidate being the one we're searching for.
At the start of the search, we don't know anything, so we consider each value to be equally likely.
Since our search space here has size sixteen, these are all at 1/16th, or about point-oh-six.
After our first round of queries, the oracle has given us five negative results.
This oracle only gives false positives, not false negatives, so we know these are real negatives.
Accordingly, we adjust the probabilities for those values down to 0,
and to compensate for this we raise the all remaining probabilities up to 1/11.
This process continues intuitively enough, until one of our options hits a confidence level of 1.
You'll notice I was updating the confidences after each round of queries,
because that's when it's easiest to figure out.
We have a shortlist of possible search results,
and we have an equal amount of evidence in favor of each of them,
so we consider them equally likely.
Similarly, if we were to chart confidences for the only-false-negatives case,
after each round of queries the confidences would all be equal,
until we finally get a positive result,
at which point we take that value's confidence up to 1 while all the others go to 0.
So in these special cases, the confidences are easy to evaluate.
But let's ask a leading question which we won't answer immediately.
Let's ask: what would these confidences look like in the middle of a round?
In the false-negative case,
if you have four negative results for one value, and only three for another,
surely the value with four negatives is slightly less likely to be correct.
But how much less likely?
Similarly, in the false-positive case,
if you have two frontrunners, one of which has more positive results than the other,
you might consider that one more likely to be correct.
But how can we quantify this intuition?
We'll come back to this question, but first,
let's break out of these simplified cases and consider the combined case,
where we have both false positives and false negatives.
This is the type of noisy oracle you can expect to encounter in the real world,
and knowing how to handle it is crucial to understanding this attack in practice.
We'll set the confidences chart aside for now, and just look at some outputs from this oracle.
We see false negatives and false positives interspersed throughout the output.
Try and guess which value is correct here.
You might get it, but it's not immediately obvious how to make a guess here,
or how confident we should be in our guess.
After all, whenever we see a combination of positive and negative results,
we know that some of those results must be errors, but we don't know which ones.
In fact, no matter how many queries we make,
we can never rule out this possibility.
This is where previous work has consistently run into challenges,
coming up with increasingly complex statistical tests to run on groups of queries,
in order to consolidate them into an overall result.
This is a promising idea,
but the methods I've seen ultimately end up relying on huge sample sizes to produce a reliable oracle,
and then making a large number of queries to this oracle.
Another option, which I propose in the second blog post I mentioned,
is to try to create a reliable byte search out of an unreliable oracle.
In other words, rather than trying to make the oracle reliable,
we'll figure out how to use it directly in spite of its unreliability.
But how do we do this? Well, I have some good news and some other news.
The good news is that this can be done.
The other news is, uh, that it involves math.
Specifically, it involves a very famous result from probability theory known as Bayes' theorem.
[EDIT: SLIDE TRANSITION TO NEXT SCENE]
"""
class c17_16_confidences(MovingCameraScene):
def construct(self):
p_chart, e_chart, query_chart = c17_15_simple_cases().get_charts()
query_chart.save_state().move_to(ORIGIN)
self.play(
query_chart.animate.restore(),
FadeIn(p_chart, shift=3.1*RIGHT),
run_time=3,
)
box_mobs = [[mob[0]] for mob in query_chart.bars]
box_size = p_chart.bars[0].get_width()
def get_box(index, result):
fill_color = BW_GREEN if result else RED
box = Square(
side_length=box_size,
stroke_color=C_STROKE, fill_color=fill_color,
stroke_width=2, fill_opacity=1
).next_to(box_mobs[index][-1], UP, buff=0.05)
box_mobs[index].append(box)
#query_chart.add(box)
return box
seed = bytes.fromhex("41ef5e793231800db1cb43bc9fb29bf1fdafe0b5f78da0f9d28b9e2a50ac18a6cebdf90e49a28d2bcfbbc6d1c87894e45dac893c6048b39dc8a77d12d959652fd738ab7633960ab3b9d4a66bb5301a09914a8958b5b5482438af7a7a70c0df56a21e2719335a250da549de3d0ab6678356078f69ec70680885d9ee68b0eb419e")
_rng.seed(seed)
positives = [True]*16
flag = True
while positives.count(True) > 1:
old_positives = positives
positives = [b and _rng.random() > 0.4 or i == 8 for i, b in enumerate(positives)]
self.play(LaggedStart(*[
FadeIn(get_box(i, b), shift=0.3*DOWN, run_time=1/3)
for i, (b, _b) in enumerate(zip(positives, old_positives))
if (_b or flag)
], lag_ratio=0.4))
#if b and positives.count(True) == 1: break
#self.play(FadeIn(get_box(i, b), shift=0.3*DOWN, run_time=1/3))
flag = False
self.play(p_chart.animate.change_bar_values(
[1/positives.count(True) if b else 0 for b in positives]
))
self.next_section()
self.play(
*[Unwrite(box) for col in box_mobs for box in col[1:]],
p_chart.animate.change_bar_values([1/16]*16),
)
self.next_section()
#bayes_1 = MathTex(
# r"P(H|E) = \frac{P(E|H) P(H)}{P(E)}",
# font_size=32,
#).next_to(p_chart, DOWN, buff=1.55)
#bayes_2 = MathTex(
# r"P(H|E) = \frac{P(E|H) P(H)}{P(E|H)P(H) + P(E|\neg H) P(\neg H)}",
# font_size=32,
#).next_to(p_chart, DOWN, buff=1.55)
#self.play(Write(bayes_1))
#self.next_section()
#self.play(
# Transform(bayes_1[0][:18], bayes_2[0][:18]),
# FadeTransform(bayes_1[0][18:], bayes_2[0][18:]),
#)
#self.next_section()
#self.play(self.camera.frame.animate.move_to(bayes_1).set_width(bayes_1.width*1.05))
#self.next_section()
## breakdown
#frame = self.camera.frame
#divider = VGroup(Line(frame.get_edge_center(UP), frame.get_edge_center(DOWN)),
# Line(frame.get_edge_center(LEFT), frame.get_edge_center(RIGHT)))
#self.play(FadeIn(divider))
#self.wait()
# false negatives (again)
box_mobs = [col[:1] for col in box_mobs]
for _ in range(2):
self.play(LaggedStart(*[
FadeIn(get_box(i, False), shift=0.3*DOWN, run_time=1/3)
for i in range(16)
], lag_ratio=0.1))
#box = get_box(i, False)
#self.play(FadeIn(box, shift=0.3*DOWN, run_time=1/3))
self.next_section()
self.play(LaggedStart(*[
FadeIn(get_box(i, i == 8), shift=0.3*DOWN, run_time=1/3)
for i in range(9)
], lag_ratio=0.4))
self.play(
Circumscribe(VGroup(*box_mobs[8]), color=BLACK),
p_chart.animate.change_bar_values([0]*8+[1]+[0]*7),
)
self.next_section()
self.play(
*[Unwrite(box) for col in box_mobs for box in col[1:]],
)
self.next_section()
# discard p chart for now, show FP+FN case
box_mobs = [col[:1] for col in box_mobs]
self.play(
FadeOut(p_chart, shift=3.1*LEFT),
query_chart.animate.move_to(ORIGIN),
)
self.next_section()
_rng.seed(_enc(b'false positives and false negatives'*17))
def oracle(i): return (i == 8) ^ (_rng.random() < FN_RATE) # blemish: technically overrides FP_RATE (but they're equal so who cares)
for _ in range(9):
self.play(LaggedStart(*[
FadeIn(get_box(i, oracle(i)), shift=0.3*DOWN, run_time=1/3)
for i in range(16)
], lag_ratio=0.4))
self.next_section()
self.play(
*[Unwrite(box) for col in box_mobs for box in col[1:]],
)
self.wait(0.5)
self.next_section()
"""
[SLIDE IN]
Bayes' theorem looks like this. I'll give you a brief run-down, but if you want the full details,
I'll refer you to 3Blue1Brown's excellent video on this topic,
which I learned a few things from myself.
I'm sure there are other good videos on this as well,
but this is the only one I've seen,
so it's the only one I can personally recommend.
I'm not going to try to recap it;
instead, what follows is a very quick high-level overview.
Bayes' theorem basically allows you to adjust your confidence in a hypothesis
as you gather evidence related to that hypothesis.
In this simplified example, we have sixteen concurrent hypotheses;
For a full byte search, we would have 256 hypotheses.
Each hypothesis tracks the possibility of its corresponding value being the target of our search:
We'll use Bayes' theorem to adjust each of these hypotheses after each oracle query.
The first thing we'll do is expand the bottom of this fraction,
which makes it harder to read but easier to work with.
Now, for any given hypothesis, we have four cases to consider:
the evidence can be for either the same value or a different value as the hypothesis,
and it can return either a positive or a negative result.
We'll indicate positive or negative results with T and F,
and when the result concerns a different byte, we'll write T-prime or F-prime.
We also will assume we know the false-positive and false-negative rates,
which we'll call p1 and p2 respectively.
For each of these cases, we need to evaluate each of the individual probabilities in Bayes' theorem.
I'm not going to talk through these, because personally,
I've never been the type who can follow along as someone reads math to me,
and so I'm not gonna be the type of person who reads math to others.
That said, if you want to work through all the casework, I'd encourage you to do so,
because it's really no harder than your average undergrad stats homework,
or you can check the aforementioned blog post, where I do spell all of this out in detail.
Given this, we have a mathematically sound way of updating our confidences after every single oracle query.
There's one other topic to cover before we're done talking about math.
Let's talk about information theory, and specifically about entropy.
We can compute the entropy of any probability distribution in this way.
This is useful for us, because our confidences in each byte can be run through this equation,
and the resulting number will give us a sense of how much "uncertainty", so to speak, is left in our search.
Taking this a step further, we know we have estimates, for each input value,
of how likely a positive or negative oracle result would be,
and we also know, from Bayes' theorem, how the distribution of confidences would change in either case,
and now we know how to quantify the entropies of the resulting distributions.
This allows us to find expected reductions in entropy for each potential query.
You can also think of this as the expected amount of information gained from each query.
[SLIDE BACK]
"""
class c17_17_bayes(MovingCameraScene):
def construct(self):
bayes_1 = MathTex(
r"P(H|E) = \frac{P(E|H) P(H)}{P(E)}",
)
FULL_BAYES = r"P(H|E) = \frac{P(E|H) P(H)}{P(E|H)P(H) + P(E|\neg H) P(\neg H)}"
bayes_2 = MathTex(FULL_BAYES)
self.add(bayes_1)
self.camera.frame.move_to(bayes_1).set_width(bayes_2.width*1.1)
self.wait(0.1)
self.play(
Transform(bayes_1[0][:18], bayes_2[0][:18]),
FadeTransform(bayes_1[0][18:], bayes_2[0][18:]),
)
self.next_section()
# breakdown
frame = self.camera.frame
bayeses = [bayes_2.copy() for _ in range(4)]
anchors = [Point(frame.get_corner(UL)),
Point(frame.get_edge_center(UP)),
Point(frame.get_edge_center(LEFT)),
ORIGIN]
labels = [Text(f"{s1} value, {s2} result", color=DARK_GRAY, stroke_color=GRAY).scale(0.25).next_to(point, DR, buff=0.1)
for (s1, s2), point in zip(product(("Same", "Different"), ("positive", "negative")),
anchors)]
self.remove(*self.mobjects)
divider = VGroup(Line(frame.get_edge_center(UP), frame.get_edge_center(DOWN)),
Line(frame.get_edge_center(LEFT), frame.get_edge_center(RIGHT)))
self.play(LaggedStart(
AnimationGroup(*[
eqn.animate(lag_ratio=0.007).scale(0.4).next_to(point, DR).shift(0.2*DOWN)
for eqn, point in zip(bayeses, anchors)
], run_time=2.5),
Write(divider),
lag_ratio=2/3,
))
bayeses_ij = [
MathTex(FULL_BAYES.replace("E", "T_i").replace("H", "H_i")).scale(0.4).move_to(bayeses[0]).align_to(bayeses[0], LEFT),
MathTex(FULL_BAYES.replace("E", "F_i").replace("H", "H_i")).scale(0.4).move_to(bayeses[1]).align_to(bayeses[1], LEFT),
MathTex(FULL_BAYES.replace("E", "T_i").replace("H", "H_j")).scale(0.4).move_to(bayeses[2]).align_to(bayeses[2], LEFT),
MathTex(FULL_BAYES.replace("E", "F_i").replace("H", "H_j")).scale(0.4).move_to(bayeses[3]).align_to(bayeses[3], LEFT),
]
self.play(LaggedStart(*[
AnimationGroup(
Write(label, lag_ratio=0.01, run_time=2),
FadeTransform(bayes, new_bayes)
) for label, bayes, new_bayes in zip(labels, bayeses, bayeses_ij)
], lag_ratio=1/4))
#self.add(index_labels(bayeses_ij[0][0], background_stroke_width=2))
terms_UL = MathTex(
r"P(H_i) &=1/16 \, \textit{initially}\\ "
r"P(\neg H_i) &=1 - P(H_i) \\ "
r"P(T_i | H_i) &=1-p_2 \\ "
r"P(T_i | \neg H_i) &=p_1 \\ ",
tex_template=TexTemplate().add_to_preamble(r"\usepackage{xfrac}")
).scale(0.4).next_to(bayeses_ij[0], DOWN)
terms_UR = MathTex(
r"P(H_i) &=1/16 \, \textit{initially} \\ "
r"P(\neg H_i) &=1 - P(H_i) \\ "
r"P(F_i | H_i) &=p_2 \\ "
r"P(F_i | \neg H_i) &=1-p_1 \\ ",
tex_template=TexTemplate().add_to_preamble(r"\usepackage{xfrac}")
).scale(0.4).next_to(bayeses_ij[1], DOWN)
terms_DL = MathTex(
r"P(H_i) &=1/16 \, \textit{initially} \\ "
r"P(\neg H_i) &=1 - P(H_i) \\ "
r"P(T_i | H_j) &=p_1 \\ "
r"P(T_i | \neg H_j) &= P(H_i | \neg H_j)(1-p_2) "
r" + P(\neg H_i | \neg H_j) p_1 \\ "
r"P(H_i | \neg H_j) &= \frac{P(H_i)}{1-P(H_j)} \,;\, P(\neg H_i | \neg H_j) = 1 - P(H_i | \neg H_j)",
tex_template=TexTemplate().add_to_preamble(r"\usepackage{xfrac}")
).scale(0.4).next_to(bayeses_ij[2], DOWN).shift(0.25*RIGHT)
terms_DR = MathTex(
r"P(H_i) &=1/16 \, \textit{initially} \\ "
r"P(\neg H_i) &=1 - P(H_i) \\ "
r"P(F_i | H_j) &= 1 - p_1 \\ "
r"P(F_i | \neg H_j) &= P(\neg H_i | \neg H_j) (1-p_1) + P(H_i | \neg H_j) p_2 \\ "
r"P(H_i | \neg H_j) &= \frac{P(H_i)}{1-P(H_j)} \,;\, P(\neg H_i | \neg H_j) = 1 - P(H_i | \neg H_j)",
tex_template=TexTemplate().add_to_preamble(r"\usepackage{xfrac}")
).scale(0.4).next_to(bayeses_ij[3], DOWN).shift(0.25*RIGHT)
self.play(Write(terms_UL),
Write(terms_UR),
Write(terms_DL),
Write(terms_DR),
run_time=1)
self.wait(0.5)
#entropy_eqn_1 = MathTex(r"-\sum_{x \in \mathcal{X}}", r"P(x)", r"\log P(x)")
#entropy_eqn_2 = MathTex(r"-\sum_{i = 1}^{256}", r"P(H_i)", r"\log P(H_i)").align_to(entropy_eqn_1, DL)
entropy_eqn_1 = MathTex(r"-\sum_{x \in \mathcal{X}}", r"P(", r"x", r")", r"\log P(", r"x", r")")
entropy_eqn_2 = MathTex(r"-\sum_{i = 0}^{2^n-1}", r"P(", r"H_i", r")", r"\log P(", r"H_i", r")")
entropy_eqn_2.shift(entropy_eqn_1[0][0].get_center()-entropy_eqn_2[0][0].get_center())
self.play(
Unwrite(terms_UL, reverse=False),
Unwrite(terms_UR, reverse=False),
Unwrite(terms_DL, reverse=False),
Unwrite(terms_DR, reverse=False),
Unwrite(divider),
*[Unwrite(label) for label in labels],
*[Unwrite(bayes) for bayes in bayeses_ij],
Succession(Wait(0.5), Write(entropy_eqn_1)),
run_time=3,
)
self.next_section()
self.play(
Transform(entropy_eqn_1[0][0], entropy_eqn_2[0][0]),
FadeIn(entropy_eqn_2[0][1:5], shift=0.06*RIGHT),
*[Transform(part1, part2) for part1, part2 in zip(entropy_eqn_1[0][1:], entropy_eqn_2[0][5:])],
*[Transform(part1, part2) for part1, part2 in zip(entropy_eqn_1[1:], entropy_eqn_2[1:])],
)
self.next_section()
"""
We'll add a second chart here to track these expectations for each byte.
This one uses a nonlinear scale, because these values are very small.
Now let's see what happens to these charts as we gather oracle queries.
"""
class c17_18_entropy_intro(Scene):
def construct(self):
p_chart, e_chart, query_chart = c17_15_simple_cases().get_charts()
query_chart.save_state().move_to(ORIGIN)
self.add(query_chart)
self.wait(0.1)
self.play(
query_chart.animate.restore(),
FadeIn(p_chart, shift=3.1*RIGHT),
FadeIn(e_chart, shift=3.1*RIGHT),
run_time=2,
)
class ByteSearch:
def __init__(self, oracle, confidence_threshold=0.999, quiet=True, hook=None):
self._counter = 0
self.oracle = oracle
self.query_log = []
self.confidences = [Decimal(1)/search_bound]*search_bound
self.confidence_threshold = confidence_threshold
self.quiet = quiet
self.hook = hook
def update_confidences(self, index, result):
"""Given an oracle result for a given byte, update the confidences for each byte."""
self.confidences = self.get_updated_confidences(self.confidences, index, result)
def pick_exhaustive(self):
return self._counter % search_bound
def pick_by_confidence(self):
"""Pick a byte to test based on the current confidences."""
return max(range(search_bound), key=lambda i: self.confidences[i])
def pick_by_entropy(self):
"""Pick a byte to test based on expected reduction in entropy."""
# NOTE: VERY SLOW
entropies = self.get_entropies()
return min(range(search_bound), key=lambda i: entropies[i])
def query_byte(self, index):
"""Query the oracle for a given byte."""
self._counter += 1
result = self.oracle(index)
self.query_log.append((index, result))
self.update_confidences(index, result)
if not self.quiet and self._counter & 0xFF == 0:
print(end=".", flush=True)
return result
def search(self, strategy):
"""Search for the plaintext byte by querying the oracle."""
threshold = self.confidence_threshold
while (_max := max(self.confidences)) < threshold:
#if _max > self._max + 0.05:
# self._max = _max
# print(f"max={_max:.5f} after {self._counter} queries")
self.query_byte(strategy())
if self.hook is not None:
self.hook(self)
return max(range(search_bound), key=lambda i: self.confidences[i])
def get_entropies(self):
entropies = []
for i in range(search_bound):
e_if_t = self.get_entropy(self.get_updated_confidences(self.confidences, i, True))
e_if_f = self.get_entropy(self.get_updated_confidences(self.confidences, i, False))
p_t = self.confidences[i]
p_f = 1 - p_t
entropies.append(p_t * e_if_t + p_f * e_if_f)
return entropies
@staticmethod
def bayes(h, e_given_h, e_given_not_h):
"""Update the posterior probability of h given e.
e: evidence
h: hypothesis
e_given_h: probability of e given h
e_given_not_h: probability of e given not h
"""
return e_given_h * h / (e_given_h * h + e_given_not_h * (1 - h))
@staticmethod
def get_updated_confidences(confidences, index, result):
new_confidences = confidences[:] # shallow copy
for j in range(search_bound):
p_h = confidences[j]
if index == j:
p_e_given_h = 1 - FN_RATE if result else FN_RATE
p_e_given_not_h = FP_RATE if result else 1 - FP_RATE
else:
p_e_given_h = FP_RATE if result else 1 - FP_RATE
p_hi_given_not_hj = confidences[index] / (1 - confidences[j])
p_not_hi_given_not_hj = 1 - p_hi_given_not_hj
if result:
p_e_given_not_h = p_hi_given_not_hj * (1 - FN_RATE) + p_not_hi_given_not_hj * FP_RATE
else:
p_e_given_not_h = p_hi_given_not_hj * FN_RATE + p_not_hi_given_not_hj * (1 - FP_RATE)
new_confidences[j] = ByteSearch.bayes(p_h, p_e_given_h, p_e_given_not_h)
return new_confidences
@staticmethod
def get_entropy(dist):
return -sum(p * Decimal(log2(p)) for p in dist if p)
### Helper for generating single-byte oracles
search_bound = 16
oracle_seed = [b'false positives and false negatives'*17]
def get_oracle(target_byte):
rng = Random()
rng.seed(_enc(oracle_seed[0]))
def oracle(index):
if index == target_byte:
return rng.random() > FN_RATE
return not (rng.random() > FP_RATE)
return oracle
"""
These are the exact same oracle responses as earlier.
Look at how quickly most of the options are ruled out.
After a couple rounds of queries, we have a shortlist.
It's interesting seeing how the expected information gain correlates with the change in confidences.
Sometimes they move the same direction, but sometimes they don't. We'll come back to that.
Anyway, after a few more rounds, we have a clear frontrunner.
We can continue iterating until this frontrunner reaches whatever level of confidence we're looking for.
As we keep going, it becomes increasingly clear that a lot of queries are being wasted here
asking about bytes that we've already ruled out.
So a natural next step would be to consider using some kind of guiding heuristic to speed up this search.
And we have two reasonable candidates for this heuristic listed on the left-hand side of the screen.
"""
class c17_19_exhaustive_with_stats(Scene):
@staticmethod
def ys_to_logs(ys):
ys = [max(y, 0.00000001) for y in ys]
scale = 10**LOG_SCALE
return [log10((scale-1)*y+1) for y in ys] # good enough
def construct(self, strategy="exhaustive"):
p_chart = BarChart(
values=[1/16]*16,
y_range=[0, 1, 0.2],
y_length=3,
x_length=5.5,
bar_names=['']*8 + ["Confidences"],
x_axis_config={"font_size": 36},
bar_colors=[BLUE],
)
e_chart = BarChart(
values=self.ys_to_logs([1/16]*16),
y_range=[0, LOG_SCALE, 1],
y_length=3,
x_length=5.5,
bar_names=['']*8 + ["E[Information Gained]"],
x_axis_config={"font_size": 36},
bar_colors=[GREEN],
)
charts = Group(p_chart, e_chart)
charts.arrange(DOWN, buff=0) # type: ignore
e_chart.align_to(p_chart, RIGHT)
query_chart = BarChart(
values=[0.00001]*16,
y_range=[0, 1, 0.2],
y_length=p_chart.height+3,
x_length=5.5,
bar_names=['']*8 + ["Queries"],
x_axis_config={"font_size": 36},
bar_colors=[BLUE],
)
query_chart.y_axis.numbers.set_opacity(0)
query_chart.y_axis.ticks.set_opacity(0)
Group(charts, query_chart).arrange(RIGHT, buff=0.25) # type: ignore
e_chart.y_axis.numbers.set_opacity(0)
e_chart.y_axis.add_labels({0: "0", 1: "0.01", 2: "0.1", 3: "1"})
self.add(p_chart, e_chart, query_chart)
e_bars_r = e_chart.copy()
e_bars_r.bar_colors = [RED]
e_bars_r.change_bar_values([LOG_SCALE]*16, update_colors=True)
e_bars_r = e_bars_r.bars.set_opacity(0.3)
e_bars_g = e_chart.copy()
e_bars_g.bar_colors = [BW_GREEN]
e_bars_g.change_bar_values([LOG_SCALE]*16, update_colors=True)
e_bars_g = e_bars_g.bars.set_opacity(0.3)
p_bars_r = p_chart.copy()
p_bars_r.bar_colors = [RED]
p_bars_r.change_bar_values([1]*16, update_colors=True)
p_bars_r = p_bars_r.bars.set_opacity(0.3)
p_bars_g = p_chart.copy()
p_bars_g.bar_colors = [BW_GREEN]
p_bars_g.change_bar_values([1]*16, update_colors=True)
p_bars_g = p_bars_g.bars.set_opacity(0.3)
box_mobs = [[mob[0]] for mob in query_chart.bars]
box_size = p_chart.bars[0].get_width()
def log_query(index, result):
fill_color = BW_GREEN if result else RED
box = Square(
side_length=box_size,
stroke_color=C_STROKE, fill_color=fill_color,
stroke_width=2, fill_opacity=1
).next_to(box_mobs[index][-1], UP, buff=0.05)
box_mobs[index].append(box)
#query_chart.add(box)
return box
self.wait(0.1)
def hook(search):
index, result = search.query_log[-1]
box = log_query(index, result)
curr_entropy = search.get_entropy(search.confidences)
info_gained = [float(curr_entropy - e) for e in search.get_entropies()]
if result:
p_bars = p_bars_g
e_bars = e_bars_g
else:
p_bars = p_bars_r
e_bars = e_bars_r
t_1 = 1/6
t_2 = 1/8
self.play(
FadeIn(box, shift=0.3*DOWN, run_time=t_2),
p_chart.animate(run_time=t_1).change_bar_values(list(map(float, search.confidences))),
e_chart.animate(run_time=t_1).change_bar_values(self.ys_to_logs(info_gained)),
FadeOut(p_bars[index].copy(), run_time=t_2, rate_func=rush_into),
FadeOut(e_bars[index].copy(), run_time=t_2, rate_func=rush_into),
#run_time=1/3
)
if strategy != "exhaustive":
self.wait(1/6)
TARGET_BYTE = 8
oracle = get_oracle(TARGET_BYTE)
search = ByteSearch(oracle, quiet=False, hook=hook)
result = search.search(strategy={
"confidence": search.pick_by_confidence,
"entropy": search.pick_by_entropy,
"exhaustive": search.pick_exhaustive,
}[strategy])
print(result == TARGET_BYTE)
self.wait()
to_drop = [box for boxes in box_mobs for box in boxes[1:]]
self.play(FadeOut(*to_drop, run_time=0.5),
p_chart.animate(run_time=1).change_bar_values([1/16]*16),
e_chart.animate(run_time=1).change_bar_values(self.ys_to_logs([1/16]*16)))
for i in range(16):
box_mobs[i] = [box_mobs[i][0]]
"""
We'll try using the maxima from each of these charts to guide our search.
Now, if we want to be efficient with our oracle queries,
then it might seem intuitively obvious to query for the value that'll tell us the most,
or in other words, the value that maximizes expected information gained.
Here's an example of that heuristic in action.
The first thing we might notice is that it takes a fairly breadth-first, scattershot approach,
but it starts to hone in on interesting options after a while.
Interestingly, notice that it tends to avoid immediately following up on positive results,
though it doesn't mind following up on negative ones.
Roughly speaking, this is because a follow-up query on a positive result carries higher potential downside,
since it might undo the increase in that byte's confidence,
whereas a follow-up query on a negative result carries more upside for similar reasons.
This behavior starts to shift as we gather more information
and our level of confidence in the frontrunner bytes increases.
This might seem unintuitive at first, but it starts to make sense the more you think about it.
In any case, we can see that this heuristic is already a huge improvement over the more exhaustive strategy.
"""
SearchScene = c17_19_exhaustive_with_stats
class c17_20_info_gained(SearchScene):
def construct(self):
oracle_seed[0] = b'another random seed value'*17
super().construct(strategy="entropy")
"""
Let's see an even bigger improvement. We'll take another go at the problem,
but this time we'll let the confidences guide our search directly.
This is a super simple heuristic:
at every step, we'll just query for whichever byte currently has the highest confidence.
We get a handful of false starts here, just like we did with the other heuristic,
but the strategy handles them gracefully and eventually it ends up in the right place.
That was over too quick, so let's run it again.
In the average case, the confidence-guided heuristic tends to require a few times fewer oracle queries
than the information-guided heuristic.
The confidence heuristic is significantly less inclined to explore the entire search space
unless it needs to, and it hones in on promising bytes as soon as it gets a positive result,
which gives it an edge over the information-guided heuristic for large search spaces.
Rather than ruling bytes out by querying for them and getting negative results,
it is happy to rule them out implicitly by honing in on the most promising byte.
All in all, this tends to work remarkably well.
That said, while the confidence-guided heuristic is faster on average,
it also has higher variance, meaning that the number of queries per search is less consistent
than with the entropy-guided heuristic.
That's because it is more willing to go down dead ends if they seem promising at first,
as we can see here.
When these dead ends turn out not to pan out, we find that we've spent a lot of queries to end up more or less back where we started.
Usually we're willing to accept this risk in exchange for the low average query count,
but this is still worth knowing about.
Anyway, all these dead ends get ruled out before our confidence in them reaches the cutoff threshold,
and eventually we do find the correct byte.
How do we know it's the correct byte?
Well, in this case, you know because I'm telling you.
In general, we won't know this with certainty, but we can get arbitrarily close to certainty,
because our confidence threshold can be set arbitrarily close to 1.
This is not even expensive: once the error rate is low, lowering it by another order of magnitude costs only a small, fixed number of oracle queries in the average case,
though of course, the exact number depends on the oracle's error rate.
"""
class c17_21_confidence_1(SearchScene):
def construct(self):
oracle_seed[0] = b'confidence guided search'*17
super().construct(strategy="confidence")
class c17_22_confidence_2(SearchScene):
def construct(self):
oracle_seed[0] = b'part deux of the confidence guided strategy'*17
super().construct(strategy="confidence")
"""
"""
class c17_23_hard_case_outro(Scene):
def construct(self):
p_chart = BarChart(
values=[1/16]*16,
y_range=[0, 1, 0.2],
y_length=3,
x_length=5.5,
bar_names=['']*8 + ["Confidences"],
x_axis_config={"font_size": 36},
bar_colors=[BLUE],
)
e_chart = BarChart(
values=c17_19_exhaustive_with_stats.ys_to_logs([1/16]*16),
y_range=[0, LOG_SCALE, 1],
y_length=3,
x_length=5.5,
bar_names=['']*8 + ["E[Information Gained]"],
x_axis_config={"font_size": 36},
bar_colors=[GREEN],
)
charts = Group(p_chart, e_chart)
charts.arrange(DOWN, buff=0) # type: ignore
e_chart.align_to(p_chart, RIGHT)
query_chart = BarChart(
values=[0.00001]*16,
y_range=[0, 1, 0.2],
y_length=p_chart.height+3,
x_length=5.5,
bar_names=['']*8 + ["Queries"],
x_axis_config={"font_size": 36},
bar_colors=[BLUE],
)
query_chart.y_axis.numbers.set_opacity(0)
query_chart.y_axis.ticks.set_opacity(0)
Group(charts, query_chart).arrange(RIGHT, buff=0.25) # type: ignore
e_chart.y_axis.numbers.set_opacity(0)
e_chart.y_axis.add_labels({0: "0", 1: "0.01", 2: "0.1", 3: "1"})
self.add(p_chart, e_chart, query_chart)
toc = c17_00_toc().title_and_toc(indicated=5)[1][5]
self.remove(e_chart.bars)
# transition the "information gained" bars into the ToC entry lol
self.play(
Transform(e_chart.bars, toc, run_time=1.5),
FadeOut(*self.mobjects, run_time=0.75),
)
"""
[over ToC]
So that covers the hard case where the oracle is unreliable.
There's more to say here but I don't want to get too sidetracked,
especially since we'll be diving back into this topic when we get to challenges 31 and 32,
where we'll be exploiting timing side-channels with the end goal of MAC forgery.
I'm planning for that video to go way further into the weeds on this stuff.
For now, this should do.
We have one more case of this attack to cover, and that's the case where MACs are added.
In contrast to the hard case, I'm going to basically gloss over this one,
because explaining it in detail would mean going into MAC forgery.
But the short version is:
since the padding oracle attack is a chosen ciphertext attack,
you prevent it by authenticating your ciphertexts,
at least as long as you do so correctly.
"""
"""
that said, let's add a little context.
the padding oracle attack is also an attack on CBC mode specifically.
CBC doesn't have authentication built in, so we have to add it ourselves.
and, as we mentioned earlier, people who use CBC tend to use HMAC for this.
Now, HMAC is not fast but it is a solid and reliable MAC.
However, many APIs for it just give you a MAC tag,
and leave you to write the validation logic yourself,
and if you aren't careful here then you can still get into trouble.
Let's look at an example.
Here's that same Flask app from earlier. Let's add HMAC to it.
We'll assume that the last 32 bytes of the message are the MAC tag.
We'll generate the expected tag, then compare the actual tag against it.
Now, the problem here is that this comparison does not happen in constant time.
As a result, and as we'll see again at the end of set 4,
attackers can use this to infer valid MACs one byte at a time for arbitrary ciphertexts.
After recovering a complete MAC for their chosen ciphertext,
the attacker can use it to bypass this HMAC check
and continue their chosen-ciphertext attack.
From the attacker's perspective, the downside of this new attack
is that it does require one MAC forgery per chosen ciphertext.
This raises the cost of the attack a lot.
However, the same ciphertext can be queried many times without having to re-forge the MAC.
This means we can collect large sample sizes for each query,
and, as we discussed earlier, a whole range of statistical tricks exist
for synthesizing these samples into a more accurate oracle.
Nevertheless, this extra step raises the cost of the attack by a lot.
In monitored environments, this might be a problem; otherwise, it's probably fine, albeit very slow.
In any case, the fix is simple: replace this comparison with a constant-time one.
=== possible joke ending ===
So now let's talk about how, as a defender reviewing service logs,
we might be able to detect this attack while it's happening and - lol -
no i'm kidding, i'm kidding, it's time to write some code
=== OR if i'm not in the mood for the above ===
Alright! That does it for the theoretical part of this video!
That was a long one, and I hope you found it as interesting as I did.
I'm looking forward to following up on this topic
when we finally reach timing side channels at the end of set 4.
But for now, let's set these matters aside,
as we implement the basic padding oracle attack.
"""
class c17_24_mac_validation(Scene):
def construct(self):
toc = c17_00_toc().title_and_toc(indicated=6)[1][6]
self.add(toc)
# cbc setup
cbc_blocks = CBCBlocks([None]*16*4, direction=DOWN)
mac_block = cbc_blocks[-1].ct.rewrite(c_fills=C_MAC)
cbc_blocks = cbc_blocks[:-1]
to_mac = VGroup(*[block.ct for block in cbc_blocks] + [cbc_blocks[0].iv])
kwargs = {"mobject": to_mac, "direction": RIGHT, "sharpness": 4.0}
brace = Brace(**kwargs)
mac_box = FuncBox(r"\textit{MAC}_{k'}").scale(0.7).next_to(brace, RIGHT).align_to(brace, DOWN)
arrows = mac_box.get_arrows((brace, RIGHT, UP), (mac_block, DOWN, RIGHT))
mac_func = VGroup(arrows[0], mac_box, arrows[1])
figure = ZoomableVGroup(cbc_blocks, mac_block, brace, mac_box, mac_func)
figure.center().zoom(0.6)
self.play(FadeTransform(toc, figure))
self.next_section()
kwargs = {"language": "python", "tab_width": 4, "font": "Monospace", "line_spacing": 0.6, "style": "inkpot", "margin": 0.4, "background": "window", "insert_line_no": False}
code_1 = Code("snippets/c17_flask_excerpt_2.py", **kwargs).scale(0.6).shift(0.5*UP)
self.play(FadeTransform(figure, code_1))
self.next_section()
code_2 = Code("snippets/c17_flask_excerpt_3.py", **kwargs).scale(0.6).align_to(code_1, UP)
self.play(CodeRewrite(code_1, code_2, lag_ratio=0.25, bg_ratefunc=rate_functions.slow_into))
self.next_section()
self.play(HighlightRadial(code_2.code.chars[10][8:10]))
self.next_section()
code_3 = Code("snippets/c17_flask_excerpt_4.py", **kwargs).scale(0.6).align_to(code_1, UP)
self.play(CodeRewrite(code_2, code_3, lag_ratio=0.25))
toc = c17_00_toc().title_and_toc(indicated=6)[1][6]
self.remove(*self.mobjects)
self.add(code_3)
self.play(FadeTransform(code_3, toc))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment