In the vLLM implementation of Gemma 3, particularly when handling multimodal prompts (has_images=True), the Gemma3Attention module employs a specific logic during its forward pass. During the decode phase:
- 
The qkv_projlayer computes Query (Q), Key (K), and Value (V) based on thehidden_statesof the current input token(s). In the decode phase, this meansq,k, andvinitially represent only the Q, K, V for the single new token being processed.# In Gemma3Attention.forward qkv, _ = self.qkv_proj(hidden_states) # hidden_states is for current token(s) in decode q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
- 
The self.attn(q, k, v)(vLLM’s optimizedAttentionbackend) is called. For multimodal inputs, the comment explicitly states its role:# In Gemma3Attention.forward, when has_images is True: attn_output = self.attn(q, k, v) # ... # "The call to self.attn(q, k, v) is only used to populate the KV cache - its # output is discarded and overwritten below."This implies that self.attncorrectly appends the K and V of the current new token to the existing vLLM KV cache, which holds the KV history from prefill and previous decode steps.
- 
Subsequently, naive_attn_with_masksis called to compute the actual attention output:# In Gemma3Attention.forward, when has_images is True: attn_output = self.naive_attn_with_masks(q, # Current token's Q k, # Current token's K v, # Current token's V out=attn_output, **kwargs)
- 
Inside naive_attn_with_masks, the attention is computed per sequence in the batch usingtorch.nn.functional.scaled_dot_product_attention:# In Gemma3Attention.naive_attn_with_masks def naive_attn_with_masks( self, q: torch.Tensor, # Expected to be current token's Q in decode k: torch.Tensor, # Expected to be current token's K in decode v: torch.Tensor, # Expected to be current token's V in decode out: torch.Tensor, **kwargs, ) -> torch.Tensor: # ... (reshaping of q, k, v) # k and v are reshaped from the input arguments # k_reshaped = k.view(-1, self.num_kv_heads, self.head_dim) # v_reshaped = v.view(-1, self.num_kv_heads, self.head_dim) for seq_len, attn_mask in zip(seq_lens, attn_masks): # seq_lens from kwargs start_idx = # ... calculated based on seq_lens end_idx = start_idx + seq_len # query, key, value are sliced based on start_idx, end_idx from the reshaped q, k, v query = q_reshaped[start_idx:end_idx].unsqueeze(0) key = k_reshaped[start_idx:end_idx].unsqueeze(0) # (A) value = v_reshaped[start_idx:end_idx].unsqueeze(0) # (B) output = F.scaled_dot_product_attention( query, # Current token's Q for this sequence key, # Needs to be full KV history for this sequence value, # Needs to be full KV history for this sequence attn_mask, self.scaling, ) # ... return out
The core of the question is:
Given that F.scaled_dot_product_attention requires the key and value arguments to contain the full KV history for the sequence up to the current token, but the k and v arguments passed into naive_attn_with_masks during decode are (as per step 1) only for the current new token:
How do the key and value tensors, which are derived from the input k and v to naive_attn_with_masks, effectively represent or gain access to the full KV history (i.e., KVs from prefill and all prior decode steps for that specific sequence) that is stored in the vLLM KV cache (managed by self.attn)?