Skip to content

Instantly share code, notes, and snippets.

@asumagic
Last active December 4, 2023 10:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save asumagic/aaf875929d1c03a96fd08a897599aaa8 to your computer and use it in GitHub Desktop.
Save asumagic/aaf875929d1c03a96fd08a897599aaa8 to your computer and use it in GitHub Desktop.

Conformer-transducer streaming model in SpeechBrain

The #2140 PR implements a streaming model in SpeechBrain, based on a Conformer-Transducer architecture.
The "transducer" part refers to (part of) the loss used (RNN-T loss), but the vast majority of changes needed for streaming were specific to the Conformer model.

Two main parameters exist:

  • Chunk size, which basically corresponds to how many new audio frames we need to throw at the model to get new predictions. With most models, a smaller chunk size will worsen the model's accuracy, but improve latency. Lower chunk sizes also tend to worsen the RTF (decoding speed).
  • Left context size, which, here, corresponds to the number of left context frames at each transformer layer, which will be detailed later. When not using streaming, the left and right context are technically infinite. Larger contexts mean higher accuracy, but also higher memory and computational costs.

This model leverages several techniques for streaming, and is trained to be able to cope with different chunk and left context sizes chosen at runtime (with certain limits).

Both the chunk size and left context sizes are defined in terms of frames at transformer level. However, by knowing the entirety of the architecture, we deduce how many audio samples this corresponds to, which is required at inference time.

When we stream, we have to constrain our models in certain ways. The most important is that we have to perfectly control and understand what outputs depend on what inputs: We need to be able to infer continuously, for a potentially long time, and predictions need to be constrained to a low-ish latency (the prediction at time t cannot depend on, say, t + 5s).
Additionally, these constraints need to be enforced during training, and this enforcement is not necessarily implemented the same between training and inference, as the former usually depends on masking and the latter needs special code to operate over individual chunks of data.
These constraints were generally verified by using the newly introduced speechbrain.utils.streaming.infer_dependency_matrix (and plot_dependency_matrix), which attempts to determine what inputs do any output depend on by randomizing that input data.

The streaming approach has the benefit of having a fixed memory cost long term. This means that the model can be used to decode long files which would otherwise run out of memory or require workarounds with typical models.

Dynamic Chunk Training

At training, for each different batch, we select a random chunk size and a random left context size. This is so that the model can adapt to different situations.

The strategy we adopt is (with the default hparams):

  • A 60% chance to use chunking for this batch with a chunk size randomly sampled in the range 8..32
  • Then, a 75% chance to limit the left context to a value randomly sampled in the range 16..64

Some of the prior art describes obtaining better results with this approach than if the model were trained normally for half the epochs and then exclusively with chunking for the other half (e.g. https://arxiv.org/abs/2012.05481).
Training the model for the entire duration with a mix of both methods provides the added benefit that the model can be used with an infinite chunk size, thus behaving like a non-streaming model.
When training our model on the LibriSpeech dataset, we found that for this particular architecture and setup, the error rate is very similar between the non-streaming and the streaming model (in non-streaming mode for both).

Chunked attention

The objective of chunked attention is to prevent an output at a given time point from depending on data far into the future.
Typically, you would be able to achieve this by preventing a frame in the transformer at time t from depending on any frame at a time > t, i.e. a given output can only depend on past data, i.e. "left attention".
However, this generally tends to harm accuracy, so we prefer to give the ability to restrain the attention to non-overlapping chunks (of e.g. 8 frames), where any output frame can depend on any input frame from within the same chunk (and left context, i.e. past data, as described later).

Figure 2 in https://arxiv.org/abs/2012.05481 demonstrates the difference. Note that this is not specific to the Conformer model but can sometimes be applied to other Transformer-based models.

A nice property of this approach is that this constraint persists across layers, meaning that any given output at the last layer will only ever depend on the input frames from within the same chunk.

For example, with a chunk size of 4 (and infinite left context), output frames at indices (4,5,6,7) at the last layer will all depend on input frames at indices (0,1,2,3,4,5,6,7), but not (8,9,...), which we wouldn't know yet in a streaming inference context.
Emphasis on the fact that output frame 4 can directly depend on 7 in the above example, because it is within the same chunk (i//chunk_size == 1), which is the difference with pure left context (in which case 4 could only depend on inputs (0,1,2,3,4)).

At train time, chunked attention is enforced through masking.

Left context

For each layer of the conformer, we save (with the default hparams) a certain number of left frames (16 or 32 are reasonable values).

At train time, this is enforced by masking (at chunk level) any input to the layer that is further to the left than the enforced limit.
At inference time, we save and reinject that left context as necessary.

This works because, due to the chunking, no past chunk can ever depend on a future chunk. Additionally, since we do this at each layer, the earliest frame of the left context of the current layer depends on the earliest frame of the left context of the previous layer, and so on.
This effectively means that the "receptive field" of the model can be very wide with not-so-high left context sizes, even though we only need to save that fairly small context.
Note that this does not necessarily mean that the model makes meaningful use of information that goes back many seconds.

Dynamic Chunk Convolution (DCConv)

Understanding of the Conformer paper may prove helpful here.

The Conformer model is composed of stacked conformer blocks, each of which is composed of feed-forward modules, but especially of a multi-head self-attention module followed by the convolution module. The streaming changes relevant to the self-attention were described above, and those relevant to the convolution module comes now.

Usually, in the convolution module, you would support streaming by using a causal convolution, which shifts the output so that the n-th output frame would only depend on the input frames from n-kernel_size to n (instead of n-((kernel_size-1)//2) to n+((kernel_size-1)//2)). This generally comes at a significant accuracy cost, even for the non-streaming case.

The approach taken by the DCConv paper, described in this Amazon paper (https://assets.amazon.science/18/80/2126d1f5416aa7143505694ae013/dynamic-chunk-convolution-for-unified-streaming-and-non-streaming-conformer-asr.pdf), is to instead mask frames that come after the current chunk (see Fig. 2). This happens to make the convolution identical to normal in a non-streaming context, and performs pretty well in a streaming context.

The implementation for this is slightly trickier as we actually have to functionally split the chunks that are fed to the convolution.
We could simply split the tensors into a list of chunks and perform calculations this way, but this adds a lot of overhead, so we instead pack all those chunks to the batch dimension and concatenate them later. Doing things this way makes training roughly as fast as when not streaming.

Feature extraction

We use a feature extraction module at the start of the model which combines a filter bank (fbank) extractor, some normalization and a downsampling CNN, which reduces the frame rate by a factor of //4 vs the frame rate at the output of the fbank extractor.

Changes to feature extraction were not really necessary for the training of a streaming model in this case. We, however, need to understand the properties of this part of the model to know how many frames to feed a chunk exactly.

To do this, we can consider the feature extraction module as a filter, with a specific window size and stride.

With a kernel size of 3, a stride of 2, and two stacked layers, the downsampling CNN has an effective kernel size of 7 and a stride of 4.

The following function then calculates the same for feature extraction as a whole:

def get_filter_properties():
    # FIXME: this should not be hardcoded

    sample_rate = 16000 # Hz
    frames_per_ms = sample_rate // 1000

    # win_length (32 in hparams) and hop_length (defaults to 10) respectively,
    # specified in milliseconds
    fbank_win_size_ms = 32
    fbank_stride_ms = 10
    fbank_win_size_frames = fbank_win_size_ms * frames_per_ms
    fbank_stride_frames = fbank_stride_ms * frames_per_ms

    # the configuration of the downsampling CNN has the following properties.
    # we express them in "fbank frames" as it consumes the fbanks
    # this is determined from its architecture (window size, stride, layers)

    # NOTE: ideally i feel like we should have a way to poll these properties
    # from the model directly, or at least provide a mechanism to compute
    # combined filter properties (see downcnn_*_frames below)

    downcnn_win_size_fbanks = 7
    downcnn_stride_fbanks = 4

    # we can consider the fbank+featcnn combination as a filter and thus we can
    # determine the properties of the whole filter
    downcnn_win_size_frames = (
        fbank_win_size_frames
        + (fbank_stride_frames * (downcnn_win_size_fbanks - 1))
    )
    downcnn_stride_frames = fbank_stride_frames * downcnn_stride_fbanks

    return FilterProperties(
        window_size_frames=downcnn_win_size_frames,
        stride_frames=downcnn_stride_frames
    )

The result is a window size of 92ms and a stride of 40ms.

Padding and contexts at inference time

In order to avoid a discrepancy between training and time and inference time, feature extraction needs some care (preserving left and some fixed right context). The code is available, and the logic is complicated and wasn't thoroughly checked for correctness and effect on performance.

I found from limited testing that it roughly seems to work even without doing this.

This is still TODO, but if I don't get around it, this might deserve a better look.

Decoding

Currently, only greedy search (GS) is implemented, as it is the easiest to implement for streaming. Here, we only really need to do two things:

  • Saving the decoder context (here, a LSTM's) to reuse for subsequent calls, which is fairly straightforward as it is fully unidirectional and only depends on the last hidden state
  • Adapt the token decoding (token ID -> string) that is done by SentencePiece.

The reason why we need to do the latter is the way that SentencePiece tokenizers actually deal with whitespaces. If a token begins with the '▁' symbol, then SentencePiece will emit a space only if it is not the first token.
This makes sense generally, but as we're streaming, the "first" token might just be in the middle of a sentence which SentencePiece cannot see. Thus, we work around the issue by peeking at the first token and inserting a space if needed.

Broad lines for inference

We let encoders optionally define a make_streaming_context method if they support streaming. A streaming context is a simple mutable dataclass that holds streaming metadata (chunk size, etc.) and the tensors that need to be saved at inference time, due to the techniques outlined previously.

This is done fairly generically so not all TransformerASR encoders actually need to support streaming contexts, but can do so if someone desires to implement such a model for a particular setup.

The streaming encoding/forward/etc. methods take the context as a parameter and will update it, and forward any sub-context required for the operation of e.g. an encoder (like the ConformerEncoder, which in turns holds ConformerEncoderLayer streaming contexts).

The same context object passed to these methods, which is initially blank when initialized by make_streaming_context, gets updated by them.

Unresolved bug: lower WER% in actual streaming vs streaming emulation

NOTE: This issue was resolved. It was caused by a mismatch in the conformer convolution streaming code path. This required a model retrain.

Using a chunk size of 16 and a left context of 32, in real streaming mode, we obtain 4.69% WER on test-clean which is way too high compared to emulation (i.e. masking) where we get 3.46% with half the chunk size.

Several things:

  • Minor: The chunk size calculation or splitting logic appears to be very slightly incorrect. Sometimes, the expected chunk count and the calculated chunk count are off-by-one but this is rare. Additionally, calculating features for the entire audio and splitting those features instead of splitting the audio and calculating features yields a very slight WER difference (4.67% vs 4.69%). This might not need immediate attention.
  • Maybe something is acting unexpectedly inside the RelMHAXL, especially regarding the use of positional embeddings. To my understanding, theoretically, this should not be the case.
  • My assumption is that the MHAXL doesn't need any sort of change to account for the offset within the stream in streaming code and the code is written as such.
  • The left context seems to be inserted correctly for the MHAXL and the DCConv, and for the latter, padding seems to match.
  • Some of the normalization inside the model might behave differently, but it seems surprising that this would have such a dramatic effect.
  • Greedy decoding yields identical results between the streaming mode and the non-streaming mode, so it is not the culprit.

I believe that for debugging this issue one would need to really carefully observe what difference there is between emulating streaming and actually doing streaming at each layer in order to find where the discrepancy occurs.

Hopefully that isn't the case as it would require a model retrain, but maybe the masking (i.e. streaming "emulation") is wrong as well. Not very easy to debug, however.

Inference performance considerations

NOTE: The RTF figures here are not final since there were some adjustements.

Using a single CPU thread on an AMD Ryzen 9 3900X (Zen 2), a chunk size of 8 (~700ms latency, TODO check) and left context of 32, a 6x speed (~0.16 RTF) could be achieved (using a single stream, including decoding).

This RTF is currently not very good compared to certain other options and could be a future area of improvement. Note that this is on a full fp32 model running in eager PyTorch mode. Also note that this was tested on a single test audio, more accurate RTF measurements should be done on test-clean/test-other.
Please take these figures with a grain of salt.

At the time of writing, a proper inference server to serve transcription over the network has not be written.

Anecdotally, we found that using more PyTorch threads did not result in better performance on that system for CPU decoding, even when increasing the number of batches (which is more complicated to set up).

GPU decoding was only briefly explored, but batching was hugely more beneficial there. It is unclear how much of that performance can be achieved without batching and exclusively through multithreading/multiprocessing.
On a 31s test audio, on a 2080 Ti, using a chunk size of 8 and left context of 16:

  • batch size = 256, autocast on: RTF=0.0011
  • batch size = 256, autocast off: RTF=0.0014
  • batch size = 128, autocast on: RTF=0.0012
  • batch size = 128, autocast off: RTF=0.0015
  • batch size = 64, autocast on: RTF=0.0016
  • batch size = 64, autocast off: RTF=0.0017
  • batch size = 64, chunk size=+inf: RTF=OutOfMemoryError
  • batch size = 16, chunk size=+inf, autocast on: RTF=0.0015

You can archieve even higher RTF if you can allow higher latency and you only care about batched transcription: with batch size = 256, autocast and a chunk size of 32 (left context=16), a RTF of 0.00086 is achievable on that GPU.

You could also explore running a thread pool of two with half the batch size, with both threads consuming the stream of inferences. This might result into better performance while the other thread is busy CPU-side, for example.

For an inference server on CPU, we would recommend disabling PyTorch's parallelism using torch.set_num_threads and then using a thread or process pool to concurrently process requests and not bother with batching, as processing each request this way will scale much better than relying on PyTorch's own parallelism.
Performance using a thread pool should globally be acceptable despite the GIL, because PyTorch's functions will generally release it.
A process pool might scale better on a high number of cores, but this option was not explored.

Currently, batching is tricky as you may want to interrupt or reset streams. This is not yet well supported.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment