How does the forward pass in speculative decoding work?

Hi all,

I am trying to learn the details about how the forward pass in speculative decoding works.

From what I understand so far about the process:

  1. Draft model create a sequence of tokens
  2. This sequence is passed through the larger model for evaluation
  3. The larger model creates a prediction for each index in the sequence
  4. There is a function that compares the larger models prediction vs. the draft and replaces tokens where the larger models prediction is different

What I am struggling to understand is the internal mechanism for this. From my understanding of LLMs

  1. The attention mechanisms does not compute logits for each index in a sequence, it predicts the so called ‘affinity’ between all indexes in the key/query/value setup
  2. The forward passes after predict the logits for the ‘next’ token to follow the sequence.

So I am not clear on where in the flow logits are predicted for all of the tokens passed in a sequence.

Can someone explain how this works?

It seems likely that in vLLM’s speculative decoding, the draft model generates a sequence of speculative tokens, and then the larger (target) model performs a forward pass over the entire sequence (including the speculative tokens) to compute logits for each position. The attention mechanism computes affinities (QK scores) for all tokens in the context, and the model outputs logits for each token position in the batch, not just the last one, when given multiple tokens as input. This allows the target model to provide the “true” logits for each speculative token, which are then compared to the draft model’s predictions to decide which tokens to accept or reject. The process is described in detail in the speculative decoding documentation and the paged attention kernel design.

My understanding is that, during speculative decoding, the target model’s forward pass is run in “prefill” mode (not single-token decode), so it computes logits for all positions in the input sequence, including the speculative tokens. The logits for each speculative token are then used to verify the draft model’s predictions. This is possible because the attention kernel and model head are designed to output logits for every token in the input, not just the last one, when run in this mode. For more, see Speculative Decoding — vLLM and QK section of paged attention.

Would you like a step-by-step breakdown of the forward pass and token verification process?

Sources: