Hi everyone. I have implemented a novel KV caching approach that works with HF caching.
Driver class here:
class JVCache(Cache):
def __init__(self, compression_module_path: str = None, compression_type: str = "linear", **kwargs):
super(JVCache, self).__init__(**kwargs)
print("Initialising Randomised Cache")
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
if compression_type == "nonlinear":
self.compression_decompression_module = LearnableCacherNonlinear.from_pretrained(compression_module_path, local_files_only=True) \
if compression_module_path else LearnableCacherNonlinear(**kwargs)
elif compression_type == "linear":
self.compression_decompression_module = LearnableCacherLinear.from_pretrained(compression_module_path, local_files_only=True) \
if compression_module_path else LearnableCacherLinear(**kwargs)
elif compression_type == "residual":
self.compression_decompression_module = LearnableCacherDeltaMLP.from_pretrained(compression_module_path, local_files_only=True) \
if compression_module_path else LearnableCacherDeltaMLP(**kwargs)
# self.compression_decompression_module = LearnableCacher.from_pretrained(compression_module_path, local_files_only=True) \
# if compression_module_path else LearnableCacher(**kwargs)
# self.applied_projection = False
def cast_compression_module(self, torch_dtype: torch.dtype):
self.compression_decompression_module.convert_type(torch_dtype)
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
print(f"Trying to access layer_idx: {layer_idx}")
if layer_idx < len(self):
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
else:
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update the cache with the new key and value states for the given layer.
"""
# if layer_idx == 0:
# self._seen_tokens += key_states.shape[-2]
if key_states is not None:
if len(self.key_cache) <= layer_idx:
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append([])
self.value_cache.append([])
self.key_cache.append(self.compression_decompression_module.compress_kv_cache(key_states, layer_idx, cache_type="key"))
self.value_cache.append(self.compression_decompression_module.compress_kv_cache(value_states, layer_idx, cache_type="value"))
# self.key_cache.append(self.compress_kv_cache(key_states, layer_idx))
# self.value_cache.append(self.compress_kv_cache(value_states, layer_idx))
elif len(self.key_cache[layer_idx]) == 0:
self.key_cache[layer_idx] = self.compression_decompression_module.compress_kv_cache(key_states, layer_idx, cache_type="key")
self.value_cache[layer_idx] = self.compression_decompression_module.compress_kv_cache(value_states, layer_idx, cache_type="value")
# self.key_cache[layer_idx] = self.compress_kv_cache(key_states, layer_idx)
# self.value_cache[layer_idx] = self.compress_kv_cache(value_states, layer_idx)
else:
self.key_cache[layer_idx] = torch.cat(
[self.key_cache[layer_idx], self.compression_decompression_module.compress_kv_cache(key_states, layer_idx, cache_type="key")], dim=-2)
self.value_cache[layer_idx] = torch.cat(
[self.value_cache[layer_idx], self.compression_decompression_module.compress_kv_cache(value_states, layer_idx, cache_type="value")], dim=-2)
# self.key_cache[layer_idx] = torch.cat(
# [self.key_cache[layer_idx], self.compress_kv_cache(key_states, layer_idx)], dim=-2)
# self.value_cache[layer_idx] = torch.cat(
# [self.value_cache[layer_idx], self.compress_kv_cache(value_states, layer_idx)], dim=-2)
decompressed_key_cache = self.compression_decompression_module.decompress_kv_cache(self.key_cache[layer_idx], layer_idx, cache_type="key")
decompressed_value_cache = self.compression_decompression_module.decompress_kv_cache(self.value_cache[layer_idx], layer_idx, cache_type="value")
return decompressed_key_cache, decompressed_value_cache
As you can I see I simply load the external compression module required for compressing.
I have been stuck on trying to implement this into Vllm ? Any suggestions on when and how to to this ?