After reading the doc of page attention, I found the reading procedure of q_vecs and k_vecs is a little different. From the doc, it says
Next, we need to read the global memory data pointed to by
q_ptrinto shared memory asq_vecs. It is important to note that each vecs is assigned to a different row. For example, if theTHREAD_GROUP_SIZEis 2, thread 0 will handle the 0th row vecs, while thread 1 handles the 1st row vecs. By reading the query data in this way, neighboring threads like thread 0 and thread 1 can read neighbor memory, achieving the memory coalescing to improve performance.
It looks like every thread in thread group would read NUM_VECS_PER_THREAD vec to q_vecs. Should the increment statement in loop be
i++
instead of
i += NUM_THREAD_GROUPS
For comparison of k_vecs part, its increment statement is
j++
for q_vecs
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] =
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
}
for k_vecs
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else {
// Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, *k_scale);
}
}

