File size: 11,766 Bytes
d60851f d9addd5 d60851f dafad47 d60851f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
# Copyright 2025 China Merchants Bank. All rights reserved.
#
# Licensed under the MIT License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://mit-license.org
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from transformers import DynamicCache, GenerationConfig
from typing import Any, Dict, List, Optional, Tuple
UNSUPPORTED_GENERATION_ARGS = [
"cache_implementation", # cache-related arguments, here we always use SinkCache
"cache_config",
"return_legacy_cache",
"num_beams", # beam search (and cousin techniques) are not supported
"compile_config", # SinkCache doesn't support torch.compile
"assistant_model", # it also doesn't support speculative decoding
]
class LagKVCache(DynamicCache):
"""
A KV compression algorithm that as described in the [LagKV paper](https://arxiv.org/abs/2504.04704).
The algorithm equips Sink Attention and SlidingWindow like SinkCache but with additional selective tokens in the middle.
It allows the model to generate with fewer memory resource and faster decoding speed.
The model will hold the main part of information retrieval capbility during the compression, compared to a completed loss
of the SinkCache.
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
For the chunked prefilling, see https://github.com/AI-Lab-China-Merchants-Bank/LagKV.
Parameters:
_distributed_cache_data:
Inherited from DynamicCache.
ratio (`float`):
The retrain ratio of tokens in the middle chunks.
sink_size (`int`):
The number of sink tokens.
lag_size (`int`):
The size of the partition. The subsequent partion will serve as a reference for the prior one.
score_v_ratio (`float`):
The ratio multiplied to the score of Value states.
skip_layer_idx (`Optional[List[int]]`):
A list of layer indices will skip the compression.
Example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, LagKVCache
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = LagKVCache(ratio=0.25, lag_size=128)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
LagKVCache()
```
"""
def __init__(
self,
_distributed_cache_data=None,
ratio: float = 0.25,
sink_size: int = 16,
lag_size: int = 1024,
score_v_ratio: float = 1.0,
skip_layer_idx: Optional[List[int]] = None,
):
super().__init__(_distributed_cache_data)
self.ratio = ratio
self.sink_size: int = sink_size
self.lag_size: int = lag_size
self.score_v_ratio: float = score_v_ratio
self.skip_layer_idx: List[int] = skip_layer_idx if skip_layer_idx is not None else []
self._compressed_len: List[int] = []
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs=None,
):
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Update the cache
if key_states is not None:
if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append([])
self.value_cache.append([])
self._compressed_len.append(self.sink_size)
self.key_cache.append(key_states)
self.value_cache.append(value_states)
self._compressed_len.append(self.sink_size)
elif (
len(self.key_cache[layer_idx]) == 0
): # fills previously skipped layers; checking for tensor causes errors
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
if layer_idx not in self.skip_layer_idx:
return self._compress_kv_by_lag(layer_idx)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def _get_states_score(self, base_len, in_size, end_idx, value):
"""Partition the states then calculate the state scores"""
# [batch_size, num_heads, seq_len, head_dim]
target_v = value[:, :, base_len:end_idx]
# [batch_size, num_heads, partition_num, lag_size, head_dim]
target_v = target_v.view(in_size[0], in_size[1], -1, self.lag_size, in_size[-1])
ref = target_v[:, :, 1:, :, :]
v = target_v[:, :, :-1, :, :]
min_r = ref.min(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)
max_r = ref.max(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)
score = ((v - min_r) / (max_r - min_r)).std(dim=-1).softmax(dim=-1)
return score
def _modify_kv(self, value, base_len, end_idx, selected_idx, tail_len):
# idx is offset by base_len
selected_value = torch.gather(value[:, :, base_len:end_idx], -2, selected_idx)
value = torch.cat((value[:, :, :base_len], selected_value, value[:, :, -tail_len:]), dim=-2)
return value
def _compress_algo(self, layer_idx, base_len):
"""
Calculate the scores of KV tokens in each head and partition. See the paper.
The computation overhead of top-k is significantly reduced by partitioning.
"""
in_size = self.key_cache[layer_idx].size()
end_idx = base_len + ((in_size[-2] - base_len) // self.lag_size) * self.lag_size
# [batch_size, num_heads, partition_num - 1, lag_size, head_dim]
key_score = self._get_states_score(base_len, in_size, end_idx, self.key_cache[layer_idx])
value_score = self._get_states_score(base_len, in_size, end_idx, self.value_cache[layer_idx])
score = key_score + value_score * self.score_v_ratio
# you may need to sort the index for some cases
selected_idx = torch.topk(score, int(self.ratio * self.lag_size), dim=-1).indices
for i in range(1, selected_idx.size()[2], 1):
selected_idx[:, :, i] += i * self.lag_size
selected_idx = selected_idx.reshape(in_size[0], in_size[1], -1).unsqueeze(-1).expand(-1, -1, -1, in_size[-1])
new_base_len = base_len + selected_idx.size()[-2]
# alwarys keep the last window
tail_len = self.lag_size + in_size[-2] - end_idx
self.key_cache[layer_idx] = self._modify_kv(
self.key_cache[layer_idx], base_len, end_idx, selected_idx, tail_len
)
self.value_cache[layer_idx] = self._modify_kv(
self.value_cache[layer_idx], base_len, end_idx, selected_idx, tail_len
)
self._compressed_len[layer_idx] = new_base_len
def _compress_kv_by_lag(self, layer_idx):
"""the KV cache will be used then compressed"""
kv_size = self.key_cache[layer_idx].size()
base_len = self._compressed_len[layer_idx]
keys_to_return, values_to_return = self.key_cache[layer_idx], self.value_cache[layer_idx]
if kv_size[-2] >= base_len + 2 * self.lag_size:
self._compress_algo(layer_idx, base_len)
return keys_to_return, values_to_return
def generate(model, lag_ratio=0.5, lag_sink_size=16, lag_size=128, **kwargs):
"""Custom generate function for LagKVCache.
(template from https://huggingface.co/transformers-community/sink_cache)
Args:
model (`PreTrainedModel`):
The model to generate from.
lag_ratio (`float`):
The retrain ratio of tokens in the middle chunks.
lag_sink_size (`int`):
The number of sink tokens.
lag_size (`int`):
The size of the partition. See the original paper for more information.
"""
# 1. General sanity checks
# 1.a. A few arguments are not allowed, especially arguments that control caches.
generation_config = kwargs.get("generation_config")
default_global_generation_config = GenerationConfig()
default_model_generation_config = model.generation_config
for arg in UNSUPPORTED_GENERATION_ARGS:
has_custom_gen_config_arg = (
generation_config is not None
# = and not (match global default or match model-specific default)
and not (
getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
)
)
kwargs_has_arg = arg in kwargs and kwargs[arg] is not None
if kwargs_has_arg or has_custom_gen_config_arg:
raise ValueError(
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
)
# 1.b. The model must be decoder-only
if model.config.is_encoder_decoder:
raise ValueError("This custom generate function only works with decoder-only models")
# 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
# in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
kwargs.pop("custom_generate", None)
# 2. Generate with LagKVCache
# 2.a. prepare the cache, if it was not passed.
past_key_values = kwargs.pop("past_key_values", None)
if past_key_values is None:
past_key_values = LagKVCache(ratio=lag_ratio, sink_size=lag_sink_size, lag_size=lag_size)
elif not isinstance(past_key_values, LagKVCache):
raise ValueError(f"`past_key_values` must be a `LagKVCache` instance, got a {type(past_key_values)} instance")
# 2.b. generate with the cache
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
return generation_outputs
|