Upload folder using huggingface_hub
Browse files- .ipynb_checkpoints/README-checkpoint.md +116 -0
- README.md +116 -3
- custom_generate/.ipynb_checkpoints/generate-checkpoint.py +245 -0
- custom_generate/LICENSE +5 -0
- custom_generate/generate.py +245 -0
.ipynb_checkpoints/README-checkpoint.md
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: transformers
|
3 |
+
tags:
|
4 |
+
- custom_generate
|
5 |
+
---
|
6 |
+
|
7 |
+
# LagKV Cache
|
8 |
+
|
9 |
+
#### Introduction
|
10 |
+
|
11 |
+

|
12 |
+
|
13 |
+
LagKV is an efficient and robust KV compression algorithm. It uses lag tokens information to compress the previous ones which significantly boost the compression performance with little computation overhead.
|
14 |
+
|
15 |
+
[Original Github](https://github.com/AI-Lab-China-Merchants-Bank/LagKV)
|
16 |
+
|
17 |
+
Details are in the following work:
|
18 |
+
|
19 |
+
[LagKV: Lag-Relative Information of the KV Cache Tells Which Tokens Are Important](https://arxiv.org/abs/2504.04704)
|
20 |
+
|
21 |
+
#### How to Use
|
22 |
+
|
23 |
+
LagKV implements the Cache interface from transformers. It's easy to be integrated into the model calling function.
|
24 |
+
|
25 |
+
```python
|
26 |
+
from lag_kv import LagKV
|
27 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
28 |
+
|
29 |
+
model_path = "Qwen2.5-7B-Instruct"
|
30 |
+
device = "cuda:0"
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
32 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", attn_implementation="sdpa").to(device)
|
33 |
+
|
34 |
+
prompt = "long text"
|
35 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
36 |
+
past_key_values = LagKV(lag_size=64)
|
37 |
+
print(model.generate(input_ids, past_key_values=past_key_values))
|
38 |
+
# check KV cache size
|
39 |
+
print(past_key_values[0][0].size())
|
40 |
+
```
|
41 |
+
|
42 |
+
To compress the KV cache during the prefill stage instead of it's precisely calculated, you have to use the following inference function(for batch_size=1 only.):
|
43 |
+
|
44 |
+
```python
|
45 |
+
def inference_by_prefill_compress(model, tokenizer, inputs, max_new_tokens=256, decode=False, past_key_values=None, device="cuda"):
|
46 |
+
if isinstance(inputs, str):
|
47 |
+
input_ids = tokenizer([inputs], return_tensors="pt")["input_ids"].to(device)
|
48 |
+
else:
|
49 |
+
input_ids = inputs
|
50 |
+
if past_key_values is None:
|
51 |
+
past_key_values = LagKV(ratio=0.2,
|
52 |
+
lag_size=128,
|
53 |
+
layer_idx_skip_first=[],
|
54 |
+
use_then_compress=True)
|
55 |
+
|
56 |
+
with torch.no_grad():
|
57 |
+
sink_size = past_key_values.sink_size
|
58 |
+
lag_size = past_key_values.lag_size
|
59 |
+
trigger_len = sink_size + 2*lag_size
|
60 |
+
input_length = input_ids.shape[1]
|
61 |
+
# print(input_length > trigger_len)
|
62 |
+
if input_length > trigger_len:
|
63 |
+
start_idx = 0
|
64 |
+
end_idx = trigger_len
|
65 |
+
position_ids = torch.arange(input_length + max_new_tokens).unsqueeze(0).to(device)
|
66 |
+
def batch_input():
|
67 |
+
sel_input_ids = input_ids[:, start_idx:end_idx]
|
68 |
+
q_len = end_idx - start_idx
|
69 |
+
k_len = past_key_values.get_seq_length() + q_len
|
70 |
+
batch_size = input_ids.shape[0]
|
71 |
+
head_num = model.config.num_attention_heads
|
72 |
+
attn_mask = torch.ones((k_len, q_len),
|
73 |
+
device=input_ids.device, dtype=torch.bool)
|
74 |
+
attn_mask = torch.triu(attn_mask, diagonal=1).T
|
75 |
+
attn_mask = torch.flip(attn_mask, (0, 1))
|
76 |
+
attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)
|
77 |
+
attn_mask = attn_mask.expand(batch_size, -1, -1, -1).expand(-1, head_num, -1, -1)
|
78 |
+
attention_mask = torch.zeros((batch_size, head_num, q_len, k_len), device=input_ids.device, dtype=torch.bfloat16)
|
79 |
+
attention_mask.masked_fill_(attn_mask, -torch.inf)
|
80 |
+
return {"input_ids": sel_input_ids, "attention_mask": attention_mask}
|
81 |
+
|
82 |
+
while start_idx < input_length:
|
83 |
+
tmp_pos = position_ids[:, start_idx:end_idx]
|
84 |
+
outputs = model(**batch_input(),
|
85 |
+
past_key_values=past_key_values,
|
86 |
+
position_ids=tmp_pos,
|
87 |
+
cache_position=tmp_pos[0]
|
88 |
+
)
|
89 |
+
start_idx = end_idx
|
90 |
+
end_idx += lag_size
|
91 |
+
end_idx = min(end_idx, input_length)
|
92 |
+
|
93 |
+
new_token_id = outputs.logits[:, -1].argmax(dim=-1).unsqueeze(-1)
|
94 |
+
# print(new_token_id)
|
95 |
+
new_token_count = 1
|
96 |
+
generated_ids = [new_token_id]
|
97 |
+
while new_token_id[0][0] != tokenizer.eos_token_id and new_token_count < max_new_tokens+1:
|
98 |
+
tmp_pos = position_ids[:, (input_length+new_token_count-1):(input_length+new_token_count)]
|
99 |
+
outputs = model(new_token_id,
|
100 |
+
past_key_values=past_key_values,
|
101 |
+
position_ids=tmp_pos,
|
102 |
+
cache_position=tmp_pos[0]
|
103 |
+
)
|
104 |
+
new_token_id = outputs.logits[:, -1].argmax(dim=-1).unsqueeze(-1)
|
105 |
+
new_token_count += 1
|
106 |
+
generated_ids.append(new_token_id)
|
107 |
+
generated_ids = torch.cat(generated_ids, dim=-1)
|
108 |
+
else:
|
109 |
+
generated_ids = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens, past_key_values=past_key_values)
|
110 |
+
generated_ids = generated_ids[:, input_length:]
|
111 |
+
if decode:
|
112 |
+
output = tokenizer.batch_decode(generated_ids)
|
113 |
+
else:
|
114 |
+
output = generated_ids
|
115 |
+
return output, past_key_values
|
116 |
+
```
|
README.md
CHANGED
@@ -1,3 +1,116 @@
|
|
1 |
-
---
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: transformers
|
3 |
+
tags:
|
4 |
+
- custom_generate
|
5 |
+
---
|
6 |
+
|
7 |
+
# LagKV Cache
|
8 |
+
|
9 |
+
#### Introduction
|
10 |
+
|
11 |
+

|
12 |
+
|
13 |
+
LagKV is an efficient and robust KV compression algorithm. It uses lag tokens information to compress the previous ones which significantly boost the compression performance with little computation overhead.
|
14 |
+
|
15 |
+
[Original Github](https://github.com/AI-Lab-China-Merchants-Bank/LagKV)
|
16 |
+
|
17 |
+
Details are in the following work:
|
18 |
+
|
19 |
+
[LagKV: Lag-Relative Information of the KV Cache Tells Which Tokens Are Important](https://arxiv.org/abs/2504.04704)
|
20 |
+
|
21 |
+
#### How to Use
|
22 |
+
|
23 |
+
LagKV implements the Cache interface from transformers. It's easy to be integrated into the model calling function.
|
24 |
+
|
25 |
+
```python
|
26 |
+
from lag_kv import LagKV
|
27 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
28 |
+
|
29 |
+
model_path = "Qwen2.5-7B-Instruct"
|
30 |
+
device = "cuda:0"
|
31 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
32 |
+
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto", attn_implementation="sdpa").to(device)
|
33 |
+
|
34 |
+
prompt = "long text"
|
35 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
36 |
+
past_key_values = LagKV(lag_size=64)
|
37 |
+
print(model.generate(input_ids, past_key_values=past_key_values))
|
38 |
+
# check KV cache size
|
39 |
+
print(past_key_values[0][0].size())
|
40 |
+
```
|
41 |
+
|
42 |
+
To compress the KV cache during the prefill stage instead of it's precisely calculated, you have to use the following inference function(for batch_size=1 only.):
|
43 |
+
|
44 |
+
```python
|
45 |
+
def inference_by_prefill_compress(model, tokenizer, inputs, max_new_tokens=256, decode=False, past_key_values=None, device="cuda"):
|
46 |
+
if isinstance(inputs, str):
|
47 |
+
input_ids = tokenizer([inputs], return_tensors="pt")["input_ids"].to(device)
|
48 |
+
else:
|
49 |
+
input_ids = inputs
|
50 |
+
if past_key_values is None:
|
51 |
+
past_key_values = LagKV(ratio=0.2,
|
52 |
+
lag_size=128,
|
53 |
+
layer_idx_skip_first=[],
|
54 |
+
use_then_compress=True)
|
55 |
+
|
56 |
+
with torch.no_grad():
|
57 |
+
sink_size = past_key_values.sink_size
|
58 |
+
lag_size = past_key_values.lag_size
|
59 |
+
trigger_len = sink_size + 2*lag_size
|
60 |
+
input_length = input_ids.shape[1]
|
61 |
+
# print(input_length > trigger_len)
|
62 |
+
if input_length > trigger_len:
|
63 |
+
start_idx = 0
|
64 |
+
end_idx = trigger_len
|
65 |
+
position_ids = torch.arange(input_length + max_new_tokens).unsqueeze(0).to(device)
|
66 |
+
def batch_input():
|
67 |
+
sel_input_ids = input_ids[:, start_idx:end_idx]
|
68 |
+
q_len = end_idx - start_idx
|
69 |
+
k_len = past_key_values.get_seq_length() + q_len
|
70 |
+
batch_size = input_ids.shape[0]
|
71 |
+
head_num = model.config.num_attention_heads
|
72 |
+
attn_mask = torch.ones((k_len, q_len),
|
73 |
+
device=input_ids.device, dtype=torch.bool)
|
74 |
+
attn_mask = torch.triu(attn_mask, diagonal=1).T
|
75 |
+
attn_mask = torch.flip(attn_mask, (0, 1))
|
76 |
+
attn_mask = attn_mask.unsqueeze(0).unsqueeze(0)
|
77 |
+
attn_mask = attn_mask.expand(batch_size, -1, -1, -1).expand(-1, head_num, -1, -1)
|
78 |
+
attention_mask = torch.zeros((batch_size, head_num, q_len, k_len), device=input_ids.device, dtype=torch.bfloat16)
|
79 |
+
attention_mask.masked_fill_(attn_mask, -torch.inf)
|
80 |
+
return {"input_ids": sel_input_ids, "attention_mask": attention_mask}
|
81 |
+
|
82 |
+
while start_idx < input_length:
|
83 |
+
tmp_pos = position_ids[:, start_idx:end_idx]
|
84 |
+
outputs = model(**batch_input(),
|
85 |
+
past_key_values=past_key_values,
|
86 |
+
position_ids=tmp_pos,
|
87 |
+
cache_position=tmp_pos[0]
|
88 |
+
)
|
89 |
+
start_idx = end_idx
|
90 |
+
end_idx += lag_size
|
91 |
+
end_idx = min(end_idx, input_length)
|
92 |
+
|
93 |
+
new_token_id = outputs.logits[:, -1].argmax(dim=-1).unsqueeze(-1)
|
94 |
+
# print(new_token_id)
|
95 |
+
new_token_count = 1
|
96 |
+
generated_ids = [new_token_id]
|
97 |
+
while new_token_id[0][0] != tokenizer.eos_token_id and new_token_count < max_new_tokens+1:
|
98 |
+
tmp_pos = position_ids[:, (input_length+new_token_count-1):(input_length+new_token_count)]
|
99 |
+
outputs = model(new_token_id,
|
100 |
+
past_key_values=past_key_values,
|
101 |
+
position_ids=tmp_pos,
|
102 |
+
cache_position=tmp_pos[0]
|
103 |
+
)
|
104 |
+
new_token_id = outputs.logits[:, -1].argmax(dim=-1).unsqueeze(-1)
|
105 |
+
new_token_count += 1
|
106 |
+
generated_ids.append(new_token_id)
|
107 |
+
generated_ids = torch.cat(generated_ids, dim=-1)
|
108 |
+
else:
|
109 |
+
generated_ids = model.generate(inputs, do_sample=False, max_new_tokens=max_new_tokens, past_key_values=past_key_values)
|
110 |
+
generated_ids = generated_ids[:, input_length:]
|
111 |
+
if decode:
|
112 |
+
output = tokenizer.batch_decode(generated_ids)
|
113 |
+
else:
|
114 |
+
output = generated_ids
|
115 |
+
return output, past_key_values
|
116 |
+
```
|
custom_generate/.ipynb_checkpoints/generate-checkpoint.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 China Merchants Bank. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the MIT License (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://mit-license.org
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from transformers.cache_utils import DynamicCache
|
17 |
+
from typing import Any, Dict, List, Optional, Tuple
|
18 |
+
|
19 |
+
|
20 |
+
class LagKVCache(DynamicCache):
|
21 |
+
"""
|
22 |
+
A KV compression algorithm that as described in the [LagKV paper](https://arxiv.org/abs/2504.04704).
|
23 |
+
The algorithm equips Sink Attention and SlidingWindow like SinkCache but with additional selective tokens in the middle.
|
24 |
+
It allows the model to generate with fewer memory resource and faster decoding speed.
|
25 |
+
The model will hold the main part of information retrieval capbility during the compression, compared to a completed loss
|
26 |
+
of the SinkCache.
|
27 |
+
|
28 |
+
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
29 |
+
`[batch_size, num_heads, seq_len, head_dim]`.
|
30 |
+
|
31 |
+
For the chunked prefilling, see https://github.com/AI-Lab-China-Merchants-Bank/LagKV.
|
32 |
+
|
33 |
+
Parameters:
|
34 |
+
_distributed_cache_data:
|
35 |
+
Inherited from DynamicCache.
|
36 |
+
ratio (`float`):
|
37 |
+
The retrain ratio of tokens in the middle chunks.
|
38 |
+
sink_size (`int`):
|
39 |
+
The number of sink tokens.
|
40 |
+
lag_size (`int`):
|
41 |
+
The size of the partition. The subsequent partion will serve as a reference for the prior one.
|
42 |
+
score_v_ratio (`float`):
|
43 |
+
The ratio multiplied to the score of Value states.
|
44 |
+
skip_layer_idx (`Optional[List[int]]`):
|
45 |
+
A list of layer indices will skip the compression.
|
46 |
+
|
47 |
+
Example:
|
48 |
+
|
49 |
+
```python
|
50 |
+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, LagKVCache
|
51 |
+
|
52 |
+
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
53 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
54 |
+
|
55 |
+
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
|
56 |
+
|
57 |
+
>>> # Prepare a cache class and pass it to model's forward
|
58 |
+
>>> past_key_values = LagKVCache(ratio=0.25, lag_size=128)
|
59 |
+
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
60 |
+
>>> outputs.past_key_values # access cache filled with key/values from generation
|
61 |
+
LagKVCache()
|
62 |
+
```
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
_distributed_cache_data=None,
|
68 |
+
ratio: float = 0.25,
|
69 |
+
sink_size: int = 16,
|
70 |
+
lag_size: int = 1024,
|
71 |
+
score_v_ratio: float = 1.0,
|
72 |
+
skip_layer_idx: Optional[List[int]] = None,
|
73 |
+
):
|
74 |
+
super().__init__(_distributed_cache_data)
|
75 |
+
self.ratio = ratio
|
76 |
+
self.sink_size: int = sink_size
|
77 |
+
self.lag_size: int = lag_size
|
78 |
+
self.score_v_ratio: float = score_v_ratio
|
79 |
+
self.skip_layer_idx: List[int] = skip_layer_idx if skip_layer_idx is not None else []
|
80 |
+
self._compressed_len: List[int] = []
|
81 |
+
|
82 |
+
def update(
|
83 |
+
self,
|
84 |
+
key_states: torch.Tensor,
|
85 |
+
value_states: torch.Tensor,
|
86 |
+
layer_idx: int,
|
87 |
+
cache_kwargs=None,
|
88 |
+
):
|
89 |
+
"""
|
90 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
91 |
+
|
92 |
+
Parameters:
|
93 |
+
key_states (`torch.Tensor`):
|
94 |
+
The new key states to cache.
|
95 |
+
value_states (`torch.Tensor`):
|
96 |
+
The new value states to cache.
|
97 |
+
layer_idx (`int`):
|
98 |
+
The index of the layer to cache the states for.
|
99 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
100 |
+
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
101 |
+
|
102 |
+
Return:
|
103 |
+
A tuple containing the updated key and value states.
|
104 |
+
"""
|
105 |
+
# Update the number of seen tokens
|
106 |
+
if layer_idx == 0:
|
107 |
+
self._seen_tokens += key_states.shape[-2]
|
108 |
+
|
109 |
+
# Update the cache
|
110 |
+
if key_states is not None:
|
111 |
+
if len(self.key_cache) <= layer_idx:
|
112 |
+
# There may be skipped layers, fill them with empty lists
|
113 |
+
for _ in range(len(self.key_cache), layer_idx):
|
114 |
+
self.key_cache.append([])
|
115 |
+
self.value_cache.append([])
|
116 |
+
self._compressed_len.append(self.sink_size)
|
117 |
+
self.key_cache.append(key_states)
|
118 |
+
self.value_cache.append(value_states)
|
119 |
+
self._compressed_len.append(self.sink_size)
|
120 |
+
elif (
|
121 |
+
len(self.key_cache[layer_idx]) == 0
|
122 |
+
): # fills previously skipped layers; checking for tensor causes errors
|
123 |
+
self.key_cache[layer_idx] = key_states
|
124 |
+
self.value_cache[layer_idx] = value_states
|
125 |
+
else:
|
126 |
+
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
127 |
+
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
128 |
+
|
129 |
+
if layer_idx not in self.skip_layer_idx:
|
130 |
+
return self._compress_kv_by_lag(layer_idx)
|
131 |
+
|
132 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
133 |
+
|
134 |
+
def _get_states_score(self, base_len, in_size, end_idx, value):
|
135 |
+
"""Partition the states then calculate the state scores"""
|
136 |
+
# [batch_size, num_heads, seq_len, head_dim]
|
137 |
+
target_v = value[:, :, base_len:end_idx]
|
138 |
+
# [batch_size, num_heads, partition_num, lag_size, head_dim]
|
139 |
+
target_v = target_v.view(in_size[0], in_size[1], -1, self.lag_size, in_size[-1])
|
140 |
+
ref = target_v[:, :, 1:, :, :]
|
141 |
+
v = target_v[:, :, :-1, :, :]
|
142 |
+
|
143 |
+
min_r = ref.min(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)
|
144 |
+
max_r = ref.max(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)
|
145 |
+
|
146 |
+
score = ((v - min_r) / (max_r - min_r)).std(dim=-1).softmax(dim=-1)
|
147 |
+
|
148 |
+
return score
|
149 |
+
|
150 |
+
def _modify_kv(self, value, base_len, end_idx, selected_idx, tail_len):
|
151 |
+
# idx is offset by base_len
|
152 |
+
selected_value = torch.gather(value[:, :, base_len:end_idx], -2, selected_idx)
|
153 |
+
value = torch.cat((value[:, :, :base_len], selected_value, value[:, :, -tail_len:]), dim=-2)
|
154 |
+
return value
|
155 |
+
|
156 |
+
def _compress_algo(self, layer_idx, base_len):
|
157 |
+
"""
|
158 |
+
Calculate the scores of KV tokens in each head and partition. See the paper.
|
159 |
+
The computation overhead of top-k is significantly reduced by partitioning.
|
160 |
+
"""
|
161 |
+
in_size = self.key_cache[layer_idx].size()
|
162 |
+
end_idx = base_len + ((in_size[-2] - base_len) // self.lag_size) * self.lag_size
|
163 |
+
# [batch_size, num_heads, partition_num - 1, lag_size, head_dim]
|
164 |
+
key_score = self._get_states_score(base_len, in_size, end_idx, self.key_cache[layer_idx])
|
165 |
+
value_score = self._get_states_score(base_len, in_size, end_idx, self.value_cache[layer_idx])
|
166 |
+
score = key_score + value_score * self.score_v_ratio
|
167 |
+
# you may need to sort the index for some cases
|
168 |
+
selected_idx = torch.topk(score, int(self.ratio * self.lag_size), dim=-1).indices
|
169 |
+
for i in range(1, selected_idx.size()[2], 1):
|
170 |
+
selected_idx[:, :, i] += i * self.lag_size
|
171 |
+
selected_idx = selected_idx.reshape(in_size[0], in_size[1], -1).unsqueeze(-1).expand(-1, -1, -1, in_size[-1])
|
172 |
+
new_base_len = base_len + selected_idx.size()[-2]
|
173 |
+
# alwarys keep the last window
|
174 |
+
tail_len = self.lag_size + in_size[-2] - end_idx
|
175 |
+
self.key_cache[layer_idx] = self._modify_kv(
|
176 |
+
self.key_cache[layer_idx], base_len, end_idx, selected_idx, tail_len
|
177 |
+
)
|
178 |
+
self.value_cache[layer_idx] = self._modify_kv(
|
179 |
+
self.value_cache[layer_idx], base_len, end_idx, selected_idx, tail_len
|
180 |
+
)
|
181 |
+
self._compressed_len[layer_idx] = new_base_len
|
182 |
+
|
183 |
+
def _compress_kv_by_lag(self, layer_idx):
|
184 |
+
"""the KV cache will be used then compressed"""
|
185 |
+
kv_size = self.key_cache[layer_idx].size()
|
186 |
+
base_len = self._compressed_len[layer_idx]
|
187 |
+
|
188 |
+
keys_to_return, values_to_return = self.key_cache[layer_idx], self.value_cache[layer_idx]
|
189 |
+
if kv_size[-2] >= base_len + 2 * self.lag_size:
|
190 |
+
self._compress_algo(layer_idx, base_len)
|
191 |
+
return keys_to_return, values_to_return
|
192 |
+
|
193 |
+
def generate(model, lag_ratio=0.5, lag_sink_size=16, lag_size=128, **kwargs):
|
194 |
+
"""Custom generate function for LagKVCache.
|
195 |
+
(template from https://huggingface.co/transformers-community/sink_cache)
|
196 |
+
Args:
|
197 |
+
model (`PreTrainedModel`):
|
198 |
+
The model to generate from.
|
199 |
+
lag_ratio (`float`):
|
200 |
+
The retrain ratio of tokens in the middle chunks.
|
201 |
+
lag_sink_size (`int`):
|
202 |
+
The number of sink tokens.
|
203 |
+
lag_size (`int`):
|
204 |
+
The size of the partition. See the original paper for more information.
|
205 |
+
"""
|
206 |
+
# 1. General sanity checks
|
207 |
+
# 1.a. A few arguments are not allowed, especially arguments that control caches.
|
208 |
+
generation_config = kwargs.get("generation_config")
|
209 |
+
default_global_generation_config = GenerationConfig()
|
210 |
+
default_model_generation_config = model.generation_config
|
211 |
+
for arg in UNSUPPORTED_GENERATION_ARGS:
|
212 |
+
has_custom_gen_config_arg = (
|
213 |
+
generation_config is not None
|
214 |
+
# = and not (match global default or match model-specific default)
|
215 |
+
and not (
|
216 |
+
getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
|
217 |
+
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
|
218 |
+
)
|
219 |
+
)
|
220 |
+
kwargs_has_arg = arg in kwargs and kwargs[arg] is not None
|
221 |
+
if kwargs_has_arg or has_custom_gen_config_arg:
|
222 |
+
raise ValueError(
|
223 |
+
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
|
224 |
+
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
|
225 |
+
)
|
226 |
+
|
227 |
+
# 1.b. The model must be decoder-only
|
228 |
+
if model.config.is_encoder_decoder:
|
229 |
+
raise ValueError("This custom generate function only works with decoder-only models")
|
230 |
+
|
231 |
+
# 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
|
232 |
+
# in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
|
233 |
+
kwargs.pop("custom_generate", None)
|
234 |
+
|
235 |
+
# 2. Generate with LagKVCache
|
236 |
+
# 2.a. prepare the cache, if it was not passed.
|
237 |
+
past_key_values = kwargs.pop("past_key_values", None)
|
238 |
+
if past_key_values is None:
|
239 |
+
past_key_values = LagKVCache(ratio=lag_ratio, sink_size=lag_sink_size, lag_size=lag_size)
|
240 |
+
elif not isinstance(past_key_values, LagKVCache):
|
241 |
+
raise ValueError(f"`past_key_values` must be a `LagKVCache` instance, got a {type(past_key_values)} instance")
|
242 |
+
|
243 |
+
# 2.b. generate with the cache
|
244 |
+
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
245 |
+
return generation_outputs
|
custom_generate/LICENSE
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The MIT License (MIT)
|
2 |
+
Copyright © 2025 China Merchants Bank
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
4 |
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
5 |
+
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
custom_generate/generate.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2025 China Merchants Bank. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the MIT License (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://mit-license.org
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from transformers.cache_utils import DynamicCache
|
17 |
+
from typing import Any, Dict, List, Optional, Tuple
|
18 |
+
|
19 |
+
|
20 |
+
class LagKVCache(DynamicCache):
|
21 |
+
"""
|
22 |
+
A KV compression algorithm that as described in the [LagKV paper](https://arxiv.org/abs/2504.04704).
|
23 |
+
The algorithm equips Sink Attention and SlidingWindow like SinkCache but with additional selective tokens in the middle.
|
24 |
+
It allows the model to generate with fewer memory resource and faster decoding speed.
|
25 |
+
The model will hold the main part of information retrieval capbility during the compression, compared to a completed loss
|
26 |
+
of the SinkCache.
|
27 |
+
|
28 |
+
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
29 |
+
`[batch_size, num_heads, seq_len, head_dim]`.
|
30 |
+
|
31 |
+
For the chunked prefilling, see https://github.com/AI-Lab-China-Merchants-Bank/LagKV.
|
32 |
+
|
33 |
+
Parameters:
|
34 |
+
_distributed_cache_data:
|
35 |
+
Inherited from DynamicCache.
|
36 |
+
ratio (`float`):
|
37 |
+
The retrain ratio of tokens in the middle chunks.
|
38 |
+
sink_size (`int`):
|
39 |
+
The number of sink tokens.
|
40 |
+
lag_size (`int`):
|
41 |
+
The size of the partition. The subsequent partion will serve as a reference for the prior one.
|
42 |
+
score_v_ratio (`float`):
|
43 |
+
The ratio multiplied to the score of Value states.
|
44 |
+
skip_layer_idx (`Optional[List[int]]`):
|
45 |
+
A list of layer indices will skip the compression.
|
46 |
+
|
47 |
+
Example:
|
48 |
+
|
49 |
+
```python
|
50 |
+
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, LagKVCache
|
51 |
+
|
52 |
+
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
53 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
54 |
+
|
55 |
+
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
|
56 |
+
|
57 |
+
>>> # Prepare a cache class and pass it to model's forward
|
58 |
+
>>> past_key_values = LagKVCache(ratio=0.25, lag_size=128)
|
59 |
+
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
60 |
+
>>> outputs.past_key_values # access cache filled with key/values from generation
|
61 |
+
LagKVCache()
|
62 |
+
```
|
63 |
+
"""
|
64 |
+
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
_distributed_cache_data=None,
|
68 |
+
ratio: float = 0.25,
|
69 |
+
sink_size: int = 16,
|
70 |
+
lag_size: int = 1024,
|
71 |
+
score_v_ratio: float = 1.0,
|
72 |
+
skip_layer_idx: Optional[List[int]] = None,
|
73 |
+
):
|
74 |
+
super().__init__(_distributed_cache_data)
|
75 |
+
self.ratio = ratio
|
76 |
+
self.sink_size: int = sink_size
|
77 |
+
self.lag_size: int = lag_size
|
78 |
+
self.score_v_ratio: float = score_v_ratio
|
79 |
+
self.skip_layer_idx: List[int] = skip_layer_idx if skip_layer_idx is not None else []
|
80 |
+
self._compressed_len: List[int] = []
|
81 |
+
|
82 |
+
def update(
|
83 |
+
self,
|
84 |
+
key_states: torch.Tensor,
|
85 |
+
value_states: torch.Tensor,
|
86 |
+
layer_idx: int,
|
87 |
+
cache_kwargs=None,
|
88 |
+
):
|
89 |
+
"""
|
90 |
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
91 |
+
|
92 |
+
Parameters:
|
93 |
+
key_states (`torch.Tensor`):
|
94 |
+
The new key states to cache.
|
95 |
+
value_states (`torch.Tensor`):
|
96 |
+
The new value states to cache.
|
97 |
+
layer_idx (`int`):
|
98 |
+
The index of the layer to cache the states for.
|
99 |
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
100 |
+
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
101 |
+
|
102 |
+
Return:
|
103 |
+
A tuple containing the updated key and value states.
|
104 |
+
"""
|
105 |
+
# Update the number of seen tokens
|
106 |
+
if layer_idx == 0:
|
107 |
+
self._seen_tokens += key_states.shape[-2]
|
108 |
+
|
109 |
+
# Update the cache
|
110 |
+
if key_states is not None:
|
111 |
+
if len(self.key_cache) <= layer_idx:
|
112 |
+
# There may be skipped layers, fill them with empty lists
|
113 |
+
for _ in range(len(self.key_cache), layer_idx):
|
114 |
+
self.key_cache.append([])
|
115 |
+
self.value_cache.append([])
|
116 |
+
self._compressed_len.append(self.sink_size)
|
117 |
+
self.key_cache.append(key_states)
|
118 |
+
self.value_cache.append(value_states)
|
119 |
+
self._compressed_len.append(self.sink_size)
|
120 |
+
elif (
|
121 |
+
len(self.key_cache[layer_idx]) == 0
|
122 |
+
): # fills previously skipped layers; checking for tensor causes errors
|
123 |
+
self.key_cache[layer_idx] = key_states
|
124 |
+
self.value_cache[layer_idx] = value_states
|
125 |
+
else:
|
126 |
+
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
127 |
+
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
128 |
+
|
129 |
+
if layer_idx not in self.skip_layer_idx:
|
130 |
+
return self._compress_kv_by_lag(layer_idx)
|
131 |
+
|
132 |
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
133 |
+
|
134 |
+
def _get_states_score(self, base_len, in_size, end_idx, value):
|
135 |
+
"""Partition the states then calculate the state scores"""
|
136 |
+
# [batch_size, num_heads, seq_len, head_dim]
|
137 |
+
target_v = value[:, :, base_len:end_idx]
|
138 |
+
# [batch_size, num_heads, partition_num, lag_size, head_dim]
|
139 |
+
target_v = target_v.view(in_size[0], in_size[1], -1, self.lag_size, in_size[-1])
|
140 |
+
ref = target_v[:, :, 1:, :, :]
|
141 |
+
v = target_v[:, :, :-1, :, :]
|
142 |
+
|
143 |
+
min_r = ref.min(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)
|
144 |
+
max_r = ref.max(dim=-2).values.unsqueeze(-2).expand(-1, -1, -1, self.lag_size, -1)
|
145 |
+
|
146 |
+
score = ((v - min_r) / (max_r - min_r)).std(dim=-1).softmax(dim=-1)
|
147 |
+
|
148 |
+
return score
|
149 |
+
|
150 |
+
def _modify_kv(self, value, base_len, end_idx, selected_idx, tail_len):
|
151 |
+
# idx is offset by base_len
|
152 |
+
selected_value = torch.gather(value[:, :, base_len:end_idx], -2, selected_idx)
|
153 |
+
value = torch.cat((value[:, :, :base_len], selected_value, value[:, :, -tail_len:]), dim=-2)
|
154 |
+
return value
|
155 |
+
|
156 |
+
def _compress_algo(self, layer_idx, base_len):
|
157 |
+
"""
|
158 |
+
Calculate the scores of KV tokens in each head and partition. See the paper.
|
159 |
+
The computation overhead of top-k is significantly reduced by partitioning.
|
160 |
+
"""
|
161 |
+
in_size = self.key_cache[layer_idx].size()
|
162 |
+
end_idx = base_len + ((in_size[-2] - base_len) // self.lag_size) * self.lag_size
|
163 |
+
# [batch_size, num_heads, partition_num - 1, lag_size, head_dim]
|
164 |
+
key_score = self._get_states_score(base_len, in_size, end_idx, self.key_cache[layer_idx])
|
165 |
+
value_score = self._get_states_score(base_len, in_size, end_idx, self.value_cache[layer_idx])
|
166 |
+
score = key_score + value_score * self.score_v_ratio
|
167 |
+
# you may need to sort the index for some cases
|
168 |
+
selected_idx = torch.topk(score, int(self.ratio * self.lag_size), dim=-1).indices
|
169 |
+
for i in range(1, selected_idx.size()[2], 1):
|
170 |
+
selected_idx[:, :, i] += i * self.lag_size
|
171 |
+
selected_idx = selected_idx.reshape(in_size[0], in_size[1], -1).unsqueeze(-1).expand(-1, -1, -1, in_size[-1])
|
172 |
+
new_base_len = base_len + selected_idx.size()[-2]
|
173 |
+
# alwarys keep the last window
|
174 |
+
tail_len = self.lag_size + in_size[-2] - end_idx
|
175 |
+
self.key_cache[layer_idx] = self._modify_kv(
|
176 |
+
self.key_cache[layer_idx], base_len, end_idx, selected_idx, tail_len
|
177 |
+
)
|
178 |
+
self.value_cache[layer_idx] = self._modify_kv(
|
179 |
+
self.value_cache[layer_idx], base_len, end_idx, selected_idx, tail_len
|
180 |
+
)
|
181 |
+
self._compressed_len[layer_idx] = new_base_len
|
182 |
+
|
183 |
+
def _compress_kv_by_lag(self, layer_idx):
|
184 |
+
"""the KV cache will be used then compressed"""
|
185 |
+
kv_size = self.key_cache[layer_idx].size()
|
186 |
+
base_len = self._compressed_len[layer_idx]
|
187 |
+
|
188 |
+
keys_to_return, values_to_return = self.key_cache[layer_idx], self.value_cache[layer_idx]
|
189 |
+
if kv_size[-2] >= base_len + 2 * self.lag_size:
|
190 |
+
self._compress_algo(layer_idx, base_len)
|
191 |
+
return keys_to_return, values_to_return
|
192 |
+
|
193 |
+
def generate(model, lag_ratio=0.5, lag_sink_size=16, lag_size=128, **kwargs):
|
194 |
+
"""Custom generate function for LagKVCache.
|
195 |
+
(template from https://huggingface.co/transformers-community/sink_cache)
|
196 |
+
Args:
|
197 |
+
model (`PreTrainedModel`):
|
198 |
+
The model to generate from.
|
199 |
+
lag_ratio (`float`):
|
200 |
+
The retrain ratio of tokens in the middle chunks.
|
201 |
+
lag_sink_size (`int`):
|
202 |
+
The number of sink tokens.
|
203 |
+
lag_size (`int`):
|
204 |
+
The size of the partition. See the original paper for more information.
|
205 |
+
"""
|
206 |
+
# 1. General sanity checks
|
207 |
+
# 1.a. A few arguments are not allowed, especially arguments that control caches.
|
208 |
+
generation_config = kwargs.get("generation_config")
|
209 |
+
default_global_generation_config = GenerationConfig()
|
210 |
+
default_model_generation_config = model.generation_config
|
211 |
+
for arg in UNSUPPORTED_GENERATION_ARGS:
|
212 |
+
has_custom_gen_config_arg = (
|
213 |
+
generation_config is not None
|
214 |
+
# = and not (match global default or match model-specific default)
|
215 |
+
and not (
|
216 |
+
getattr(default_model_generation_config, arg) == getattr(generation_config, arg)
|
217 |
+
or getattr(default_global_generation_config, arg) == getattr(generation_config, arg)
|
218 |
+
)
|
219 |
+
)
|
220 |
+
kwargs_has_arg = arg in kwargs and kwargs[arg] is not None
|
221 |
+
if kwargs_has_arg or has_custom_gen_config_arg:
|
222 |
+
raise ValueError(
|
223 |
+
f"`{arg}` is set, but it's not supported in this custom generate function. List of "
|
224 |
+
f"unsupported arguments: {UNSUPPORTED_GENERATION_ARGS}"
|
225 |
+
)
|
226 |
+
|
227 |
+
# 1.b. The model must be decoder-only
|
228 |
+
if model.config.is_encoder_decoder:
|
229 |
+
raise ValueError("This custom generate function only works with decoder-only models")
|
230 |
+
|
231 |
+
# 1.c. compatibility with transformers 4.52: we must pop `custom_generate` from kwargs, otherwise it will result
|
232 |
+
# in an infinite loop when we call `model.generate`. This is solved in transformers 4.53.
|
233 |
+
kwargs.pop("custom_generate", None)
|
234 |
+
|
235 |
+
# 2. Generate with LagKVCache
|
236 |
+
# 2.a. prepare the cache, if it was not passed.
|
237 |
+
past_key_values = kwargs.pop("past_key_values", None)
|
238 |
+
if past_key_values is None:
|
239 |
+
past_key_values = LagKVCache(ratio=lag_ratio, sink_size=lag_sink_size, lag_size=lag_size)
|
240 |
+
elif not isinstance(past_key_values, LagKVCache):
|
241 |
+
raise ValueError(f"`past_key_values` must be a `LagKVCache` instance, got a {type(past_key_values)} instance")
|
242 |
+
|
243 |
+
# 2.b. generate with the cache
|
244 |
+
generation_outputs = model.generate(**kwargs, past_key_values=past_key_values, use_cache=True)
|
245 |
+
return generation_outputs
|