Skip to content

Instantly share code, notes, and snippets.

@iscgar
Last active March 7, 2023 05:16
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save iscgar/b77caf9a8b4982a1002111ba46f0e701 to your computer and use it in GitHub Desktop.
Save iscgar/b77caf9a8b4982a1002111ba46f0e701 to your computer and use it in GitHub Desktop.
import struct
import itertools
from base64 import b64encode
from retrie.trie import Trie
def commonise_group(pat):
patr = list(sorted(''.join(reversed(s)) for s in pat))
common = [patr[0]]
pat_map = {}
matching = longest = len(patr[0])
for w in patr[1:]:
for i, c in enumerate(w[:len(common[-1])]):
if c != common[-1][i]:
break
else:
raise ValueError("cannot have identical patterns in group")
if i >= max(longest, len(w)) // 2:
common.append(w)
longest = max(longest, len(w))
matching = i
else:
common_word = ''.join(reversed(common[0]))
if len(common) == 1:
pat_map[common_word] = ''.join(common_word)
elif matching == longest - 1:
pat_map[common_word] = '[{}]{}'.format(''.join(w[matching:] for w in common), ''.join(common_word[-matching:]))
else:
pat_map[common_word] = '({}){}'.format('|'.join(''.join(reversed(w[matching:])) for w in common), ''.join(common_word[-matching:]))
common = [w]
matching = longest = len(w)
common_word = ''.join(reversed(common[0]))
if len(common) == 1:
pat_map[common_word] = ''.join(common_word)
elif matching == longest - 1:
pat_map[common_word] = '[{}]{}'.format(''.join(w[matching:] for w in common), ''.join(common_word[-matching:]))
else:
pat_map[common_word] = '({}){}'.format('|'.join(''.join(reversed(w[matching:])) for w in common), ''.join(common_word[-matching:]))
return [pat_map[w] for w in pat if w in pat_map]
def shorten_pat(pattern):
pat = ['']
group_stack = []
in_alt = False
in_escape = False
last_cat = -1
cats = []
group_prefix = ''
for c in pattern:
if not in_escape:
if in_alt:
cat = -1
if '0' <= c <= '9':
cat = 0
elif 'A' <= c <= 'Z':
cat = 1
elif 'a' <= c <= 'z':
cat = 2
if cat != last_cat:
cats.sort()
if len(cats) > 3 and ord(cats[-1]) - ord(cats[0]) == len(cats) - 1:
pat[-1] += '{}-{}'.format(cats[0], cats[-1])
else:
pat[-1] += ''.join(cats)
del cats[:]
if cat != -1:
cats.append(c)
last_cat = cat
continue
if c == '\\':
in_escape = True
elif c == '[':
assert not in_alt
in_alt = True
elif c == ']':
assert in_alt
in_alt = False
elif not in_alt:
if c == '(':
group_stack.append((group_prefix, pat))
group_prefix, pat = '', ['']
continue
elif c == ')':
gp, opt = group_prefix, commonise_group(pat)
group_prefix, pat = group_stack.pop(-1)
pat[-1] += '({}{})'.format(gp, '|'.join(opt))
continue
elif c == '|':
pat.append('')
continue
elif c == ':' and group_stack and pat[-1] == '?':
assert not group_prefix
pat[-1] = ''
group_prefix = '?:'
continue
else:
in_escape = False
pat[-1] += c
assert not in_alt and not in_escape and not group_stack
assert len(pat) == 1 or pat[-1]
return '|'.join(commonise_group(pat))
def regex_for_prefix(prefix, tail_len):
# A sequence of all possible binary values
# (used to pad the prefix on either side to account for encoding alignment)
padding = bytearray(range(256))
# We build a trie in order to try to get the most compressed form of the resulting pattern
t = Trie()
# A base64 encoding block is 3 bytes long, so we need to account for the position
# of the beginning of the prefix in any of an encoding block's slots
for i in range(3):
lead = b'A' * max(0, i - 1)
# If the length of the prefix plus the current encoding block offset
# isn't divisable by the length of n encoding block, we need to pad it
# in order to get all of the values that could appear after the prefix
# in the encoded form
pad_len = (3 - (len(prefix) + i) % 3) % 3
pads = b'A' * max(0, pad_len - 1)
# Iterate over all of the permutations of padding values for this slot
for r in itertools.permutations(padding, int(bool(i)) + int(bool(pad_len))):
source = lead
if i:
source += struct.pack('<B', r[0])
source += prefix
if pad_len:
source += struct.pack('<B', r[-1]) + pads
# We get the encoded value of the prefix (offset by the current slot
# index and padded to the next encoding block boundary)
encoded = b64encode(source)
# However, if the prefix isn't at the beginning of an encoding block,
# we only care about the way it affects the encoded prefix itself,
# and we don't really care about the value of the bytes that come
# before it, so strip the leading bytes (note that since the encoded
# length is stricktly bigger than the source length for base64,
# stripping an amount equal to the slot index is guaranteed to only
# strip the leading padding bytes, but not the encoded prefix).
encoded = encoded[i:]
# Similarly, if we added padding, we only care about the way it affect
# the prefix, but not about the encoded padding byte values, so strip
# them as well (again, this is guaranteed to not touch the encoded prefix,
# because the encoded size is strictly bigger than the source size for
# base64).
if pad_len > 0:
encoded = encoded[:-pad_len]
# Add it to the trie
t.add(encoded.decode('ascii'))
# Extract a pattern that describes this trie and optimise it a bit
pat = shorten_pat(t.pattern())
# Add a pattern for the tail (because we need to at least see this many bytes as well)
total_len = len(prefix) + tail_len
left = total_len - (len(prefix) + 2)
if left > 0:
groups = (left + 3) // 4
pat += '(?:{})'.format('|'.join(r'[\+\/A-Za-z0-9]{{{}}}{}'.format(groups * 4 - i, '='*i) for i in range(3)))
return pat
@gofri
Copy link

gofri commented Mar 7, 2023

Apparently, you can't pull-request to a gist, so see: https://gist.github.com/gofri/2ad0e25430bf89ea70614891bca5d35a/revisions
(account for the length difference between the ascii string and the encoded version)

@gofri
Copy link

gofri commented Mar 7, 2023

Also, here:

            # we only care about the way it affects the encoded prefix itself,
            # and we don't really care about the value of the bytes that come
            # before it, so strip the leading bytes (note that since the encoded
            # length is stricktly bigger than the source length for base64,
            # stripping an amount equal to the slot index is guaranteed to only
            # strip the leading padding bytes, but not the encoded prefix).
            encoded = encoded[i:]

Indeed, this catches most of it, but makes it much harder to handle the result (you need to generate your own prefix after the fact).
If we simply remove this line, we do get some extra chars for the pattern (ideally: [b64-char-pattern] once or twice. practically for current imp: a weird ORing that's longer, although still reasonable).
By fixing it, one can simply do the following for a complete solution:
echo "$input" | grep -oE "$exp" | base64 -d | grep -oE $original_pattern

@gofri
Copy link

gofri commented Mar 7, 2023

p.s. I'm using the following script to generate test vectors (obviously, not a perfect one, but helpful enough):

#!/usr/bin/env python3
import random

def range_chrs(a, z):
    return [chr(x) for x in range(ord(a), ord(z)+1)]

def pat_opts():
    return range_chrs('a', 'z') + range_chrs('A', 'Z') + range_chrs('0', '9')

def ascii_opts():
    return range_chrs('0', 'z')

def gen_one(pref, tail):
    s = ''
    s += ''.join(random.choices(ascii_opts(), k=random.randint(0, 5)))
    s += pref
    s += ''.join(random.choices(pat_opts(), k=tail))
    s += ''.join(random.choices(ascii_opts(), k=random.randint(0, 5)))
    return s

import sys
pref = sys.argv[1]
tail = int(sys.argv[2])
print(gen_one(pref, tail))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment