|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from transformers import DynamicCache, GenerationConfig |
|
from typing import Any, Dict, List, Optional, Tuple |
|
|
|
|
|
UNSUPPORTED_GENERATION_ARGS = [ |
|
"cache_implementation", |
|
"cache_config", |
|
"return_legacy_cache", |
|
"num_beams", |
|
"compile_config", |
|
"assistant_model", |
|
] |
|
|
|
class LagKVCache(DynamicCache): |
|
""" |
|
A KV compression algorithm that as described in the [LagKV paper](https://arxiv.org/abs/2504.04704). |
|
The algorithm equips Sink Attention and SlidingWindow like SinkCache but with additional selective tokens in the middle. |
|
It allows the model to generate with fewer memory resource and faster decoding speed. |
|
The model will hold the main part of information retrieval capbility during the compression, compared to a completed loss |
|
of the SinkCache. |
|
|
|
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is |
|
`[batch_size, num_heads, seq_len, head_dim]`. |
|
|
|
For the chunked prefilling, see https://github.com/AI-Lab-China-Merchants-Bank/LagKV. |
|
|
|
Parameters: |
|
_distributed_cache_data: |
|
Inherited from DynamicCache. |
|
ratio (`float`): |
|
The retrain ratio of tokens in the middle chunks. |
|
sink_size (`int`): |
|
The number of sink tokens. |
|
lag_size (`int`): |
|
The size of the partition. The subsequent partion will serve as a reference for the prior one. |
|
score_v_ratio (`float`): |
|
The ratio multiplied to the score of Value states. |
|
skip_layer_idx (`Optional[List[int]]`): |
|
A list of layer indices will skip the compression. |
|
|
|
Example: |
|
|
|
```python |
|
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, LagKVCache |
|
|
|
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") |
|
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") |
|
|
|
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") |
|
|
|
>>> # Prepare a cache class and pass it to model's forward |
|
>>> past_key_values = LagKVCache(ratio=0.25, lag_size=128) |
|
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) |
|
>>> outputs.past_key_values # access cache filled with key/values from generation |
|
LagKVCache() |
|
``` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
_distributed_cache_data=None, |
|
ratio: float = 0.25, |
|
sink_size: int = 16, |
|
lag_size: int = 1024, |
|
score_v_ratio: float = 1.0, |
|
skip_layer_idx: Optional[List[int]] = None, |
|
): |
|
super().__init__(_distributed_cache_data) |
|
self.ratio = ratio |
|
self.sink_size: int = sink_size |
|
self.lag_size: int = lag_size |
|
self.score_v_ratio: float = score_v_ratio |
|
self.skip_layer_idx: List[int] = skip_layer_idx if skip_layer_idx is not None else [] |
|
self._compressed_len: List[int] = [] |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs=None, |
|
): |
|
""" |
|
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
|
|
|
Parameters: |
|
key_states (`torch.Tensor`): |
|
The new key states to cache. |
|
value_states (`torch.Tensor`): |
|
The new value states to cache. |
|
layer_idx (`int`): |
|
The index of the layer to cache the states for. |
|
cache_kwargs (`Dict[str, Any]`, `optional`): |
|
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. |
|
|
|
Return: |
|
A tuple containing the updated key and value states. |
|
""" |
|
|
|
if layer_idx == 0: |
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
|
|
if key_states is not None: |
|
if len(self.key_cache) <= layer_idx: |
|
|
|
for _ in range(len(self.key_cache), layer_idx): |
|
self.key_cache.append([]) |
|
self.value_cache.append([]) |
|
self._compressed_len.append(self.sink_size) |
|
self.key_cache.append(key_states) |
|
self.value_cache.append(value_states) |
|
self._compressed_len.append(self.sink_size) |
|
elif ( |
|
len(self.key_cache[layer_idx]) == 0 |
|
): |
|
self.key_cache[layer_idx] = key_states |
|
self.value_cache[layer_idx] = value_states |
|
else: |
|
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
|
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
|
|
|
if layer_idx not in self.skip_layer_idx: |
|
return self._compress_kv_by_lag(layer_idx) |
|
|
|
return self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
|
|
def _get_states_score(self, base_len, in_size, end_idx, value): |
|
"""Partition the states then calculate the state scores""" |
|
|
|
target_v = value[:, :, base_len:end_idx] |
|
|
|
target_v = target_v.view(in_size[0], in_size[1], -1, self.lag_size, in_size[-1]) |
|
ref = target_v[:, :, 1:, :, :] |
|
v = target_v[:, :, :-1, :, :] |
|
|
|
min_r = ref.min(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1) |
|
max_r = ref.max(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1) |
|
|
|
score = ((v - min_r) / (max_r - min_r)).std(dim=-1).softmax(dim=-1) |
|
|
|
return score |
|
|
|
def _modify_kv(self, value, base_len, end_idx, selected_idx, tail_len): |
|
|
|
selected_value = torch.gather(value[:, :, base_len:end_idx], -2, selected_idx) |
|
value = torch.cat((value[:, :, :base_len], selected_value, value[:, :, -tail_len:]), dim=-2) |
|
return value |
|
|
|
def _compress_algo(self, layer_idx, base_len): |
|
""" |
|
Calculate the scores of KV tokens in each head and partition. See the paper. |
|
The computation overhead of top-k is significantly reduced by partitioning. |
|
""" |
|
in_size = self.key_cache[layer_idx].size() |
|
end_idx = base_len + ((in_size[-2] - base_len) // self.lag_size) * self.lag_size |
|
|
|
key_score = self._get_states_score(base_len, in_size, end_idx, self.key_cache[layer_idx]) |
|
value_score = self._get_states_score(base_len, in_size, end_idx, self.value_cache[layer_idx]) |
|
score = key_score + value_score * self.score_v_ratio |
|
|
|
selected_idx = torch.topk(score, int(self.ratio * self.lag_size), dim=-1).indices |
|
for i in range(1, selected_idx.size()[2], 1): |
|
selected_idx[:, :, i] += i * self.lag_size |
|
selected_idx = selected_idx.reshape(in_size[0], in_size[1], -1).unsqueeze(-1).expand(-1, -1, -1, in_size[-1]) |
|
new_base_len = base_len + selected_idx.size()[-2] |
|
|
|
tail_len = self.lag_size + in_size[-2] - end_idx |
|
self.key_cache[layer_idx] = self._modify_kv( |
|
self.key_cache[layer_idx], base_len, end_idx, selected_idx, tail_len |
|
) |
|
self.value_cache[layer_idx] = self._modify_kv( |
|
self.value_cache[layer_idx], base_len, end_idx, selected_idx, tail_len |
|
) |
|
self._compressed_len[layer_idx] = new_base_len |
|
|
|
def _compress_kv_by_lag(self, layer_idx): |
|
"""the KV cache will be used then compressed""" |
|
kv_size = self.key_cache[layer_idx].size() |
|
base_len = self._compressed_len[layer_idx] |
|
|
|
keys_to_return, values_to_return = self.key_cache[layer_idx], self.value_cache[layer_idx] |
|
if kv_size[-2] >= base_len + 2 * self.lag_size: |
|
self._compress_algo(layer_idx, base_len) |
|
return keys_to_return, values_to_return |
|
|
|
def generate(model, lag_ratio=0.5, lag_sink_size=16, lag_size=128, **kwargs): |
|
"""Custom generate function for LagKVCache. |
|
(template from https://huggingface.co/transformers-community/sink_cache) |
|
Args: |
|
model (`PreTrainedModel`): |
|
The model to generate from. |
|
lag_ratio (`float`): |
|
The retrain ratio of tokens in the middle chunks. |
|
lag_sink_size (`int`): |
|
The number of sink tokens. |
|
lag_size (`int`): |
|
The size of the partition. See the original paper for more information. |
|
""" |
|
|
|
|
|
generation_config = kwargs.get("generation_config") |
|
default_global_generation_config = GenerationConfig() |
|
default_model_generation_config = model.generation_config |
|
for arg in UNSUPPORTED_GENERATION_ARGS: |
|
has_custom_gen_config_arg = ( |
|
generation_config is not None |
|
|
|
and not ( |
|
getattr(default_model_generation_config, arg) == getattr(generation_config, arg) |
|
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg) |
|
) |
|
) |
|
kwargs_has_arg = arg in kwargs and kwargs[arg] is not None |
|
if kwargs_has_arg or has_custom_gen_config_arg: |
|
raise ValueError( |
|
f"`{arg}` is set, but it's not supported in this custom generate function. List of " |
|
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}" |
|
) |
|
|
|
|
|
if model.config.is_encoder_decoder: |
|
raise ValueError("This custom generate function only works with decoder-only models") |
|
|
|
|
|
|
|
kwargs.pop("custom_generate", None) |
|
|
|
|
|
|
|
past_key_values = kwargs.pop("past_key_values", None) |
|
if past_key_values is None: |
|
past_key_values = LagKVCache(ratio=lag_ratio, sink_size=lag_sink_size, lag_size=lag_size) |
|
elif not isinstance(past_key_values, LagKVCache): |
|
raise ValueError(f"`past_key_values` must be a `LagKVCache` instance, got a {type(past_key_values)} instance") |
|
|
|
|
|
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True) |
|
return generation_outputs |
|
|