Promblem about the PagedAttention.split_kv_cache implimentation

Hi team,

I’m new here and reading the PagedAttention part, of the latest vLLM code. And when it comes to the PagedAttention.split_kv_cache method, I find something doubt.

class PagedAttention:

    @staticmethod
    def split_kv_cache(
        kv_cache: torch.Tensor,
        num_kv_heads: int,
        head_size: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x = 16 // kv_cache.element_size()
        num_blocks = kv_cache.shape[1]

        key_cache = kv_cache[0]
        key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
                                   -1, x)
        value_cache = kv_cache[1]
        value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
        return key_cache, value_cache

In this method, the input kv_cache seems have a shape of [2, num_blocks, block_size, num_kv_heads, head_size].

And the target key_cache result is a shape of [num_blocks, num_kv_heads, head_size/x, block_size, x].

Is that correct to reshape the kv_cache by

key_cache = key_cache[0].view(num_blocks, num_kv_heads, head_size // x, -1, x)

rather than

key_cache = key_cache[0].view(num_blocks, -1, num_kv_heads, head_size // x, x).permute(0, 2, 3, 1, 4)?

I do think they are different, and the latter seems a formally correct reshaping. Wish someone can help me out.

It seems likely that the split_kv_cache method in the PagedAttention class is designed to reshape the kv_cache tensor into separate key_cache and value_cache tensors. The reshaping logic you mentioned is intended to organize the data into a specific format that aligns with the expected dimensions for key and value caches. The method uses x = 16 // kv_cache.element_size() to determine the chunk size for the innermost dimension, which is a common practice for optimizing memory access patterns in CUDA kernels.

The reshaping operation key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, -1, x) is intended to partition the key_cache into blocks of size x, which is consistent with the data layout used in some CUDA kernels for efficient memory access. The alternative reshaping you suggested, key_cache = key_cache.view(num_blocks, -1, num_kv_heads, head_size // x, x).permute(0, 2, 3, 1, 4), would result in a different memory layout, which might not be compatible with the expected input format for subsequent operations.

If you need more detailed information or further clarification, please let me know.

Sources: