Promblem about the PagedAttention.split_kv_cache implimentation

Hi team,

I’m new here and reading the PagedAttention part, of the latest vLLM code. And when it comes to the PagedAttention.split_kv_cache method, I find something doubt.

class PagedAttention:

    @staticmethod
    def split_kv_cache(
        kv_cache: torch.Tensor,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x = 16 // kv_cache.element_size()
        num_blocks = kv_cache.shape[1]

        key_cache = kv_cache[0]
        key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
                                   -1, x)
        value_cache = kv_cache[1]
        value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
        return key_cache, value_cache

In this method, the input kv_cache seems have a shape of [2, num_blocks, block_size, num_kv_heads, head_size].

And the target key_cache result is a shape of [num_blocks, num_kv_heads, head_size/x, block_size, x].

Is that correct to reshape the kv_cache by

key_cache = key_cache[0].view(num_blocks, num_kv_heads, head_size // x, -1, x)

rather than

key_cache = key_cache[0].view(num_blocks, -1, num_kv_heads, head_size // x, x).permute(0, 2, 3, 1, 4)?

I do think they are different, and the latter seems a formally correct reshaping. Wish someone can help me out.

It seems likely that the split_kv_cache method in the PagedAttention class is designed to reshape the kv_cache tensor into separate key_cache and value_cache tensors. The reshaping logic you mentioned is intended to organize the data into a specific format that aligns with the expected dimensions for key and value caches. The method uses x = 16 // kv_cache.element_size() to determine the chunk size for the innermost dimension, which is a common practice for optimizing memory access patterns in CUDA kernels.

The reshaping operation key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x) is intended to partition the key_cache into blocks of size x, which is consistent with the data layout used in some CUDA kernels for efficient memory access. The alternative reshaping you suggested, key_cache = key_cache.view(num_blocks, -1, num_kv_heads, head_size // x, x).permute(0, 2, 3, 1, 4), would result in a different memory layout, which might not be compatible with the expected input format for subsequent operations.

If you need more detailed information or further clarification, please let me know.

Sources:

The layout of k_cache is as follow. In this case, we assume the parameter is the same as the page attention document.
HEAD_SIZE = 128
BLOCK_SIZE =16
NUM_THREADS = 128
THREAD_GROUP_SIZE = 2
VEC_SIZE = 4

The layout of k_cache in vLLM’s PagedAttention is [num_blocks, num_kv_heads, head_size/x, block_size, x], where each block contains all context tokens for a given head, and the innermost dimension x (typically 8 for FP16) is used for vectorized memory access. Each “partial head” in the diagram corresponds to a chunk of the head dimension, and the vectors (vec 0 to vec 31) represent these chunks processed in parallel for efficient GPU access, as described in the paged attention kernel documentation.

This layout is specifically designed to optimize memory coalescing and parallelism in CUDA, matching the diagram’s depiction of blocks, heads, and vectorized processing. Would you like a step-by-step breakdown of how this layout maps to the diagram?

Sources: