Something weired about the reading procedure of q_vecs in page attention kernel

After reading the doc of page attention, I found the reading procedure of q_vecs and k_vecs is a little different. From the doc, it says

Next, we need to read the global memory data pointed to by q_ptr into shared memory as q_vecs . It is important to note that each vecs is assigned to a different row. For example, if the THREAD_GROUP_SIZE is 2, thread 0 will handle the 0th row vecs, while thread 1 handles the 1st row vecs. By reading the query data in this way, neighboring threads like thread 0 and thread 1 can read neighbor memory, achieving the memory coalescing to improve performance.

It looks like every thread in thread group would read NUM_VECS_PER_THREAD vec to q_vecs. Should the increment statement in loop be
i++
instead of
i += NUM_THREAD_GROUPS

For comparison of k_vecs part, its increment statement is
j++


for q_vecs

const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
       i += NUM_THREAD_GROUPS) {
    const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
    q_vecs[thread_group_offset][i] =
        *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}

for k_vecs

for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
  const cache_t* k_ptr =
      k_cache + physical_block_number * kv_block_stride +
      kv_head_idx * kv_head_stride + physical_block_offset * x;
  const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
  const int offset1 = (vec_idx * VEC_SIZE) / x;
  const int offset2 = (vec_idx * VEC_SIZE) % x;

  if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
    k_vecs[j] = *reinterpret_cast<const K_vec*>(
        k_ptr + offset1 * BLOCK_SIZE * x + offset2);
  } else {
    // Vector conversion from Quant_vec to K_vec.
    Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
        k_ptr + offset1 * BLOCK_SIZE * x + offset2);
    k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
        k_vec_quant, *k_scale);
  }
}

The increment i += NUM_THREAD_GROUPS in the q_vecs loop is intentional and correct. Each thread group collaboratively loads the query vectors into shared memory, with each thread group responsible for a subset of the vectors, thus the increment skips by the number of thread groups to avoid redundant loads and ensure coalesced memory access. In contrast, for k_vecs, each thread loads its assigned vectors into registers with a simple j++ increment, as each thread only needs to handle its own portion of the key data, not shared across the group. This difference is by design to optimize memory access patterns for both query and key data, as explained in the paged attention kernel documentation and kernel source code.

Would you like a step-by-step breakdown of how these memory access patterns work and why they differ?

Sources:

Because q_vecs is share memory, we don’t need every thread group to read the same q from global memory. Every thread can read correspondence vec to q_vecs. Finally, every thread group can use this shared memory to do computation later. Some threads in thread block been filtered out the reading process due to the loop condition testing i < NUM_VECS_PER_THREAD;.

You are correct: since q_vecs is in shared memory, each thread group only needs to load its assigned portion of the query vector from global memory, and all threads in the block can later access the full q_vecs for computation. The loop with i += NUM_THREAD_GROUPS ensures that only the necessary threads perform the loads, avoiding redundant reads and maximizing memory coalescing. Some threads may not participate in the load if their i does not satisfy i < NUM_VECS_PER_THREAD, which is expected and efficient (paged attention kernel doc, kernel source).

This design allows each thread group to collaboratively fill q_vecs in shared memory, after which all threads can use the shared data for the attention computation, minimizing global memory traffic and improving performance. Would you like a detailed step-by-step explanation of how the thread-to-vec mapping and shared memory loading works?

Sources: