Skip to content

Instantly share code, notes, and snippets.

@pervognsen
Last active August 21, 2024 09:43
Show Gist options
  • Save pervognsen/218ea17743e1442e59bb60d29b1aa725 to your computer and use it in GitHub Desktop.
Save pervognsen/218ea17743e1442e59bb60d29b1aa725 to your computer and use it in GitHub Desktop.
Shift-based DFAs

A traditional table-based DFA implementation looks like this:

uint8_t table[NUM_STATES][256]

uint8_t run(const uint8_t *start, const uint8_t *end, uint8_t state) {
    for (const uint8_t *s = start; s != end; s++)
        state = table[state][*s];
    return state;
}

On Skylake, the throughput should be around 7 cycles per input byte (if the table fits in L1 cache). The main problem is that you can't issue the load to compute the next state until you've finished computing the current state, even on an out-of-order machine. There's actually a smaller, hidden problem too: when you compile the implicit address calculation involved in the table lookup you get something like this:

state = load(table + 256 * state + *s)

As a result, the multiplication by 256 (shift by 8) is also added to the loop-carried dependency chain for the state, resulting in an extra cycle of latency. You can fix that by just flipping the table:

uint8_t table[256][NUM_STATES];

uint8_t run(const uint8_t *start, const uint8_t *end, uint8_t state) {
    for (const uint8_t *s = start; s != end; s++)
        state = table[*s][state];
    return state;
}

Now the compiled table lookup looks like this:

state = load(table + NUM_STATES * *s + state)

That apparently trivial change should yield 16% better throughput. Incidentally, it can also help with memory locality if only a subset of the input bytes are seen in practice (e.g. 0..128 for ASCII). Even though that isn't the topic of this write-up, I wanted to go through that example since (1) using the "wrong" table layout is a common mistake (and I make it all the time despite being on guard) and (2) it illustrates how the loop-carried dependency chain latency for the state is what controls the performance of a DFA loop like this.

In order to make this go faster, let's think of the state transition table as an array of rows where indexing with an input byte yields the row that specifies all possible state transitions. Instead of doing our table lookup as a single-level load (with an address calculation that combines the byte and current state) we're going to do a two-level lookup where the first-level lookup is a load from memory that only depends on the input byte and gives us the row, and the second-level lookup picks out the right column within the row based on the current state. The key to making this faster is that the second-level lookup should not involve memory latency: it has to be done as a fast ALU operation, ideally in one cycle. This suggests the following approach: each row will be encoded as a 64-bit word, and we pick out the state-dependent column within the row by shifting and masking:

uint64_t table[256];

uint8_t run(const uint8_t *start, const uint8_t *end, uint8_t state) {
    for (const uint8_t *s = start; s != end; s++) {
        uint64_t row = table[*s];
        state = (row >> (state * BITS_PER_STATE)) & ((1 << BITS_PER_STATE) - 1);
    }
    return state;
}

The shift has 1c latency, the multiply has 3c latency or 1c if BITS_PER_STATE is a power of two, and the masking has 1c latency. So the latency on the critical path is only 3c/5c, which is already much better. But now for the final two tricks. First, instead of encoding the state sequentially as 0, 1, 2, etc, we are going to pre-multiply by BITS_PER_STATE:

uint64_t table[256];

uint8_t run(const uint8_t *start, const uint8_t *end, uint8_t state) {
    for (const uint8_t *s = start; s != end; s++) {
        uint64_t row = table[*s];
        state = (row >> state) & ((1 << BITS_PER_STATE) - 1);
    }
    return state;
}

And finally we pick BITS_PER_STATE = 6 so that (1 << BITS_PER_STATE) - 1 = 63. Most instruction sets already interpret 64-bit shift amounts mod 64, so the masking is now automatic at the machine code level:

uint64_t table[256];

uint64_t run(const uint8_t *start, const uint8_t *end, uint64_t state) {
    for (const uint8_t *s = start; s != end; s++) {
        uint64_t row = table[*s];
        state = row >> (state & 63);
    }
    return state & 63;
}

I also propagated the & 63 to the shift operand to help the compiler's instruction selection do what we want. The compiler now generates this for what was the loop's critical path:

shrx rax, qword ptr [8*rcx + table], rax

That is, 1c latency on the critical path for the state in register rax.

Now that latency is so low, we should also examine other potential bottlenecks that might prevent us reaching 1 byte/cycle. We'd need 2 scalar loads per cycle, for the input byte and for the row. That happens to be the limit for Skylake, which is a good stand-in for any consumer PC sold in the last 6-7 years; newer AMD processors (starting with Zen 3) and Apple's M1 processor can do 3 scalar loads per cycle. So we should just scrape by with enough load capacity, even on Skylake. Another bottleneck to consider is that the string pointer increment imposes a 1 cycle/byte latency limit. Also, modern x86 CPUs have a limit of one taken branch per cycle or even one taken branch every two cycles. For a 4-wide issue processor like Skylake, there might also be too many instructions per input byte. So four other things are just on the cusp of being too tight, which usually spells doom for sustained pipeline utilization.

Fortunately the last three bottlenecks are alleviated by just unrolling the loop. Clang will unroll aggressively on its own but GCC refuses to unroll unless forced, so you'll want to do the unrolling manually to have reliable performance across compilers. Because we need near-perfect pipeline utilization of critical resources (2 loads per cycle, etc) I kept seeing marginal gains even up to absurdly high unroll factors; just keep in mind that was in a micro-benchmark, and moderate unrolling is better suited for real-world use.

The main restriction with a shift-based DFA is that each row is 64 bits and hence with BITS_PER_STATE = 6 you can only accommodate up to 10 states. But 10 states is enough for a lot of useful tasks, including UTF-8 validation, simple needle-in-a-haystack pattern matching, skip lexing, etc. For many of these applications it's helpful to use an absorbing "state of interest" (e.g. an absorbing error state for UTF-8 validation) so you can hoist state-based branches out of the innermost loop.

Anyway, that's how you make a table-driven DFA go fast with simple, portable code. On my laptop's Ryzen 5900HX (Zen 3) processor, this runs at a throughput of 1 byte/cycle as predicted (4.5-4.6 GB/s with 4.6 GHz max clock), 4-5 times faster than a traditional table-based DFA. Not too bad for executing an arbitrary 10-state DFA!

Note that you can scale this idea beyond 10 states: make a row 2x64 bits, load both words, and mux between them. The BITS_PER_STATE is then 7 where the lower 6 bits are the shift amount and the 7th bit is the word selector, so you can fit 9 states per word and 18 states per row. But you leave the 1 byte/cycle sweet spot as soon as you add anything to the state-dependent logic (and the other bottlenecks will also assert themselves if you add much of anything). Concretely, just the TEST + CMOV you need for 2x64-bit rows adds 2 cycles to the critical path latency, so throughput is 1/3rd of a pure shift-based DFA. That said, the performance gap to a traditional table-based DFA is large enough that you can fit a few loads and muxes like this and still come out ahead, so your take-away shouldn't be that this general approach is totally useless beyond 10 states. On Skylake where an L1 load with a complex addressing mode is 6 cycles and [rax + rbx] is considered complex, 3 cycles is still 2x faster. But on Zen 3 where an L1 load with a simple addressing mode is only 4 cycles and [rax + rbx] is considered simple, 3 cycles is not very compelling given the limitations.

Background: To the best of my knowledge the single-cycle shift-based DFA is original work [1], but presented in this manner I hope it seems obvious to you in hindsight. The idea of separating the input-dependent and state-dependent latencies in the state update is known to SIMD hackers (e.g. Geoff Langdale used it in Hyperscan), and that was the main inspiration: you load the row as a vector and then use PSHUFB, VTBL and similar instructions to do the column selection. The 1 byte/cycle version of a PSHUFB-based DFA is limited to 16 states. Since PSHUFB/VTBL is vectorized it also allows for other tricks based on parallel-prefix circuits you can't do with the shift-based scalar technique, but you run into the same max loads/cycle wall. The input bytes can be loaded in groups of 8 using a single load and then bitwise extracted from there. The extra overhead of the bitwise extraction was too high on x86 to be a net win when I tried it, but we got to 2.3 bytes/cycle with a parallel-prefix VTBL-based DFA (my initial version was 1.5 bytes/cycle and @dougallj and I eventually pushed it to 2.3) on an M1 Firestorm perf core in the new Macs (and since VTBL is high latency compared to PSHUFB you need the latency hiding from parallel-prefix windowing just to reach 1 byte/cycle): https://twitter.com/pervognsen/status/1365170848215142400

[1] I first tweeted about it here: https://twitter.com/pervognsen/status/1364164150843252736. If you're interested there's a discussion in the tweet thread where some of the initial refinement happened and a bunch of extensions were explored. Travis Downs pointed out that you can do 16 states using BITS_PER_STATE = 4 if you're okay with 2 cycles/byte: https://twitter.com/trav_downs/status/1366524409196912641. This is strictly better than the 3 cycles/byte 2-to-1 row mux except in the edge case where you need 17 or 18 states and 16 won't do. And at 2 cycles/byte it's a big win even on machines with lower L1 latency like Zen 3. Another sub-thread discusses the fact that with the single-cycle version you can support a restricted 11th state. You only have 4 bits left after 10 * 6 bits for the 10 states, so you can only support outgoing edges from the 11th state to three other states (assuming you're using a logical right shift) with the premultiplied values 0, 6, 12. Since you have the freedom to permute the state numbers, you can make the 11th state be any state that only connects to 3 other states. And in some special cases you can squeeze extra juice from the algorithm by using an arithmetic right shift to perform bit smearing. Go through the linked twitter threads for more ideas/details.

@waffleboot
Copy link

I don't understand phrase 'First, instead of encoding the state sequentially as 0, 1, 2, etc, we are going to pre-multiply by BITS_PER_STATE'. "state" is still masked by BITS_PER_STATE bits. It looks like we shift only lower bits of rows. For example, if BITS_PER_STATE = 4 and last octet in row has value 0101 (decimal 5, meaning translation initial zero state to state 5) then we will see next state 5 and on the next iteration are we going to shift next step row to 5 bits?

It looks like for single state from 8 states set we use full byte even 4 bits is enough for each state. We just increase state's bit space just for remove multiplication. It looks like we run 'state = row >> state' on each step and masking as a small step to drop higher bits.

@gwillen
Copy link

gwillen commented Jul 10, 2021

I don't understand phrase 'First, instead of encoding the state sequentially as 0, 1, 2, etc, we are going to pre-multiply by BITS_PER_STATE'. "state" is still masked by BITS_PER_STATE bits. It looks like we shift only lower bits of rows. For example, if BITS_PER_STATE = 4 and last octet in row has value 0101 (decimal 5, meaning translation initial zero state to state 5) then we will see next state 5 and on the next iteration are we going to shift next step row to 5 bits?

It looks like for single state from 8 states set we use full byte even 4 bits is enough for each state. We just increase state's bit space just for remove multiplication. It looks like we run 'state = row >> state' on each step and masking as a small step to drop higher bits.

The pre-multiplication means that the values we store in the table are "state number" times "BITS_PER_STATE". So e.g. if you had 6 bits per state, and you wanted to represent state 5, the value in the table would not be 5, but decimal 30, 0x1E = 011110.

This encoding wastes space to gain speed, and indeed with only 4 bits this encoding cannot represent state 5: 5*4 is decimal 20, 0x14 = 10100, so you can't fit it into 4 bits. With 4 bits you can only encode state numbers 0-3.

@bitRAKE
Copy link

bitRAKE commented May 14, 2022

Sometimes simpler code is better - even if it has problems ...

	repeat UNROLL
		mov	ah, [rsi + %]
		movzx	eax, byte [rcx + rax]
	end repeat

The lack of register pressure allows several DFAs to run in parallel - shadowing any stall.

@pervognsen
Copy link
Author

pervognsen commented May 15, 2022

Yes, you can apply that generally to turn a latency-limited problem into a throughput-limited problem. It's commonly used for Huffman decoding from multiple logically independent (but physically interleaved) code streams. Anyway, on x64 you still have enough registers to handle at least 5 completely separate DFA decoder states, each with their own table pointers, which the last time I did a loop analysis with uiCA was enough to cover the full latency on Zen 3: https://gist.github.com/pervognsen/5e081b19720a8d6954e28f187b50beff

Even so, that is not a drop-in replacement since it requires multiple independent workloads and often quite a few more than strictly needed to cover the latency. For example, 5 parallel DFA streams is the minimum to cover the latency in the above example on Zen 3, but if one of the streams is much longer than the 4 others, then once the other 4 streams finish you're back to being latency limited. And most applications simply aren't architected to benefit from this kind of decoupled processing with multi-workload batching, so having a faster drop-in replacement for serial decoding is still very valuable.

I don't remember why I didn't discuss this general "micro-threading" transformation in the article since I was experimenting with it around the same time, but I appreciate you bringing it up.

@QuarticCat
Copy link

QuarticCat commented Jul 15, 2023

Interesting post!

I think we can scale this idea to 20 states without giving up the nice property of BITS_PER_STATE = 6. Here's the idea:

Since state * 6 == (state * 3) << 1, the last bit is always zero. We can use that bit to mux between two words. And the words are constructed in this way:

word 0 (last bit = 0): aaaaaa bbbbbb cccccc dddddd eeeeee ffffff gggggg hhhhhh iiiiii jjjjjj 0000
word 1 (last bit = 1): 0 aaaaaa bbbbbb cccccc dddddd eeeeee ffffff gggggg hhhhhh iiiiii jjjjjj 000

What do you think?

EDIT: here are some code https://godbolt.org/z/8j9YvxzP4

@pervognsen
Copy link
Author

@QuarticCat Nice observation. Sorry for not responding sooner but I was traveling at the time and must have missed the notification.

Did you do any measurements?

At a glance all of those 20-state variants look like they can run at 2 cycles per byte because of the extra instruction on the state critical path. The 16-state version has 2 extra instructions on the critical path and can only run at 3 cycles per byte.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment