In the vLLM implementation of Gemma 3, particularly when handling multimodal prompts (has_images=True
), the Gemma3Attention
module employs a specific logic during its forward
pass. During the decode phase:
-
The
qkv_proj
layer computes Query (Q), Key (K), and Value (V) based on thehidden_states
of the current input token(s). In the decode phase, this meansq
,k
, andv
initially represent only the Q, K, V for the single new token being processed.# In Gemma3Attention.forward qkv, _ = self.qkv_proj(hidden_states) # hidden_states is for current token(s) in decode q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
-
The
self.attn(q, k, v)
(vLLM’s optimizedAttention
backend) is called. For multimodal inputs, the comment explicitly states its role:# In Gemma3Attention.forward, when has_images is True: attn_output = self.attn(q, k, v) # ... # "The call to self.attn(q, k, v) is only used to populate the KV cache - its # output is discarded and overwritten below."
This implies that
self.attn
correctly appends the K and V of the current new token to the existing vLLM KV cache, which holds the KV history from prefill and previous decode steps. -
Subsequently,
naive_attn_with_masks
is called to compute the actual attention output:# In Gemma3Attention.forward, when has_images is True: attn_output = self.naive_attn_with_masks(q, # Current token's Q k, # Current token's K v, # Current token's V out=attn_output, **kwargs)
-
Inside
naive_attn_with_masks
, the attention is computed per sequence in the batch usingtorch.nn.functional.scaled_dot_product_attention
:# In Gemma3Attention.naive_attn_with_masks def naive_attn_with_masks( self, q: torch.Tensor, # Expected to be current token's Q in decode k: torch.Tensor, # Expected to be current token's K in decode v: torch.Tensor, # Expected to be current token's V in decode out: torch.Tensor, **kwargs, ) -> torch.Tensor: # ... (reshaping of q, k, v) # k and v are reshaped from the input arguments # k_reshaped = k.view(-1, self.num_kv_heads, self.head_dim) # v_reshaped = v.view(-1, self.num_kv_heads, self.head_dim) for seq_len, attn_mask in zip(seq_lens, attn_masks): # seq_lens from kwargs start_idx = # ... calculated based on seq_lens end_idx = start_idx + seq_len # query, key, value are sliced based on start_idx, end_idx from the reshaped q, k, v query = q_reshaped[start_idx:end_idx].unsqueeze(0) key = k_reshaped[start_idx:end_idx].unsqueeze(0) # (A) value = v_reshaped[start_idx:end_idx].unsqueeze(0) # (B) output = F.scaled_dot_product_attention( query, # Current token's Q for this sequence key, # Needs to be full KV history for this sequence value, # Needs to be full KV history for this sequence attn_mask, self.scaling, ) # ... return out
The core of the question is:
Given that F.scaled_dot_product_attention
requires the key
and value
arguments to contain the full KV history for the sequence up to the current token, but the k
and v
arguments passed into naive_attn_with_masks
during decode are (as per step 1) only for the current new token:
How do the key
and value
tensors, which are derived from the input k
and v
to naive_attn_with_masks
, effectively represent or gain access to the full KV history (i.e., KVs from prefill and all prior decode steps for that specific sequence) that is stored in the vLLM KV cache (managed by self.attn
)?