kyleliang's picture
Upload folder using huggingface_hub
dafad47 verified
# Copyright 2025 China Merchants Bank. All rights reserved.
#
# Licensed under the MIT License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://mit-license.org
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import DynamicCache, GenerationConfig
from typing import Any, Dict, List, Optional, Tuple
UNSUPPORTED_GENERATION_ARGS = [
"cache_implementation", # cache-related arguments, here we always use SinkCache
"cache_config",
"return_legacy_cache",
"num_beams", # beam search (and cousin techniques) are not supported
"compile_config", # SinkCache doesn't support torch.compile
"assistant_model", # it also doesn't support speculative decoding
]
class LagKVCache(DynamicCache):
"""
A KV compression algorithm that as described in the [LagKV paper](https://arxiv.org/abs/2504.04704).
The algorithm equips Sink Attention and SlidingWindow like SinkCache but with additional selective tokens in the middle.
It allows the model to generate with fewer memory resource and faster decoding speed.
The model will hold the main part of information retrieval capbility during the compression, compared to a completed loss
of the SinkCache.
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
For the chunked prefilling, see https://github.com/AI-Lab-China-Merchants-Bank/LagKV.
Parameters:
_distributed_cache_data:
Inherited from DynamicCache.
ratio (`float`):
The retrain ratio of tokens in the middle chunks.
sink_size (`int`):
The number of sink tokens.
lag_size (`int`):
The size of the partition. The subsequent partion will serve as a reference for the prior one.
score_v_ratio (`float`):
The ratio multiplied to the score of Value states.
skip_layer_idx (`Optional[List[int]]`):
A list of layer indices will skip the compression.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, LagKVCache
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = LagKVCache(ratio=0.25, lag_size=128)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
LagKVCache()
```
"""
def __init__(
self,
_distributed_cache_data=None,
ratio: float = 0.25,
sink_size: int = 16,
lag_size: int = 1024,
score_v_ratio: float = 1.0,
skip_layer_idx: Optional[List[int]] = None,
):
super().__init__(_distributed_cache_data)
self.ratio = ratio
self.sink_size: int = sink_size
self.lag_size: int = lag_size
self.score_v_ratio: float = score_v_ratio
self.skip_layer_idx: List[int] = skip_layer_idx if skip_layer_idx is not None else []
self._compressed_len: List[int] = []
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs=None,
):
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Update the cache
if key_states is not None:
if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append([])
self.value_cache.append([])
self._compressed_len.append(self.sink_size)
self.key_cache.append(key_states)
self.value_cache.append(value_states)
self._compressed_len.append(self.sink_size)
elif (
len(self.key_cache[layer_idx]) == 0
): # fills previously skipped layers; checking for tensor causes errors
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
if layer_idx not in self.skip_layer_idx:
return self._compress_kv_by_lag(layer_idx)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def _get_states_score(self, base_len, in_size, end_idx, value):
"""Partition the states then calculate the state scores"""
# [batch_size, num_heads, seq_len, head_dim]
target_v = value[:, :, base_len:end_idx]
# [batch_size, num_heads, partition_num, lag_size, head_dim]
target_v = target_v.view(in_size[0], in_size[1], -1, self.lag_size, in_size[-1])
ref = target_v[:, :, 1:, :, :]
v = target_v[:, :, :-1, :, :]
min_r = ref.min(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)
max_r = ref.max(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)
score = ((v - min_r) / (max_r - min_r)).std(dim=-1).softmax(dim=-1)
return score
def _modify_kv(self, value, base_len, end_idx, selected_idx, tail_len):
# idx is offset by base_len
selected_value = torch.gather(value[:, :, base_len:end_idx], -2, selected_idx)
value = torch.cat((value[:, :, :base_len], selected_value, value[:, :, -tail_len:]), dim=-2)
return value
def _compress_algo(self, layer_idx, base_len):
"""
Calculate the scores of KV tokens in each head and partition. See the paper.
The computation overhead of top-k is significantly reduced by partitioning.
"""
in_size = self.key_cache[layer_idx].size()
end_idx = base_len + ((in_size[-2] - base_len) // self.lag_size) * self.lag_size
# [batch_size, num_heads, partition_num - 1, lag_size, head_dim]
key_score = self._get_states_score(base_len, in_size, end_idx, self.key_cache[layer_idx])
value_score = self._get_states_score(base_len, in_size, end_idx, self.value_cache[layer_idx])
score = key_score + value_score * self.score_v_ratio
# you may need to sort the index for some cases
selected_idx = torch.topk(score, int(self.ratio * self.lag_size), dim=-1).indices
for i in range(1, selected_idx.size()[2], 1):
selected_idx[:, :, i] += i * self.lag_size
selected_idx = selected_idx.reshape(in_size[0], in_size[1], -1).unsqueeze(-1).expand(-1, -1, -1, in_size[-1])
new_base_len = base_len + selected_idx.size()[-2]
# alwarys keep the last window
tail_len = self.lag_size + in_size[-2] - end_idx
self.key_cache[layer_idx] = self._modify_kv(
self.key_cache[layer_idx], base_len, end_idx, selected_idx, tail_len
)
self.value_cache[layer_idx] = self._modify_kv(
self.value_cache[layer_idx], base_len, end_idx, selected_idx, tail_len
)
self._compressed_len[layer_idx] = new_base_len
def _compress_kv_by_lag(self, layer_idx):
"""the KV cache will be used then compressed"""
kv_size = self.key_cache[layer_idx].size()
base_len = self._compressed_len[layer_idx]
keys_to_return, values_to_return = self.key_cache[layer_idx], self.value_cache[layer_idx]
if kv_size[-2] >= base_len + 2 * self.lag_size:
self._compress_algo(layer_idx, base_len)
return keys_to_return, values_to_return
def generate(model, lag_ratio=0.5, lag_sink_size=16, lag_size=128, **kwargs):
"""Custom generate function for LagKVCache.
(template from https://huggingface.co/transformers-community/sink_cache)
Args:
model (`PreTrainedModel`):
The model to generate from.
lag_ratio (`float`):
The retrain ratio of tokens in the middle chunks.
lag_sink_size (`int`):
The number of sink tokens.
lag_size (`int`):
The size of the partition. See the original paper for more information.
"""
# 1. General sanity checks
# 1.a. A few arguments are not allowed, especially arguments that control caches.
generation_config = kwargs.get("generation_config")
default_global_generation_config = GenerationConfig()
default_model_generation_config = model.generation_config
for arg in UNSUPPORTED_GENERATION_ARGS:
has_custom_gen_config_arg = (
generation_config is not None
# = and not (match global default or match model-specific default)
and not (
getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
)
)
kwargs_has_arg = arg in kwargs and kwargs[arg] is not None
if kwargs_has_arg or has_custom_gen_config_arg:
raise ValueError(
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
)
# 1.b. The model must be decoder-only
if model.config.is_encoder_decoder:
raise ValueError("This custom generate function only works with decoder-only models")
# 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
# in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
kwargs.pop("custom_generate", None)
# 2. Generate with LagKVCache
# 2.a. prepare the cache, if it was not passed.
past_key_values = kwargs.pop("past_key_values", None)
if past_key_values is None:
past_key_values = LagKVCache(ratio=lag_ratio, sink_size=lag_sink_size, lag_size=lag_size)
elif not isinstance(past_key_values, LagKVCache):
raise ValueError(f"`past_key_values` must be a `LagKVCache` instance, got a {type(past_key_values)} instance")
# 2.b. generate with the cache
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
return generation_outputs