How to crop kv_caches?

I want to implement some customed features in Qwen2ForCausalLM, which will crop a certain length of kv_caches (e.g. discard the last 64 generated tokens). How can i implement this ?

I have tried to directly modified kv_caches that passed into forward() function, but I found that the shape of kv_cache is fixed as [2, num_blocks, block_size, num_kv_heads, head_size], I can use block_tables to fetch kv_cache corresponds to a completion, but certainly there are some parameters indicate ending position of actual kv_cache, but I just can’t find.

Another problem is when I successfully extract correct kv_cache to one completion, will it cause problem when I tried to crop kv_cache since it might affect another completion’s kv_cache because they may share same cuda_block ?

1 Like