Custom KV cache implementation

Hi everyone. I have implemented a novel KV caching approach that works with HF caching.

Driver class here:

class JVCache(Cache):

    def __init__(self, compression_module_path: str = None, compression_type: str = "linear", **kwargs):
        super(JVCache, self).__init__(**kwargs)
        print("Initialising Randomised Cache")

        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []

        if compression_type == "nonlinear":
            self.compression_decompression_module = LearnableCacherNonlinear.from_pretrained(compression_module_path, local_files_only=True) \
                if compression_module_path else LearnableCacherNonlinear(**kwargs)
        elif compression_type == "linear":
            self.compression_decompression_module = LearnableCacherLinear.from_pretrained(compression_module_path, local_files_only=True) \
                if compression_module_path else LearnableCacherLinear(**kwargs)
        elif compression_type == "residual":
            self.compression_decompression_module = LearnableCacherDeltaMLP.from_pretrained(compression_module_path, local_files_only=True) \
                if compression_module_path else LearnableCacherDeltaMLP(**kwargs)

        # self.compression_decompression_module = LearnableCacher.from_pretrained(compression_module_path, local_files_only=True) \
        #     if compression_module_path else LearnableCacher(**kwargs)
        # self.applied_projection = False

    def cast_compression_module(self, torch_dtype: torch.dtype):
        self.compression_decompression_module.convert_type(torch_dtype)

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        """Returns the sequence length of the cached states. A layer index can be optionally passed."""
        # TODO: deprecate this function in favor of `cache_position`
        # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
        if len(self.key_cache) <= layer_idx:
            return 0
        return self.key_cache[layer_idx].shape[-2]

    def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
        print(f"Trying to access layer_idx: {layer_idx}")
        if layer_idx < len(self):
            return (self.key_cache[layer_idx], self.value_cache[layer_idx])
        else:
            raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Update the cache with the new key and value states for the given layer.
        """
        # if layer_idx == 0:
        #     self._seen_tokens += key_states.shape[-2]

        if key_states is not None:
            if len(self.key_cache) <= layer_idx:
                for _ in range(len(self.key_cache), layer_idx):
                    self.key_cache.append([])
                    self.value_cache.append([])

                self.key_cache.append(self.compression_decompression_module.compress_kv_cache(key_states, layer_idx, cache_type="key"))
                self.value_cache.append(self.compression_decompression_module.compress_kv_cache(value_states, layer_idx, cache_type="value"))
                # self.key_cache.append(self.compress_kv_cache(key_states, layer_idx))
                # self.value_cache.append(self.compress_kv_cache(value_states, layer_idx))
            elif len(self.key_cache[layer_idx]) == 0:
                self.key_cache[layer_idx] = self.compression_decompression_module.compress_kv_cache(key_states, layer_idx, cache_type="key")
                self.value_cache[layer_idx] = self.compression_decompression_module.compress_kv_cache(value_states, layer_idx, cache_type="value")

                # self.key_cache[layer_idx] = self.compress_kv_cache(key_states, layer_idx)
                # self.value_cache[layer_idx] = self.compress_kv_cache(value_states, layer_idx)
            else:
                self.key_cache[layer_idx] = torch.cat(
                    [self.key_cache[layer_idx], self.compression_decompression_module.compress_kv_cache(key_states, layer_idx, cache_type="key")], dim=-2)
                self.value_cache[layer_idx] = torch.cat(
                    [self.value_cache[layer_idx], self.compression_decompression_module.compress_kv_cache(value_states, layer_idx, cache_type="value")], dim=-2)
                # self.key_cache[layer_idx] = torch.cat(
                #     [self.key_cache[layer_idx], self.compress_kv_cache(key_states, layer_idx)], dim=-2)
                # self.value_cache[layer_idx] = torch.cat(
                #     [self.value_cache[layer_idx], self.compress_kv_cache(value_states, layer_idx)], dim=-2)


        decompressed_key_cache = self.compression_decompression_module.decompress_kv_cache(self.key_cache[layer_idx], layer_idx, cache_type="key")
        decompressed_value_cache = self.compression_decompression_module.decompress_kv_cache(self.value_cache[layer_idx], layer_idx, cache_type="value")

        return decompressed_key_cache, decompressed_value_cache

As you can I see I simply load the external compression module required for compressing.

I have been stuck on trying to implement this into Vllm ? Any suggestions on when and how to to this ?

It seems likely that vLLM does not currently provide a direct plug-in interface for custom KV cache compression modules like your JVCache class, especially with learnable or non-standard compression/decompression logic. vLLM’s KV cache management is tightly integrated with its memory allocator and attention backends, and most existing support for compression is via quantization (e.g., FP8, INT8) or block/page-level eviction, not arbitrary user-defined compression modules. See sparse KV cache RFC and quantized kv cache support for the current directions.

If you want to integrate a custom compression module, you would likely need to modify vLLM’s attention layer and KV cache manager to call your compression/decompression logic at the appropriate points—specifically, where keys/values are written to or read from the cache. There is no documented API for this, so you would need to fork vLLM and insert your hooks in the attention forward pass and cache update logic. For reference, see how quantized KV cache is handled in CompressedTensorsKVCacheMethod and the KV cache helper.

Would you like a step-by-step outline of where to patch vLLM for this integration?

Sources: