File size: 11,766 Bytes
d60851f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9addd5
d60851f
 
 
dafad47
 
 
 
 
 
 
 
 
d60851f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
# Copyright 2025 China Merchants Bank. All rights reserved.
#
# Licensed under the MIT License (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://mit-license.org
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from transformers import DynamicCache, GenerationConfig
from typing import Any, Dict, List, Optional, Tuple


UNSUPPORTED_GENERATION_ARGS = [
    "cache_implementation",  # cache-related arguments, here we always use SinkCache
    "cache_config",
    "return_legacy_cache",
    "num_beams",  # beam search (and cousin techniques) are not supported
    "compile_config",  # SinkCache doesn't support torch.compile
    "assistant_model",  # it also doesn't support speculative decoding
]

class LagKVCache(DynamicCache):
    """
    A KV compression algorithm that as described in the [LagKV paper](https://arxiv.org/abs/2504.04704).
    The algorithm equips Sink Attention and SlidingWindow like SinkCache but with additional selective tokens in the middle.
    It allows the model to generate with fewer memory resource and faster decoding speed.
    The model will hold the main part of information retrieval capbility during the compression, compared to a completed loss
    of the SinkCache.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.

    For the chunked prefilling, see https://github.com/AI-Lab-China-Merchants-Bank/LagKV.

    Parameters:
        _distributed_cache_data:
            Inherited from DynamicCache.
        ratio (`float`):
            The retrain ratio of tokens in the middle chunks.
        sink_size (`int`):
            The number of sink tokens.
        lag_size (`int`):
            The size of the partition. The subsequent partion will serve as a reference for the prior one.
        score_v_ratio (`float`):
            The ratio multiplied to the score of Value states.
        skip_layer_idx (`Optional[List[int]]`):
            A list of layer indices will skip the compression.

    Example:

        ```python
        >>> from transformers import AutoTokenizer, AutoModelForCausalLM, LagKVCache

        >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
        >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")

        >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")

        >>> # Prepare a cache class and pass it to model's forward
        >>> past_key_values = LagKVCache(ratio=0.25, lag_size=128)
        >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
        >>> outputs.past_key_values # access cache filled with key/values from generation
        LagKVCache()
        ```
    """

    def __init__(
        self,
        _distributed_cache_data=None,
        ratio: float = 0.25,
        sink_size: int = 16,
        lag_size: int = 1024,
        score_v_ratio: float = 1.0,
        skip_layer_idx: Optional[List[int]] = None,
    ):
        super().__init__(_distributed_cache_data)
        self.ratio = ratio
        self.sink_size: int = sink_size
        self.lag_size: int = lag_size
        self.score_v_ratio: float = score_v_ratio
        self.skip_layer_idx: List[int] = skip_layer_idx if skip_layer_idx is not None else []
        self._compressed_len: List[int] = []

    def update(
        self,
        key_states: torch.Tensor,
        value_states: torch.Tensor,
        layer_idx: int,
        cache_kwargs=None,
    ):
        """
        Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

        Parameters:
            key_states (`torch.Tensor`):
                The new key states to cache.
            value_states (`torch.Tensor`):
                The new value states to cache.
            layer_idx (`int`):
                The index of the layer to cache the states for.
            cache_kwargs (`Dict[str, Any]`, `optional`):
                Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

        Return:
            A tuple containing the updated key and value states.
        """
        # Update the number of seen tokens
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        # Update the cache
        if key_states is not None:
            if len(self.key_cache) <= layer_idx:
                # There may be skipped layers, fill them with empty lists
                for _ in range(len(self.key_cache), layer_idx):
                    self.key_cache.append([])
                    self.value_cache.append([])
                    self._compressed_len.append(self.sink_size)
                self.key_cache.append(key_states)
                self.value_cache.append(value_states)
                self._compressed_len.append(self.sink_size)
            elif (
                len(self.key_cache[layer_idx]) == 0
            ):  # fills previously skipped layers; checking for tensor causes errors
                self.key_cache[layer_idx] = key_states
                self.value_cache[layer_idx] = value_states
            else:
                self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
                self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

            if layer_idx not in self.skip_layer_idx:
                return self._compress_kv_by_lag(layer_idx)

        return self.key_cache[layer_idx], self.value_cache[layer_idx]

    def _get_states_score(self, base_len, in_size, end_idx, value):
        """Partition the states then calculate the state scores"""
        # [batch_size, num_heads, seq_len, head_dim]
        target_v = value[:, :, base_len:end_idx]
        # [batch_size, num_heads, partition_num, lag_size, head_dim]
        target_v = target_v.view(in_size[0], in_size[1], -1, self.lag_size, in_size[-1])
        ref = target_v[:, :, 1:, :, :]
        v = target_v[:, :, :-1, :, :]

        min_r = ref.min(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)
        max_r = ref.max(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)

        score = ((v - min_r) / (max_r - min_r)).std(dim=-1).softmax(dim=-1)

        return score

    def _modify_kv(self, value, base_len, end_idx, selected_idx, tail_len):
        # idx is offset by base_len
        selected_value = torch.gather(value[:, :, base_len:end_idx], -2, selected_idx)
        value = torch.cat((value[:, :, :base_len], selected_value, value[:, :, -tail_len:]), dim=-2)
        return value

    def _compress_algo(self, layer_idx, base_len):
        """
        Calculate the scores of KV tokens in each head and partition. See the paper.
        The computation overhead of top-k is significantly reduced by partitioning.
        """
        in_size = self.key_cache[layer_idx].size()
        end_idx = base_len + ((in_size[-2] - base_len) // self.lag_size) * self.lag_size
        # [batch_size, num_heads, partition_num - 1, lag_size, head_dim]
        key_score = self._get_states_score(base_len, in_size, end_idx, self.key_cache[layer_idx])
        value_score = self._get_states_score(base_len, in_size, end_idx, self.value_cache[layer_idx])
        score = key_score + value_score * self.score_v_ratio
        # you may need to sort the index for some cases
        selected_idx = torch.topk(score, int(self.ratio * self.lag_size), dim=-1).indices
        for i in range(1, selected_idx.size()[2], 1):
            selected_idx[:, :, i] += i * self.lag_size
        selected_idx = selected_idx.reshape(in_size[0], in_size[1], -1).unsqueeze(-1).expand(-1, -1, -1, in_size[-1])
        new_base_len = base_len + selected_idx.size()[-2]
        # alwarys keep the last window
        tail_len = self.lag_size + in_size[-2] - end_idx
        self.key_cache[layer_idx] = self._modify_kv(
            self.key_cache[layer_idx], base_len, end_idx, selected_idx, tail_len
        )
        self.value_cache[layer_idx] = self._modify_kv(
            self.value_cache[layer_idx], base_len, end_idx, selected_idx, tail_len
        )
        self._compressed_len[layer_idx] = new_base_len

    def _compress_kv_by_lag(self, layer_idx):
        """the KV cache will be used then compressed"""
        kv_size = self.key_cache[layer_idx].size()
        base_len = self._compressed_len[layer_idx]

        keys_to_return, values_to_return = self.key_cache[layer_idx], self.value_cache[layer_idx]
        if kv_size[-2] >= base_len + 2 * self.lag_size:
            self._compress_algo(layer_idx, base_len)
        return keys_to_return, values_to_return

def generate(model, lag_ratio=0.5, lag_sink_size=16, lag_size=128, **kwargs):
    """Custom generate function for LagKVCache. 
    (template from https://huggingface.co/transformers-community/sink_cache)
    Args:
        model (`PreTrainedModel`):
            The model to generate from.
        lag_ratio (`float`):
            The retrain ratio of tokens in the middle chunks.
        lag_sink_size (`int`):
            The number of sink tokens.
        lag_size (`int`):
            The size of the partition. See the original paper for more information.
    """
    # 1. General sanity checks
    # 1.a. A few arguments are not allowed, especially arguments that control caches.
    generation_config = kwargs.get("generation_config")
    default_global_generation_config = GenerationConfig()
    default_model_generation_config = model.generation_config
    for arg in UNSUPPORTED_GENERATION_ARGS:
        has_custom_gen_config_arg = (
            generation_config is not None
            # = and not (match global default or match model-specific default)
            and not (
                getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
                or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
            )
        )
        kwargs_has_arg = arg in kwargs and kwargs[arg] is not None
        if kwargs_has_arg or has_custom_gen_config_arg:
            raise ValueError(
                f"`{arg}` is set, but it's not supported in this custom generate function. List of "
                f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
            )

    # 1.b. The model must be decoder-only
    if model.config.is_encoder_decoder:
        raise ValueError("This custom generate function only works with decoder-only models")

    # 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
    # in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
    kwargs.pop("custom_generate", None)

    # 2. Generate with LagKVCache
    # 2.a. prepare the cache, if it was not passed.
    past_key_values = kwargs.pop("past_key_values", None)
    if past_key_values is None:
        past_key_values = LagKVCache(ratio=lag_ratio, sink_size=lag_sink_size, lag_size=lag_size)
    elif not isinstance(past_key_values, LagKVCache):
        raise ValueError(f"`past_key_values` must be a `LagKVCache` instance, got a {type(past_key_values)} instance")

    # 2.b. generate with the cache
    generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
    return generation_outputs