BigDong commited on
Commit
f37d55b
·
1 Parent(s): 84eb0ec

update readme and modeling model

Browse files
Files changed (2) hide show
  1. README.md +326 -0
  2. modeling_minicpm.py +277 -565
README.md CHANGED
@@ -18,3 +18,329 @@ library_name: transformers
18
  <p align="center">
19
  👋 Contact us in <a href="https://discord.gg/3cGQn9b3YM" target="_blank">Discord</a> and <a href="https://github.com/OpenBMB/MiniCPM/blob/main/assets/wechat.jpg" target="_blank">WeChat</a>
20
  </p>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  <p align="center">
19
  👋 Contact us in <a href="https://discord.gg/3cGQn9b3YM" target="_blank">Discord</a> and <a href="https://github.com/OpenBMB/MiniCPM/blob/main/assets/wechat.jpg" target="_blank">WeChat</a>
20
  </p>
21
+
22
+ ## What's New
23
+ - [2025.09.05] **MiniCPM4.1** series are released! This series is a hybrid reasoning model, which can be used in
24
+ both deep reasoning mode and non-reasoning mode. 🔥🔥🔥
25
+ - [2025.06.06] **MiniCPM4** series are released! This model achieves ultimate efficiency improvements while maintaining optimal performance at the same scale! It can achieve over 5x generation acceleration on typical end-side chips! You can find technical report [here](https://github.com/OpenBMB/MiniCPM/tree/main/report/MiniCPM_4_Technical_Report.pdf).🔥🔥🔥
26
+
27
+ ## MiniCPM4 and MiniCPM4.1 Series
28
+ MiniCPM4 and MiniCPM4.1 series are highly efficient large language models (LLMs) designed explicitly for end-side devices, which achieves this efficiency through systematic innovation in four key dimensions: model architecture, training data, training algorithms, and inference systems.
29
+ - [MiniCPM4.1-8B](https://huggingface.co/openbmb/MiniCPM4.1-8B): The latest version of MiniCPM4, with 8B parameters, support fusion thinking. (**<-- you are here**)
30
+ - [MiniCPM4.1-8B-GPTQ](https://huggingface.co/openbmb/MiniCPM4.1-8B-GPTQ): MiniCPM4.1-8B in GPTQ format.
31
+ - [MiniCPM4.1-8B-AutoAWQ](https://huggingface.co/openbmb/MiniCPM4.1-8B-AutoAWQ): MiniCPM4.1-8B in AutoAWQ format.
32
+ - [MiniCPM-4.1-8B-Marlin](https://huggingface.co/openbmb/MiniCPM-4.1-8B-Marlin): MiniCPM4.1-8B in Marlin format.
33
+ - [MiniCPM4.1-8B-GGUF](https://huggingface.co/openbmb/MiniCPM4.1-8B-GGUF): MiniCPM4.1-8B in GGUF format.
34
+ - [MiniCPM4.1-8B-MLX](https://huggingface.co/openbmb/MiniCPM4.1-8B-MLX): MiniCPM4.1-8B in MLX format.
35
+ - [MiniCPM4.1-8B-Eagle3](https://huggingface.co/openbmb/MiniCPM4.1-8B-Eagle3): Eagle3 model for MiniCPM4.1-8B.
36
+ - **MiniCPM4 Series**
37
+ <details>
38
+ <summary>Click to expand all MiniCPM4 series models</summary>
39
+
40
+ - [**MiniCPM4-8B**](https://huggingface.co/openbmb/MiniCPM4-8B): The flagship model with 8B parameters, trained on 8T tokens
41
+ - [**MiniCPM4-0.5B**](https://huggingface.co/openbmb/MiniCPM4-0.5B): Lightweight version with 0.5B parameters, trained on 1T tokens
42
+ - [**MiniCPM4-8B-Eagle-FRSpec**](https://huggingface.co/openbmb/MiniCPM4-8B-Eagle-FRSpec): Eagle head for FRSpec, accelerating speculative inference
43
+ - [**MiniCPM4-8B-Eagle-FRSpec-QAT-cpmcu**](https://huggingface.co/openbmb/MiniCPM4-8B-Eagle-FRSpec-QAT-cpmcu): Eagle head with QAT for FRSpec, integrating speculation and quantization for ultra acceleration
44
+ - [**MiniCPM4-8B-Eagle-vLLM**](https://huggingface.co/openbmb/MiniCPM4-8B-Eagle-vLLM): Eagle head in vLLM format for speculative inference
45
+ - [**MiniCPM4-8B-marlin-Eagle-vLLM**](https://huggingface.co/openbmb/MiniCPM4-8B-marlin-Eagle-vLLM): Quantized Eagle head for vLLM format
46
+ - [**BitCPM4-0.5B**](https://huggingface.co/openbmb/BitCPM4-0.5B): Extreme ternary quantization of MiniCPM4-0.5B, achieving 90% bit width reduction
47
+ - [**BitCPM4-1B**](https://huggingface.co/openbmb/BitCPM4-1B): Extreme ternary quantization of MiniCPM3-1B, achieving 90% bit width reduction
48
+ - [**MiniCPM4-Survey**](https://huggingface.co/openbmb/MiniCPM4-Survey): Generates trustworthy, long-form survey papers from user queries
49
+ - [**MiniCPM4-MCP**](https://huggingface.co/openbmb/MiniCPM4-MCP): Integrates MCP tools to autonomously satisfy user requirements
50
+ </details>
51
+
52
+ ## Introduction
53
+ MiniCPM4 and MiniCPM4.1 are extremely efficient edge-side large model that has undergone efficient optimization across four dimensions: model architecture, learning algorithms, training data, and inference systems, achieving ultimate efficiency improvements.
54
+
55
+ - 🏗️ **Efficient Model Architecture:**
56
+ - InfLLM v2 -- Trainable Sparse Attention Mechanism: Adopts a trainable sparse attention mechanism architecture where each token only needs to compute relevance with less than 5% of tokens in 128K long text processing, significantly reducing computational overhead for long texts
57
+
58
+ - 🧠 **Efficient Learning Algorithms:**
59
+ - Model Wind Tunnel 2.0 -- Efficient Predictable Scaling: Introduces scaling prediction methods for performance of downstream tasks, enabling more precise model training configuration search
60
+ - BitCPM -- Ultimate Ternary Quantization: Compresses model parameter bit-width to 3 values, achieving 90% extreme model bit-width reduction
61
+ - Efficient Training Engineering Optimization: Adopts FP8 low-precision computing technology combined with Multi-token Prediction training strategy
62
+
63
+ - 📚 **High-Quality Training Data:**
64
+ - UltraClean -- High-quality Pre-training Data Filtering and Generation: Builds iterative data cleaning strategies based on efficient data verification, open-sourcing high-quality Chinese and English pre-training dataset [UltraFinweb](https://huggingface.co/datasets/openbmb/Ultra-FineWeb)
65
+ - UltraChat v2 -- High-quality Supervised Fine-tuning Data Generation: Constructs large-scale high-quality supervised fine-tuning datasets covering multiple dimensions including knowledge-intensive data, reasoning-intensive data, instruction-following data, long text understanding data, and tool calling data
66
+
67
+ - ⚡ **Efficient Inference System:**
68
+ - CPM.cu -- Lightweight and Efficient CUDA Inference Framework: Integrates sparse attention, model quantization, and speculative sampling to achieve efficient prefilling and decoding
69
+ - ArkInfer -- Cross-platform Deployment System: Supports efficient deployment across multiple backend environments, providing flexible cross-platform adaptation capabilities
70
+
71
+ ## Usage
72
+
73
+ ### Inference with [CPM.cu](https://github.com/OpenBMB/cpm.cu)
74
+
75
+ We recommend using [CPM.cu](https://github.com/OpenBMB/cpm.cu) for the inference of MiniCPM4 and MiniCPM4.1. CPM.cu is a CUDA inference framework developed by OpenBMB, which integrates efficient sparse, speculative sampling, and quantization techniques, fully leveraging the efficiency advantages of MiniCPM4 and MiniCPM4.1.
76
+
77
+ You can install CPM.cu by running the following command:
78
+
79
+ ```bash
80
+ git clone https://github.com/OpenBMB/cpm.cu.git --recursive
81
+ cd cpm.cu
82
+ python3 setup.py install
83
+ ```
84
+
85
+ MiniCPM4.1 natively supports context lengths of up to 65,536(64k) tokens. To reproduce the long-text acceleration effect in the paper, we recommend using the LongRoPE factors that have been validated. Change the `rope_scaling` field in the `config.json` file as the following to enable LongRoPE.
86
+ ```json
87
+ {
88
+ ...,
89
+ "rope_scaling": {
90
+ "rope_type": "longrope",
91
+ "long_factor": [0.9977997200264581, 1.014658295992452, 1.0349680404997148, 1.059429246056193, 1.0888815016813513, 1.1243301355211495, 1.166977103606075, 1.2182568066927284, 1.2798772354275727, 1.3538666751582975, 1.4426259039919596, 1.5489853358570191, 1.6762658237220625, 1.8283407612492941, 2.0096956085876183, 2.225478927469756, 2.481536379650452, 2.784415934557119, 3.1413289096347365, 3.560047844772632, 4.048719380066383, 4.752651957515948, 5.590913044973868, 6.584005926629993, 7.7532214876576155, 9.119754865903639, 10.704443927019176, 12.524994176518703, 14.59739595363613, 16.93214476166354, 19.53823297353041, 22.417131025031697, 25.568260840911098, 28.991144156566317, 32.68408069090375, 36.65174474170465, 40.90396065611201, 45.4664008671033, 50.37147343433591, 55.6804490772103, 61.470816952306556, 67.8622707390618, 75.00516023410414, 83.11898235973767, 92.50044360202462, 103.57086856690864, 116.9492274587385, 118.16074567836519, 119.18497548708795, 120.04810876261652, 120.77352815196981, 121.38182790207875, 121.89094985353891, 122.31638758099915, 122.6714244963338, 122.9673822552567, 123.21386397019609, 123.41898278254268, 123.58957065488238, 123.73136519024158, 123.84917421274221, 123.94701903496814, 124.02825801299717, 124.09569231686116],
92
+ "short_factor": [0.9977997200264581, 1.014658295992452, 1.0349680404997148, 1.059429246056193, 1.0888815016813513, 1.1243301355211495, 1.166977103606075, 1.2182568066927284, 1.2798772354275727, 1.3538666751582975, 1.4426259039919596, 1.5489853358570191, 1.6762658237220625, 1.8283407612492941, 2.0096956085876183, 2.225478927469756, 2.481536379650452, 2.784415934557119, 3.1413289096347365, 3.560047844772632, 4.048719380066383, 4.752651957515948, 5.590913044973868, 6.584005926629993, 7.7532214876576155, 9.119754865903639, 10.704443927019176, 12.524994176518703, 14.59739595363613, 16.93214476166354, 19.53823297353041, 22.417131025031697, 25.568260840911098, 28.991144156566317, 32.68408069090375, 36.65174474170465, 40.90396065611201, 45.4664008671033, 50.37147343433591, 55.6804490772103, 61.470816952306556, 67.8622707390618, 75.00516023410414, 83.11898235973767, 92.50044360202462, 103.57086856690864, 116.9492274587385, 118.16074567836519, 119.18497548708795, 120.04810876261652, 120.77352815196981, 121.38182790207875, 121.89094985353891, 122.31638758099915, 122.6714244963338, 122.9673822552567, 123.21386397019609, 123.41898278254268, 123.58957065488238, 123.73136519024158, 123.84917421274221, 123.94701903496814, 124.02825801299717, 124.09569231686116],
93
+ "original_max_position_embeddings": 32768
94
+ }
95
+ }
96
+ ```
97
+
98
+ After modification, you can run the following command to reproduce the long-context acceleration effect (the script will automatically download the model weights from HuggingFace)
99
+ ```bash
100
+ python3 tests/test_generate.py
101
+ ```
102
+
103
+ For more details about CPM.cu, please refer to [the repo CPM.cu](https://github.com/OpenBMB/cpm.cu).
104
+
105
+ ### Hybird Reasoning Mode
106
+
107
+ MiniCPM4.1 supports hybrid reasoning mode, which can be used in both deep reasoning mode and non-reasoning mode. To enable hybrid reasoning mode. User can set `enable_thinking=True` in `tokenizer.apply_chat_template` to enable hybrid reasoning mode, and set `enable_thinking=False` to enable non-reasoning mode. Similarly, user can directly add `\no_think` at the end of the query to enable non-reasoning mode. If not add any special token or add `\think` at the end of the query, the model will enable reasoning mode.
108
+
109
+ ```python
110
+ # Enable reasoning mode
111
+ prompt_text = tokenizer.apply_chat_template(
112
+ messages,
113
+ tokenize=False,
114
+ add_generation_prompt=True,
115
+ enable_thinking=True
116
+ )
117
+ # Enable non-reasoning mode
118
+ prompt_text = tokenizer.apply_chat_template(
119
+ messages,
120
+ tokenize=False,
121
+ add_generation_prompt=True,
122
+ enable_thinking=False
123
+ )
124
+ ```
125
+
126
+ ### Inference with Transformers
127
+ ```python
128
+ from transformers import AutoModelForCausalLM, AutoTokenizer
129
+ import torch
130
+ torch.manual_seed(0)
131
+
132
+ path = 'openbmb/MiniCPM4.1-8B'
133
+ device = "cuda"
134
+ tokenizer = AutoTokenizer.from_pretrained(path)
135
+ model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True)
136
+
137
+ # User can directly use the chat interface
138
+ # responds, history = model.chat(tokenizer, "Write an article about Artificial Intelligence.", temperature=0.7, top_p=0.7)
139
+ # print(responds)
140
+
141
+ # User can also use the generate interface
142
+ messages = [
143
+ {"role": "user", "content": "Write an article about Artificial Intelligence."},
144
+ ]
145
+ prompt_text = tokenizer.apply_chat_template(
146
+ messages,
147
+ tokenize=False,
148
+ add_generation_prompt=True,
149
+ )
150
+ model_inputs = tokenizer([prompt_text], return_tensors="pt").to(device)
151
+
152
+ model_outputs = model.generate(
153
+ **model_inputs,
154
+ max_new_tokens=8192,
155
+ top_p=0.7,
156
+ temperature=0.7
157
+ )
158
+ output_token_ids = [
159
+ model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs['input_ids']))
160
+ ]
161
+
162
+ responses = tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0]
163
+ print(responses)
164
+ ```
165
+
166
+ MiniCPM4.1-8B supports `InfLLM v2`, a sparse attention mechanism designed for efficient long-sequence inference. It requires the [infllmv2_cuda_impl](https://github.com/OpenBMB/infllmv2_cuda_impl) library.
167
+
168
+ You can install it by running the following command:
169
+ ```bash
170
+ git clone -b feature_infer https://github.com/OpenBMB/infllmv2_cuda_impl.git
171
+ cd infllmv2_cuda_impl
172
+ git submodule update --init --recursive
173
+ pip install -e . # or python setup.py install
174
+ ```
175
+
176
+ To enable InfLLM v2, you need to add the `sparse_config` field in `config.json`:
177
+ ```json
178
+ {
179
+ ...,
180
+ "sparse_config": {
181
+ "kernel_size": 32,
182
+ "kernel_stride": 16,
183
+ "init_blocks": 1,
184
+ "block_size": 64,
185
+ "window_size": 2048,
186
+ "topk": 64,
187
+ "use_nope": false,
188
+ "dense_len": 8192
189
+ }
190
+ }
191
+ ```
192
+
193
+ These parameters control the behavior of InfLLM v2:
194
+ * `kernel_size` (default: 32): The size of semantic kernels.
195
+ * `kernel_stride` (default: 16): The stride between adjacent kernels.
196
+ * `init_blocks` (default: 1): The number of initial blocks that every query token attends to. This ensures attention to the beginning of the sequence.
197
+ * `block_size` (default: 64): The block size for key-value blocks.
198
+ * `window_size` (default: 2048): The size of the local sliding window.
199
+ * `topk` (default: 64): The specifies that each token computes attention with only the top-k most relevant key-value blocks.
200
+ * `use_nope` (default: false): Whether to use the NOPE technique in block selection for improved performance.
201
+ * `dense_len` (default: 8192): Since Sparse Attention offers limited benefits for short sequences, the model can use standard (dense) attention for shorter texts. The model will use dense attention for sequences with a token length below `dense_len` and switch to sparse attention for sequences exceeding this length. Set this to `-1` to always use sparse attention regardless of sequence length.
202
+
203
+ MiniCPM4.1 natively supports context lengths of up to 65,536(64k) tokens. For conversations where the total length (including both input and output) significantly exceeds this limit, we recommend using RoPE scaling techniques for effective handling of long texts. We have validated the model's performance on context lengths of up to 131,072 tokens by modifying the LongRoPE factor.
204
+
205
+ You can apply the LongRoPE factor modification by modifying the model files. Specifically, in the `config.json` file, adjust the `rope_scaling` fields.
206
+ ```json
207
+ {
208
+ ...,
209
+ "rope_scaling": {
210
+ "rope_type": "longrope",
211
+ "long_factor": [0.9977997200264581, 1.014658295992452, 1.0349680404997148, 1.059429246056193, 1.0888815016813513, 1.1243301355211495, 1.166977103606075, 1.2182568066927284, 1.2798772354275727, 1.3538666751582975, 1.4426259039919596, 1.5489853358570191, 1.6762658237220625, 1.8283407612492941, 2.0096956085876183, 2.225478927469756, 2.481536379650452, 2.784415934557119, 3.1413289096347365, 3.560047844772632, 4.048719380066383, 4.752651957515948, 5.590913044973868, 6.584005926629993, 7.7532214876576155, 9.119754865903639, 10.704443927019176, 12.524994176518703, 14.59739595363613, 16.93214476166354, 19.53823297353041, 22.417131025031697, 25.568260840911098, 28.991144156566317, 32.68408069090375, 36.65174474170465, 40.90396065611201, 45.4664008671033, 50.37147343433591, 55.6804490772103, 61.470816952306556, 67.8622707390618, 75.00516023410414, 83.11898235973767, 92.50044360202462, 103.57086856690864, 116.9492274587385, 118.16074567836519, 119.18497548708795, 120.04810876261652, 120.77352815196981, 121.38182790207875, 121.89094985353891, 122.31638758099915, 122.6714244963338, 122.9673822552567, 123.21386397019609, 123.41898278254268, 123.58957065488238, 123.73136519024158, 123.84917421274221, 123.94701903496814, 124.02825801299717, 124.09569231686116],
212
+ "short_factor": [0.9977997200264581, 1.014658295992452, 1.0349680404997148, 1.059429246056193, 1.0888815016813513, 1.1243301355211495, 1.166977103606075, 1.2182568066927284, 1.2798772354275727, 1.3538666751582975, 1.4426259039919596, 1.5489853358570191, 1.6762658237220625, 1.8283407612492941, 2.0096956085876183, 2.225478927469756, 2.481536379650452, 2.784415934557119, 3.1413289096347365, 3.560047844772632, 4.048719380066383, 4.752651957515948, 5.590913044973868, 6.584005926629993, 7.7532214876576155, 9.119754865903639, 10.704443927019176, 12.524994176518703, 14.59739595363613, 16.93214476166354, 19.53823297353041, 22.417131025031697, 25.568260840911098, 28.991144156566317, 32.68408069090375, 36.65174474170465, 40.90396065611201, 45.4664008671033, 50.37147343433591, 55.6804490772103, 61.470816952306556, 67.8622707390618, 75.00516023410414, 83.11898235973767, 92.50044360202462, 103.57086856690864, 116.9492274587385, 118.16074567836519, 119.18497548708795, 120.04810876261652, 120.77352815196981, 121.38182790207875, 121.89094985353891, 122.31638758099915, 122.6714244963338, 122.9673822552567, 123.21386397019609, 123.41898278254268, 123.58957065488238, 123.73136519024158, 123.84917421274221, 123.94701903496814, 124.02825801299717, 124.09569231686116],
213
+ "original_max_position_embeddings": 32768
214
+ }
215
+ }
216
+ ```
217
+
218
+ ### Inference with [SGLang](https://github.com/sgl-project/sglang)
219
+
220
+ For now, you need to install our forked version of SGLang.
221
+ ```bash
222
+ git clone -b openbmb https://github.com/OpenBMB/sglang.git
223
+ cd sglang
224
+
225
+ pip install --upgrade pip
226
+ pip install -e "python[all]"
227
+ ```
228
+
229
+ You can start the inference server by running the following command:
230
+ ```bash
231
+ python -m sglang.launch_server --model openbmb/MiniCPM4.1-8B --trust-remote-code --port 30000 --chat-template chatml
232
+ ```
233
+
234
+ Then you can use the chat interface by running the following command:
235
+ ```python
236
+ import openai
237
+
238
+ client = openai.Client(base_url=f"http://localhost:30000/v1", api_key="None")
239
+
240
+ response = client.chat.completions.create(
241
+ model="openbmb/MiniCPM4.1-8B",
242
+ messages=[
243
+ {"role": "user", "content": "Write an article about Artificial Intelligence."},
244
+ ],
245
+ temperature=0.7,
246
+ max_tokens=8192,
247
+ )
248
+
249
+ print(response.choices[0].message.content)
250
+ ```
251
+
252
+ ### Inference with [vLLM](https://github.com/vllm-project/vllm)
253
+ For now, you need to install the latest version of vLLM.
254
+ ```
255
+ pip install -U vllm \
256
+ --pre \
257
+ --extra-index-url https://wheels.vllm.ai/nightly
258
+ ```
259
+
260
+ Then you can inference MiniCPM4.1-8B with vLLM:
261
+ ```python
262
+ from transformers import AutoTokenizer
263
+ from vllm import LLM, SamplingParams
264
+
265
+ model_name = "openbmb/MiniCPM4.1-8B"
266
+ prompt = [{"role": "user", "content": "Please recommend 5 tourist attractions in Beijing. "}]
267
+
268
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
269
+ input_text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
270
+
271
+ llm = LLM(
272
+ model=model_name,
273
+ trust_remote_code=True,
274
+ max_num_batched_tokens=32768,
275
+ dtype="bfloat16",
276
+ gpu_memory_utilization=0.8,
277
+ )
278
+ sampling_params = SamplingParams(top_p=0.7, temperature=0.7, max_tokens=1024, repetition_penalty=1.02)
279
+
280
+ outputs = llm.generate(prompts=input_text, sampling_params=sampling_params)
281
+
282
+ print(outputs[0].outputs[0].text)
283
+ ```
284
+
285
+ Also, you can start the inference server by running the following command:
286
+ > **Note**: In vLLM's chat API, `add_special_tokens` is `False` by default. This means important special tokens—such as the beginning-of-sequence (BOS) token—will not be added automatically. To ensure the input prompt is correctly formatted for the model, you should explicitly set `extra_body={"add_special_tokens": True}`.
287
+
288
+ ```bash
289
+ vllm serve openbmb/MiniCPM4.1-8B
290
+ ```
291
+
292
+ Then you can use the chat interface by running the following code:
293
+
294
+ ```python
295
+ import openai
296
+
297
+ client = openai.Client(base_url="http://localhost:8000/v1", api_key="EMPTY")
298
+
299
+ response = client.chat.completions.create(
300
+ model="openbmb/MiniCPM4.1-8B",
301
+ messages=[
302
+ {"role": "user", "content": "Write an article about Artificial Intelligence."},
303
+ ],
304
+ temperature=0.7,
305
+ max_tokens=1024,
306
+ extra_body=dict(add_special_tokens=True), # Ensures special tokens are added for chat template
307
+
308
+ )
309
+
310
+ print(response.choices[0].message.content)
311
+ ```
312
+
313
+ ## Evaluation Results
314
+ On two typical end-side chips, Jetson AGX Orin and RTX 4090, MiniCPM4 demonstrates significantly faster processing speed compared to similar-size models in long text processing tasks. As text length increases, MiniCPM4's efficiency advantage becomes more pronounced. On the Jetson AGX Orin platform, compared to Qwen3-8B, MiniCPM4 achieves approximately 7x decoding speed improvement.
315
+
316
+ ![benchmark](https://github.com/OpenBMB/MiniCPM/blob/main/assets/minicpm4/efficiency.png?raw=true)
317
+
318
+ #### Comprehensive Evaluation
319
+ MiniCPM4.1 launches end-side versions with 8B parameter scale, both achieving best-in-class performance in their respective categories.
320
+
321
+ ![benchmark](https://github.com/OpenBMB/MiniCPM/blob/main/assets/minicpm4/benchmark4.1.png?raw=true)
322
+
323
+ #### Long Text Evaluation
324
+ MiniCPM4 is pre-trained on 32K long texts and achieves length extension through YaRN technology. In the 128K long text needle-in-a-haystack task, MiniCPM4 demonstrates outstanding performance.
325
+
326
+ ![long-niah](https://github.com/OpenBMB/MiniCPM/blob/main/assets/minicpm4/128k-niah.png?raw=true)
327
+
328
+ ## Statement
329
+ - As a language model, MiniCPM generates content by learning from a vast amount of text.
330
+ - However, it does not possess the ability to comprehend or express personal opinions or value judgments.
331
+ - Any content generated by MiniCPM does not represent the viewpoints or positions of the model developers.
332
+ - Therefore, when using content generated by MiniCPM, users should take full responsibility for evaluating and verifying it on their own.
333
+
334
+ ## LICENSE
335
+ - This repository and MiniCPM models are released under the [Apache-2.0](https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE) License.
336
+
337
+ ## Citation
338
+ - Please cite our [paper](https://github.com/OpenBMB/MiniCPM/tree/main/report/MiniCPM_4_Technical_Report.pdf) if you find our work valuable.
339
+
340
+ ```bibtex
341
+ @article{minicpm4,
342
+ title={{MiniCPM4}: Ultra-Efficient LLMs on End Devices},
343
+ author={MiniCPM Team},
344
+ year={2025}
345
+ }
346
+ ```
modeling_minicpm.py CHANGED
@@ -24,7 +24,7 @@ import torch.utils.checkpoint
24
  from torch import nn
25
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
  from transformers.activations import ACT2FN
27
- from transformers.cache_utils import Cache, DynamicCache
28
  from transformers.modeling_attn_mask_utils import (
29
  AttentionMaskConverter,
30
  _prepare_4d_attention_mask,
@@ -57,6 +57,7 @@ try:
57
  infllmv2_attn_varlen_func,
58
  infllmv2_attn_with_kvcache,
59
  max_pooling_1d,
 
60
  )
61
  except:
62
  pass
@@ -79,8 +80,7 @@ def compressed_attention(
79
  sm_scale: float = None,
80
  init_blocks: int = 1,
81
  local_blocks: int = 2,
82
- parallel_topk_compute: Union[str, bool] = 'auto',
83
- total_seq_lens=-1,
84
  ) -> Tuple[torch.Tensor, torch.Tensor]:
85
  """Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
86
 
@@ -99,31 +99,32 @@ def compressed_attention(
99
  sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
100
  init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
101
  local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
102
- parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug.
103
- We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise.
104
 
105
  Returns:
106
  Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
107
  """
108
  with torch.no_grad():
109
- cache_len = 0
110
  batch_size = cu_seqlens_q.shape[0] - 1
111
- if total_seq_lens == -1:
112
- total_seq_lens = max_seqlen_q
113
- q_idx = torch.cat(
114
- [
115
- torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device) + total_seq_lens - (cu_seqlens_q[i + 1] - cu_seqlens_q[i])
116
- for i in range(batch_size)
117
- ],
118
- dim=0,
119
- )
120
- q_idx = q_idx // block_size
121
-
 
 
 
122
  else:
123
- cache_len = total_seq_lens - max_seqlen_q
124
- assert batch_size == 1, 'batch_size must be 1 when total_seq_lens is set'
125
- q_idx = torch.tensor([total_seq_lens - 1], device=q.device, dtype=torch.int32) // block_size
126
 
 
127
  score = infllmv2_attn_stage1(
128
  q.contiguous(),
129
  k.contiguous(),
@@ -132,22 +133,27 @@ def compressed_attention(
132
  cu_seqlens_k=cu_seqlens_k,
133
  max_seqlen_q=max_seqlen_q,
134
  max_seqlen_k=max_seqlen_k,
135
- causal=q_idx.shape[0] > 1)
 
136
  score = score[:, :q_idx.shape[0], :]
137
 
138
- # Replace transform_score with max_pooling_1d
139
- block_score = max_pooling_1d(
140
  score.contiguous(),
141
- cache_len=cache_len,
 
 
 
 
142
  local_blocks=local_blocks,
143
  init_blocks=init_blocks,
144
  block_size=block_size,
145
- stride=kernel_stride,
146
- )
147
  # get topk
148
  topk = min(topk, block_score.shape[-1])
149
  topk_idx = block_score.topk(topk, dim=-1).indices.sort(-1).values
150
- topk_idx[topk_idx >= q_idx[None, :, None]] = -1
151
  topk_idx = topk_idx.to(torch.int32)
152
 
153
  return topk_idx
@@ -246,299 +252,89 @@ class CompressK(torch.nn.Module):
246
  return compressed_k, cu_seqlens_compressed
247
 
248
 
249
- class DynamicCacheQKV(DynamicCache):
250
- """
251
- A cache that grows dynamically as more tokens are generated. This is the default for generative models.
252
 
253
- It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
254
- `[batch_size, num_heads, seq_len, head_dim]`.
255
-
256
- Example:
257
- ```python
258
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
259
-
260
- >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
261
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
262
-
263
- >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
264
-
265
- >>> # Prepare a cache class and pass it to model's forward
266
- >>> past_key_values = DynamicCache()
267
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
268
- >>> outputs.past_key_values # access cache filled with key/values from generation
269
- DynamicCache()
270
- ```
271
- """
272
- def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
273
  super().__init__()
274
- if num_hidden_layers is None:
275
- self.key_cache: List[torch.Tensor] = []
276
- self.value_cache: List[torch.Tensor] = []
277
- self.compress_k_cache: List[torch.Tensor] = []
278
- self.no_compress_k_cache: List[torch.Tensor] = []
279
- self.cached_compressed_cu_seqlens: List[torch.Tensor] = []
280
- self.no_rope_key_cache: List[torch.Tensor] = []
 
 
 
281
  else:
282
- self.key_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
283
- self.value_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
284
- self.compress_k_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
285
- self.no_compress_k_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
286
- self.cached_compressed_cu_seqlens: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
287
- self.no_rope_key_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
288
- self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
289
-
290
- def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
291
- """
292
- Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
293
- sequence length.
294
- """
295
- if layer_idx < len(self):
296
- return (self.key_cache[layer_idx], self.value_cache[layer_idx])
297
  else:
298
- raise KeyError(f'Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}')
299
-
300
- def __iter__(self):
301
- """
302
- Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
303
- keys and values
304
- """
305
- for layer_idx in range(len(self)):
306
- yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
307
-
308
- def __len__(self):
309
- """
310
- Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
311
- to the number of layers in the model.
312
- """
313
- return len(self.key_cache)
 
 
 
 
 
 
 
 
314
 
315
- def update(
316
- self,
317
- key_states: torch.Tensor,
318
- value_states: torch.Tensor,
319
- layer_idx: int,
320
- cache_kwargs: Optional[Dict[str, Any]] = None
321
- ) -> Tuple[torch.Tensor, torch.Tensor]:
322
- """
323
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
324
 
325
- Parameters:
326
- key_states (`torch.Tensor`):
327
- The new key states to cache.
328
- value_states (`torch.Tensor`):
329
- The new value states to cache.
330
- layer_idx (`int`):
331
- The index of the layer to cache the states for.
332
- cache_kwargs (`Dict[str, Any]`, `optional`):
333
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
334
-
335
- Return:
336
- A tuple containing the updated key and value states.
337
- """
338
- # Update the number of seen tokens
339
  if layer_idx == 0:
340
  self._seen_tokens += key_states.shape[-2]
 
341
 
342
- # Update the cache
343
- if len(self.key_cache) <= layer_idx:
344
- self.key_cache.append(key_states)
345
- self.value_cache.append(value_states)
346
-
347
- # content on layer cache can be a tensor and checking not tensor causes errors
348
- # so we explicitly check for the empty list
349
- elif self.key_cache[layer_idx] == []:
350
- self.key_cache[layer_idx] = key_states
351
- self.value_cache[layer_idx] = value_states
352
-
353
- else:
354
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
355
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
356
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
357
-
358
- def update_no_rope_key(
359
- self,
360
- key_states: torch.Tensor,
361
- layer_idx: int,
362
- cache_kwargs: Optional[Dict[str, Any]] = None):
363
-
364
- # Update the cache
365
- if len(self.no_rope_key_cache) <= layer_idx:
366
- self.no_rope_key_cache.append(key_states)
367
-
368
- # content on layer cache can be a tensor and checking not tensor causes errors
369
- # so we explicitly check for the empty list
370
- elif self.no_rope_key_cache[layer_idx] == []:
371
- self.no_rope_key_cache[layer_idx] = key_states
372
- else:
373
- self.no_rope_key_cache[layer_idx] = torch.cat([self.no_rope_key_cache[layer_idx], key_states], dim=1)
374
- return self.no_rope_key_cache[layer_idx]
375
-
376
- def update_compress_k(
377
- self,
378
- key_states: torch.Tensor,
379
- layer_idx: int,
380
- cache_kwargs: Optional[Dict[str, Any]] = None
381
- ) -> Tuple[torch.Tensor, torch.Tensor]:
382
- """
383
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
384
-
385
- Parameters:
386
- key_states (`torch.Tensor`):
387
- The new key states to cache.
388
- value_states (`torch.Tensor`):
389
- The new value states to cache.
390
- layer_idx (`int`):
391
- The index of the layer to cache the states for.
392
- cache_kwargs (`Dict[str, Any]`, `optional`):
393
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
394
-
395
- Return:
396
- A tuple containing the updated key and value states.
397
- """
398
 
399
- # Update the cache
400
- if len(self.compress_k_cache) <= layer_idx:
401
- self.compress_k_cache.append(key_states)
402
 
403
- # content on layer cache can be a tensor and checking not tensor causes errors
404
- # so we explicitly check for the empty list
405
- elif self.compress_k_cache[layer_idx] == []:
406
- self.compress_k_cache[layer_idx] = key_states
407
- else:
408
- self.compress_k_cache[layer_idx] = torch.cat([self.compress_k_cache[layer_idx], key_states], dim=0)
409
- return self.compress_k_cache[layer_idx]
410
 
411
- def update_no_compress_k(
412
- self,
413
- key_states: torch.Tensor,
414
- layer_idx: int,
415
- kernel_size: int = 32,
416
- kernel_stride: int = 16,
417
- cache_kwargs: Optional[Dict[str, Any]] = None
418
- ) -> Tuple[torch.Tensor, torch.Tensor]:
419
- """
420
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
421
 
422
- Parameters:
423
- key_states (`torch.Tensor`):
424
- The new key states to cache.
425
- value_states (`torch.Tensor`):
426
- The new value states to cache.
427
- layer_idx (`int`):
428
- The index of the layer to cache the states for.
429
- cache_kwargs (`Dict[str, Any]`, `optional`):
430
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
431
-
432
- Return:
433
- A tuple containing the updated key and value states.
434
- """
435
- # Update the cache
436
- if len(self.no_compress_k_cache) <= layer_idx:
437
- self.no_compress_k_cache.append(key_states)
438
-
439
- # content on layer cache can be a tensor and checking not tensor causes errors
440
- # so we explicitly check for the empty list
441
- elif self.no_compress_k_cache[layer_idx] == []:
442
- self.no_compress_k_cache[layer_idx] = key_states
443
- else:
444
- self.no_compress_k_cache[layer_idx] = torch.cat([self.no_compress_k_cache[layer_idx], key_states], dim=0)
445
 
446
- current_len = self.no_compress_k_cache[layer_idx].shape[0]
447
-
448
- if current_len >= kernel_size:
449
- k_chunk = self.no_compress_k_cache[layer_idx][:kernel_size]
450
- self.no_compress_k_cache[layer_idx] = self.no_compress_k_cache[layer_idx][kernel_stride:]
451
- return k_chunk
452
- else:
453
- return None
454
-
455
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
456
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
457
- # TODO: deprecate this function in favor of `cache_position`
458
- if len(self.key_cache) <= layer_idx or (len(self.key_cache) > layer_idx and self.key_cache[layer_idx] == []):
459
- return 0
460
- return self.key_cache[layer_idx].shape[-2]
461
-
462
- def get_max_length(self) -> Optional[int]:
463
- """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
464
- return None
465
-
466
- def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
467
- """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
468
- backward compatibility."""
469
- legacy_cache = ()
470
- for layer_idx in range(len(self)):
471
- legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
472
- return legacy_cache
473
-
474
- # @classmethod
475
- # def from_legacy_cache(
476
- # cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
477
- # ) -> "DynamicCacheQKV":
478
- # """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
479
- # backward compatibility."""
480
- # cache = cls(num_hidden_layers)
481
- # if past_key_values is not None:
482
- # for layer_idx in range(len(past_key_values)):
483
- # key_states, value_states, query_status = past_key_values[layer_idx]
484
- # cache.update(key_states, value_states, query_status,layer_idx)
485
- # return cache
486
-
487
- def crop(self, max_length: int):
488
- """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
489
- negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
490
- # In case it is negative
491
- if max_length < 0:
492
- max_length = self.get_seq_length() - abs(max_length)
493
-
494
- if self.get_seq_length() <= max_length:
495
- return
496
-
497
- self._seen_tokens = max_length
498
- for idx in range(len(self.key_cache)):
499
- if self.key_cache[idx] != []:
500
- self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
501
- self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
502
-
503
- def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: int) -> List['DynamicCacheQKV']:
504
- """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
505
- `_split_model_inputs()` in `generation.utils`"""
506
- out = []
507
- for i in range(0, full_batch_size, split_size):
508
- current_split = DynamicCacheQKV(num_hidden_layers)
509
- current_split._seen_tokens = self._seen_tokens
510
- current_split.key_cache = [tensor[i: i + split_size] for tensor in self.key_cache]
511
- current_split.value_cache = [tensor[i: i + split_size] for tensor in self.value_cache]
512
- out.append(current_split)
513
- return out
514
-
515
- @classmethod
516
- def from_batch_splits(cls, splits: List['DynamicCacheQKV'], num_hidden_layers: int) -> 'DynamicCacheQKV':
517
- """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
518
- `generation.utils`"""
519
- cache = cls(num_hidden_layers)
520
- for idx in range(len(splits[0])):
521
- key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
522
- value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
523
- query_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
524
- if key_cache != []:
525
- layer_keys = torch.cat(key_cache, dim=0)
526
- layer_values = torch.cat(value_cache, dim=0)
527
- layer_query = torch.cat(query_cache, dim=0)
528
- cache.update(layer_keys, layer_values, idx, query_states=layer_query)
529
- return cache
530
-
531
- def batch_repeat_interleave(self, repeats: int):
532
- """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
533
- for layer_idx in range(len(self)):
534
- self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
535
- self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
536
-
537
- def batch_select_indices(self, indices: torch.Tensor):
538
- """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
539
- for layer_idx in range(len(self)):
540
- self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
541
- self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
542
 
543
 
544
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
@@ -567,22 +363,6 @@ def _get_unpad_data(attention_mask):
567
  )
568
 
569
 
570
- def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
571
- warnings.warn(
572
- 'Calling `transformers.models.minicpm.modeling_minicpm._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask'
573
- )
574
- return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
575
-
576
-
577
- def _make_causal_mask(
578
- input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
579
- ):
580
- warnings.warn(
581
- 'Calling `transformers.models.minicpm.modeling_minicpm._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.minicpm.modeling_minicpm.AttentionMaskConverter._make_causal_mask'
582
- )
583
- return AttentionMaskConverter._make_causal_mask(
584
- input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
585
- )
586
 
587
 
588
  # @torch.jit.script # type: ignore
@@ -796,6 +576,21 @@ class MiniCPMMLP(nn.Module):
796
 
797
  return down_proj
798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
799
 
800
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
801
  """
@@ -927,15 +722,7 @@ class MiniCPMAttention(nn.Module):
927
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
928
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
929
 
930
- kv_seq_len = key_states.shape[-2]
931
- if past_key_value is not None:
932
- if self.layer_idx is None:
933
- raise ValueError(
934
- f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
935
- 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
936
- 'with a layer index.'
937
- )
938
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
939
  cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
940
 
941
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -1037,9 +824,7 @@ class MiniCPMFlashAttention2(MiniCPMAttention):
1037
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1038
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1039
 
1040
- kv_seq_len = key_states.shape[-2]
1041
- if past_key_value is not None:
1042
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1043
  cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
1044
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1045
 
@@ -1211,7 +996,7 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1211
  self.dense_len = self.config.sparse_config.get('dense_len', 8192)
1212
 
1213
  self.local_blocks = self.window_size // self.block_size # local_blocks
1214
- self.topk = self.config.sparse_config.get('topk', 64)
1215
  self.use_nope = self.config.sparse_config.get('use_nope', False)
1216
  self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
1217
 
@@ -1237,7 +1022,6 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1237
  output_attentions = False
1238
 
1239
  bsz, q_len, _ = hidden_states.size()
1240
- assert bsz == 1, 'Only batch_size=1 is supported at the moment.'
1241
 
1242
  query_states = self.q_proj(hidden_states)
1243
  key_states = self.k_proj(hidden_states)
@@ -1255,9 +1039,7 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1255
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1256
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1257
 
1258
- kv_seq_len = key_states.shape[-2]
1259
- if past_key_value is not None:
1260
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1261
  cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
1262
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1263
 
@@ -1271,12 +1053,11 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1271
  key_states = key_states.transpose(1, 2)
1272
  value_states = value_states.transpose(1, 2)
1273
  if self.use_nope:
 
1274
  no_rope_param = {
1275
  'key_states_no_rope': key_states_no_rope,
1276
  'query_states_no_rope': query_states_no_rope,
1277
  }
1278
- if kv_seq_len <= self.dense_len:
1279
- past_key_value.update_no_rope_key(key_states_no_rope, self.layer_idx)
1280
  else:
1281
  no_rope_param = None
1282
 
@@ -1308,15 +1089,11 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1308
  if kv_seq_len < self.dense_len:
1309
  attn_output = self._flash_attention_forward_dense(
1310
  query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate)
1311
- elif past_key_value is None or q_len != 1: # prefilling
1312
- attn_output = self._flash_attention_forward(
1313
  query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate,
1314
  no_rope_param=no_rope_param, # if past_key_value is not None else None,
1315
  past_key_value=past_key_value)
1316
- else:
1317
- attn_output = self._flash_attention_forward_with_kv_cache(
1318
- query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, no_rope_param=no_rope_param, past_key_value=past_key_value)
1319
-
1320
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
1321
  attn_output = self.o_proj(attn_output)
1322
 
@@ -1325,122 +1102,146 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1325
 
1326
  return attn_output, attn_weights, past_key_value
1327
 
1328
- def _flash_attention_forward(
1329
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, no_rope_param=None, past_key_value=None
1330
- ):
1331
- """
1332
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1333
- first unpad the input, then computes the attention scores and pad the final attention scores.
1334
-
1335
- Args:
1336
- query_states (`torch.Tensor`):
1337
- Input query states to be passed to Flash Attention API
1338
- key_states (`torch.Tensor`):
1339
- Input key states to be passed to Flash Attention API
1340
- value_states (`torch.Tensor`):
1341
- Input value states to be passed to Flash Attention API
1342
- attention_mask (`torch.Tensor`):
1343
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
1344
- position of padding tokens and 1 for the position of non-padding tokens.
1345
- dropout (`int`, *optional*):
1346
- Attention dropout
1347
- softmax_scale (`float`, *optional*):
1348
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1349
- """
1350
- if not self._flash_attn_uses_top_left_mask:
1351
- causal = self.is_causal
1352
- else:
1353
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in MiniCPMFlashAttention2 __init__.
1354
- causal = self.is_causal and query_length != 1
1355
- # Contains at least one padding token in the sequence
1356
- if attention_mask is not None:
1357
- batch_size = query_states.shape[0]
1358
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
1359
- query_states, key_states, value_states, attention_mask, query_length
1360
- )
1361
- if no_rope_param is not None:
1362
- # nope unpad
1363
- no_rope_param['query_states_no_rope'] = no_rope_param['query_states_no_rope'].squeeze(0)
1364
- no_rope_param['key_states_no_rope'] = no_rope_param['key_states_no_rope'].squeeze(0)
1365
-
1366
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1367
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1368
- attn_output_unpad = self.sparse_forward(
1369
- query_states,
1370
- key_states,
1371
- value_states,
1372
- cu_seqlens_q,
1373
- cu_seqlens_k,
1374
- max_seqlen_in_batch_q,
1375
- max_seqlen_in_batch_k,
1376
- no_rope_param=no_rope_param,
1377
- past_key_value=past_key_value,
1378
- )
1379
 
1380
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
1381
- else:
1382
- raise ValueError('Need attention mask')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1383
 
1384
- return attn_output
1385
 
1386
- def _flash_attention_forward_with_kv_cache(
1387
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, no_rope_param=None, past_key_value=None
1388
- ):
1389
  """
1390
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1391
- first unpad the input, then computes the attention scores and pad the final attention scores.
1392
-
1393
  Args:
1394
- query_states (`torch.Tensor`):
1395
- Input query states to be passed to Flash Attention API
1396
- key_states (`torch.Tensor`):
1397
- Input key states to be passed to Flash Attention API
1398
- value_states (`torch.Tensor`):
1399
- Input value states to be passed to Flash Attention API
1400
- attention_mask (`torch.Tensor`):
1401
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
1402
- position of padding tokens and 1 for the position of non-padding tokens.
1403
- dropout (`int`, *optional*):
1404
- Attention dropout
1405
- softmax_scale (`float`, *optional*):
1406
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1407
  """
1408
- if not self._flash_attn_uses_top_left_mask:
1409
- causal = self.is_causal
1410
- else:
1411
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in MiniCPMFlashAttention2 __init__.
1412
- causal = self.is_causal and query_length != 1
1413
- # Contains at least one padding token in the sequence
1414
- if attention_mask is not None:
1415
-
1416
- batch_size = query_states.shape[0]
1417
-
1418
- # query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
1419
- # query_states, key_states, value_states, attention_mask, query_length=query_length
1420
- # )
1421
 
1422
- assert batch_size == 1, 'Only batch_size=1 is supported at the moment.'
1423
- # prepare past kv ,new kv
1424
- new_q = query_states
 
1425
 
1426
- new_k = key_states[:, -1:, :, :].contiguous()
1427
- new_v = value_states[:, -1:, :, :].contiguous()
1428
 
1429
- past_k = key_states[:, :-1, :, :].contiguous()
1430
- past_v = value_states[:, :-1, :, :].contiguous()
1431
- if no_rope_param is not None:
1432
- # nope unpad
1433
- no_rope_param['query_states_no_rope'] = no_rope_param['query_states_no_rope'].squeeze(0)
1434
- no_rope_param['key_states_no_rope'] = no_rope_param['key_states_no_rope'].squeeze(0)
1435
 
1436
- attn_output = self.sparse_forward_with_kv_cache(
1437
- past_k=past_k, past_v=past_v, new_k=new_k, new_v=new_v, new_q=new_q, batch_size=batch_size, no_rope_param=no_rope_param, past_key_value=past_key_value)
1438
 
1439
- # attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
 
 
1440
  else:
1441
- raise ValueError('need attention mask')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1442
 
1443
- return attn_output
1444
 
1445
  def sparse_forward(self,
1446
  query_layer,
@@ -1451,24 +1252,18 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1451
  max_seqlen_in_batch_q,
1452
  max_seqlen_in_batch_k,
1453
  no_rope_param=None,
1454
- past_key_value=None):
1455
- stage1_k = key_layer if no_rope_param is None else no_rope_param['key_states_no_rope']
1456
- compressed_k, compressed_cu_seqlens = self.compress_k(stage1_k, cu_seqlens_k)
1457
- compressed_v = compressed_k.clone()
1458
- if past_key_value is not None:
1459
- # Compute the start indices of keys (k) that were not compressed, Only batch_size=1 is supported at the moment.
1460
- no_compress_k_start = compressed_k.shape[0] * self.kernel_stride
1461
- past_key_value.update_compress_k(
1462
- compressed_k, self.layer_idx
1463
- )
1464
- past_key_value.update_no_compress_k(
1465
- key_layer[no_compress_k_start:], self.layer_idx, no_compress_k_start)
1466
- past_key_value.cached_compressed_cu_seqlens.append(compressed_cu_seqlens)
1467
  compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
 
 
 
 
 
1468
  topk_idx = compressed_attention(
1469
  query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
1470
  compressed_k,
1471
- compressed_v,
1472
  self.kernel_size,
1473
  self.kernel_stride,
1474
  self.block_size,
@@ -1480,8 +1275,8 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1480
  None,
1481
  init_blocks=self.init_blocks,
1482
  local_blocks=self.local_blocks,
 
1483
  )
1484
-
1485
  topk_attn_output = infllmv2_attn_varlen_func(
1486
  query_layer,
1487
  key_layer,
@@ -1493,102 +1288,14 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
1493
  dropout_p=0.0,
1494
  deterministic=False,
1495
  softmax_scale=None,
1496
- causal=True,
1497
  return_attn_probs=False,
1498
- block_window_size=self.window_size // self.block_size,
1499
  topk_idx=topk_idx
1500
  )
1501
 
1502
  return topk_attn_output
1503
 
1504
- def sparse_forward_with_kv_cache(self, past_k=None, past_v=None, new_k=None, new_v=None, new_q=None, batch_size=None, no_rope_param=None, past_key_value=None):
1505
-
1506
- # stage1_k = new_k.squeeze(0) if no_rope_param is None else no_rope_param['key_states_no_rope']
1507
- if past_k.shape[1] + new_k.shape[1] == self.dense_len and (past_key_value.compress_k_cache == [] or len(past_key_value.compress_k_cache) < self.layer_idx + 1 or past_key_value.compress_k_cache[self.layer_idx] == []):
1508
- if no_rope_param is not None:
1509
- stage1_k = past_key_value.no_rope_key_cache[self.layer_idx].squeeze(0).contiguous() # just batch_size ==1
1510
- else:
1511
- stage1_k = torch.cat([past_k, new_k], dim=1).contiguous().squeeze(0).contiguous() # just batch_size ==1
1512
- compressed_k, compressed_cu_seqlens = self.compress_k(stage1_k, torch.tensor([0, stage1_k.shape[0]], device=stage1_k.device, dtype=torch.int32)) # just batch_size ==1
1513
-
1514
- # Compute the start indices of keys (k) that were not compressed, Only batch_size=1 is supported at the moment.
1515
- no_compress_k_start = compressed_k.shape[0] * self.kernel_stride
1516
- past_key_value.update_compress_k(
1517
- compressed_k, self.layer_idx
1518
- )
1519
- past_key_value.update_no_compress_k(
1520
- stage1_k[no_compress_k_start:], self.layer_idx, no_compress_k_start)
1521
- past_key_value.cached_compressed_cu_seqlens.append(compressed_cu_seqlens)
1522
-
1523
- else:
1524
- stage1_k = new_k.squeeze(0) if no_rope_param is None else no_rope_param['key_states_no_rope']
1525
- no_compress_k = past_key_value.update_no_compress_k(
1526
- stage1_k, self.layer_idx, kernel_stride=self.kernel_stride, kernel_size=self.kernel_size)
1527
- if no_compress_k is not None:
1528
- compressed_k = no_compress_k.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
1529
-
1530
- compressed_k = past_key_value.update_compress_k(
1531
- compressed_k, self.layer_idx) # [seqlen, nheads_k, head_dim]
1532
-
1533
- past_key_value.cached_compressed_cu_seqlens[self.layer_idx][-1] += 1 # !Increment the last entry in sequence lengths by 1; currently supports only batch_size = 1
1534
- compressed_cu_seqlens = past_key_value.cached_compressed_cu_seqlens[self.layer_idx]
1535
- else:
1536
- compressed_k = past_key_value.compress_k_cache[self.layer_idx] # [seqlen, nheads_k, head_dim]
1537
- compressed_cu_seqlens = past_key_value.cached_compressed_cu_seqlens[self.layer_idx]
1538
-
1539
- compressed_v = compressed_k.clone()
1540
-
1541
- compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
1542
- torch.cuda.synchronize()
1543
- # Manually verify that the lengths match
1544
- assert compressed_k.shape[0] == compressed_seqlens.sum().item(), 'The length of compressed_k does not match the sum of compressed_seqlens'
1545
- topk_idx = compressed_attention(
1546
- new_q.squeeze(0).contiguous() if no_rope_param is None else no_rope_param['query_states_no_rope'],
1547
- compressed_k,
1548
- compressed_v,
1549
- self.kernel_size,
1550
- self.kernel_stride,
1551
- self.block_size,
1552
- self.topk,
1553
- torch.tensor([0, 1], device=compressed_k.device, dtype=torch.int32),
1554
- compressed_cu_seqlens,
1555
- 1,
1556
- compressed_seqlens.max().item(),
1557
- None,
1558
- init_blocks=self.init_blocks,
1559
- local_blocks=self.local_blocks,
1560
- total_seq_lens=past_k.shape[1] + 1, # !Only batch_size=1 is supported at the moment.
1561
- )
1562
-
1563
- repeat_times = 1
1564
- if repeat_times > 1:
1565
- new_q = new_q.repeat_interleave(repeat_times, dim=-2)
1566
- else:
1567
- new_q = new_q
1568
-
1569
- cache_batch_idx = torch.arange(batch_size, device=new_q.device, dtype=torch.int32)
1570
-
1571
- seqlen_k = past_k.shape[1] + new_k.shape[1] # !Only batch_size=1 is supported at the moment.
1572
- seqlens_k = torch.full((batch_size,), seqlen_k - 1, dtype=torch.int32, device=new_q.device)
1573
-
1574
- past_k = torch.cat([past_k, torch.zeros_like(new_k, dtype=new_k.dtype)], dim=1).contiguous() # Append one zero vector to avoid potential out-of-bounds access
1575
- past_v = torch.cat([past_v, torch.zeros_like(new_v, dtype=new_v.dtype)], dim=1).contiguous() # Append one zero vector to avoid potential out-of-bounds access
1576
- topk_attn_output = infllmv2_attn_with_kvcache(
1577
- q=new_q,
1578
- k_cache=past_k,
1579
- v_cache=past_v,
1580
- topk_idx=topk_idx,
1581
- block_window_size=self.window_size // self.block_size,
1582
- k=new_k, # [batch_size, 1, nheads_k, d]
1583
- v=new_v, # [batch_size, 1, nheads_k, d]
1584
- cache_seqlens=seqlens_k, # current_seqlens_k-1
1585
- rotary_cos=None, # No rotary embeddings
1586
- rotary_sin=None, # No rotary embeddings
1587
- cache_batch_idx=cache_batch_idx,
1588
- causal=False, # Renaming to match function signature
1589
- )
1590
- return topk_attn_output
1591
-
1592
  def _flash_attention_forward_dense(
1593
  self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
1594
  ):
@@ -1727,9 +1434,7 @@ class MiniCPMSdpaAttention(MiniCPMAttention):
1727
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1728
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1729
 
1730
- kv_seq_len = key_states.shape[-2]
1731
- if past_key_value is not None:
1732
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
1733
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1734
 
1735
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -2052,11 +1757,13 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
2052
  raise ValueError(
2053
  'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.'
2054
  )
2055
- past_key_values = DynamicCache.from_legacy_cache(past_key_values)
2056
 
2057
- past_key_values_length = past_key_values.get_usable_length(seq_length)
 
 
 
2058
  if self.config.sparse_config is not None and torch.cuda.is_available() and past_key_values_length == 0:
2059
- past_key_values = DynamicCacheQKV()
2060
 
2061
  if position_ids is None:
2062
  device = input_ids.device if input_ids is not None else inputs_embeds.device
@@ -2282,12 +1989,17 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
2282
  ):
2283
  if past_key_values is not None:
2284
  if isinstance(past_key_values, Cache):
 
2285
  cache_length = past_key_values.get_seq_length()
2286
- past_length = past_key_values.seen_tokens
2287
- max_cache_length = None # past_key_values.get_max_length()
2288
- else:
2289
- cache_length = past_length = past_key_values[0][0].shape[2]
2290
  max_cache_length = None
 
 
 
 
2291
 
2292
  # Keep only the unprocessed tokens:
2293
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
 
24
  from torch import nn
25
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
26
  from transformers.activations import ACT2FN
27
+ from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin, DynamicLayer
28
  from transformers.modeling_attn_mask_utils import (
29
  AttentionMaskConverter,
30
  _prepare_4d_attention_mask,
 
57
  infllmv2_attn_varlen_func,
58
  infllmv2_attn_with_kvcache,
59
  max_pooling_1d,
60
+ max_pooling_1d_varlen
61
  )
62
  except:
63
  pass
 
80
  sm_scale: float = None,
81
  init_blocks: int = 1,
82
  local_blocks: int = 2,
83
+ cache_lens: torch.Tensor = None,
 
84
  ) -> Tuple[torch.Tensor, torch.Tensor]:
85
  """Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
86
 
 
99
  sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
100
  init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
101
  local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
102
+ cache_lens (torch.Tensor, optional): shape [batch_size], used to record the cache length of each query. Defaults to None.
 
103
 
104
  Returns:
105
  Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
106
  """
107
  with torch.no_grad():
 
108
  batch_size = cu_seqlens_q.shape[0] - 1
109
+
110
+ # Check if it's prefilling stage
111
+ is_prefilling = cache_lens is None or (cache_lens == 0).all().item()
112
+
113
+ # prefilling stage
114
+ if is_prefilling:
115
+ # Calculate q_idx for each query position in each batch
116
+ cache_lens = torch.zeros(batch_size, dtype=torch.int32, device=q.device)
117
+ q_idx = torch.cat([
118
+ (torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device) +
119
+ max_seqlen_q - (cu_seqlens_q[i + 1] - cu_seqlens_q[i])) // block_size
120
+ for i in range(batch_size)
121
+ ], dim=0) # shape: [total_q_len]
122
+ # decoding stage
123
  else:
124
+ # Each batch has only one query (last position). Shape: [batch_size] = [total_q_len] in decoding
125
+ q_idx = cache_lens // block_size
 
126
 
127
+ # compute attention score
128
  score = infllmv2_attn_stage1(
129
  q.contiguous(),
130
  k.contiguous(),
 
133
  cu_seqlens_k=cu_seqlens_k,
134
  max_seqlen_q=max_seqlen_q,
135
  max_seqlen_k=max_seqlen_k,
136
+ causal=is_prefilling)
137
+ # Shape: [num_heads, total_q_len, num_blocks]
138
  score = score[:, :q_idx.shape[0], :]
139
 
140
+ # Shape: [num_heads, total_q_len, num_blocks]
141
+ block_score = max_pooling_1d_varlen(
142
  score.contiguous(),
143
+ cu_seqlens_q,
144
+ cu_seqlens_k,
145
+ cache_lens,
146
+ max_seqlen_q,
147
+ max_seqlen_k,
148
  local_blocks=local_blocks,
149
  init_blocks=init_blocks,
150
  block_size=block_size,
151
+ stride=kernel_stride)
152
+
153
  # get topk
154
  topk = min(topk, block_score.shape[-1])
155
  topk_idx = block_score.topk(topk, dim=-1).indices.sort(-1).values
156
+ topk_idx[topk_idx > q_idx[None, :, None]] = -1
157
  topk_idx = topk_idx.to(torch.int32)
158
 
159
  return topk_idx
 
252
  return compressed_k, cu_seqlens_compressed
253
 
254
 
 
 
 
255
 
256
+ class InfLLMv2CacheLayer(DynamicLayer):
257
+ def __init__(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  super().__init__()
259
+ # Initialize any additional attributes specific to InfLLMv2CacheLayer
260
+ self.no_rope_keys = torch.tensor([], dtype=torch.float32)
261
+ self.compress_k_cache = []
262
+ self.no_compress_k_cache = []
263
+ self.cached_compressed_cu_seqlens = torch.tensor([], dtype=torch.int32)
264
+ self.compress_k_cache_varlen = torch.tensor([], dtype=torch.float32)
265
+
266
+ def update_no_rope_key(self, key_states):
267
+ if self.no_rope_keys.numel() == 0:
268
+ self.no_rope_keys = key_states
269
  else:
270
+ self.no_rope_keys = torch.cat([self.no_rope_keys, key_states], dim=1)
271
+ return self.no_rope_keys
272
+
273
+ def update_compress_k(self, key_states, cu_seqlens=None):
274
+ if len(self.compress_k_cache) == 0:
275
+ if cu_seqlens is not None:
276
+ self.cached_compressed_cu_seqlens = cu_seqlens.clone()
277
+ self.compress_k_cache_varlen = key_states
278
+ split_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
279
+ self.compress_k_cache = list(torch.split(key_states, split_sizes))
 
 
 
 
 
280
  else:
281
+ for index, k in enumerate(key_states):
282
+ if k is not None:
283
+ self.compress_k_cache[index] = torch.cat([self.compress_k_cache[index], k], dim=0)
284
+ new_seq_lens = torch.tensor([tensor.shape[0] for tensor in self.compress_k_cache], dtype=torch.int32)
285
+ new_cumsum = torch.cumsum(new_seq_lens, dim=0, dtype=torch.int32)
286
+
287
+ self.compress_k_cache_varlen = torch.cat(self.compress_k_cache, dim=0)
288
+ self.cached_compressed_cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), new_cumsum]).to(self.compress_k_cache_varlen.device)
289
+ return self.compress_k_cache_varlen, self.cached_compressed_cu_seqlens
290
+
291
+ def update_no_compress_k(self, key_states, kernel_size=32, kernel_stride=16):
292
+ k_chunk_list = []
293
+ for index, k in enumerate(key_states):
294
+ if len(self.no_compress_k_cache) <= index:
295
+ self.no_compress_k_cache.append(k)
296
+ else:
297
+ self.no_compress_k_cache[index] = torch.cat([self.no_compress_k_cache[index], k], dim=0)
298
+ current_len = self.no_compress_k_cache[index].shape[0]
299
+ if current_len >= kernel_size:
300
+ k_chunk_list.append(self.no_compress_k_cache[index][:kernel_size])
301
+ self.no_compress_k_cache[index] = self.no_compress_k_cache[index][kernel_stride:]
302
+ else:
303
+ k_chunk_list.append(None)
304
+ return k_chunk_list
305
 
306
+ class InfLLMv2Cache(DynamicCache):
307
+ def __init__(self,
308
+ config,num_hidden_layers: Optional[int] = None) -> None:
309
+ super().__init__(config=config)
310
+ self.layers = [InfLLMv2CacheLayer() for _ in range(num_hidden_layers)] if num_hidden_layers else []
311
+ self._seen_tokens = 0
 
 
 
312
 
313
+ def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  if layer_idx == 0:
315
  self._seen_tokens += key_states.shape[-2]
316
+ return self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
317
 
318
+ def update_no_rope_key(self, key_states, layer_idx, cache_kwargs=None):
319
+ return self.layers[layer_idx].update_no_rope_key(key_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
+ def update_compress_k(self, key_states, layer_idx, cu_seqlens=None, cache_kwargs=None):
322
+ return self.layers[layer_idx].update_compress_k(key_states, cu_seqlens)
 
323
 
324
+ def update_no_compress_k(self, key_states, layer_idx, kernel_size=32, kernel_stride=16, cache_kwargs=None):
325
+ return self.layers[layer_idx].update_no_compress_k(key_states, kernel_size, kernel_stride)
 
 
 
 
 
326
 
327
+ def crop(self, max_length):
328
+ for layer in self.layers:
329
+ layer.crop(max_length)
 
 
 
 
 
 
 
330
 
331
+ def batch_repeat_interleave(self, repeats):
332
+ for layer in self.layers:
333
+ layer.batch_repeat_interleave(repeats)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
+ def batch_select_indices(self, indices):
336
+ for layer in self.layers:
337
+ layer.batch_select_indices(indices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
 
339
 
340
  # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
 
363
  )
364
 
365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
 
367
 
368
  # @torch.jit.script # type: ignore
 
576
 
577
  return down_proj
578
 
579
+ def _unpad_one_tensor(hidden_states, attention_mask):
580
+ # Unpad the hidden states using the indices
581
+ indices, cu_seqlens, max_seqlen_in_batch = _get_unpad_data(attention_mask)
582
+ batch_size, seq_len = hidden_states.shape[:2]
583
+
584
+ # Get the remaining dimensions
585
+ remaining_dims = hidden_states.shape[2:]
586
+
587
+ # Reshape to (batch_size * seq_len, *remaining_dims)
588
+ reshaped_states = hidden_states.reshape(batch_size * seq_len, *remaining_dims)
589
+
590
+ # Apply unpadding using indices
591
+ unpadded_states = index_first_axis(reshaped_states, indices)
592
+
593
+ return unpadded_states, indices, cu_seqlens, max_seqlen_in_batch
594
 
595
  def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
596
  """
 
722
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
723
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
724
 
725
+ kv_seq_len = position_ids.max().item() + 1
 
 
 
 
 
 
 
 
726
  cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
727
 
728
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
824
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
825
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
826
 
827
+ kv_seq_len = position_ids.max().item() + 1
 
 
828
  cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
829
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
830
 
 
996
  self.dense_len = self.config.sparse_config.get('dense_len', 8192)
997
 
998
  self.local_blocks = self.window_size // self.block_size # local_blocks
999
+ self.topk = self.config.sparse_config.get('topk', 64) + (self.window_size//self.block_size)
1000
  self.use_nope = self.config.sparse_config.get('use_nope', False)
1001
  self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
1002
 
 
1022
  output_attentions = False
1023
 
1024
  bsz, q_len, _ = hidden_states.size()
 
1025
 
1026
  query_states = self.q_proj(hidden_states)
1027
  key_states = self.k_proj(hidden_states)
 
1039
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1040
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1041
 
1042
+ kv_seq_len = position_ids.max().item() + 1
 
 
1043
  cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
1044
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
1045
 
 
1053
  key_states = key_states.transpose(1, 2)
1054
  value_states = value_states.transpose(1, 2)
1055
  if self.use_nope:
1056
+ key_states_no_rope = past_key_value.update_no_rope_key(key_states_no_rope, self.layer_idx)
1057
  no_rope_param = {
1058
  'key_states_no_rope': key_states_no_rope,
1059
  'query_states_no_rope': query_states_no_rope,
1060
  }
 
 
1061
  else:
1062
  no_rope_param = None
1063
 
 
1089
  if kv_seq_len < self.dense_len:
1090
  attn_output = self._flash_attention_forward_dense(
1091
  query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate)
1092
+ else:
1093
+ attn_output = self._sparse_attention_forward(
1094
  query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate,
1095
  no_rope_param=no_rope_param, # if past_key_value is not None else None,
1096
  past_key_value=past_key_value)
 
 
 
 
1097
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
1098
  attn_output = self.o_proj(attn_output)
1099
 
 
1102
 
1103
  return attn_output, attn_weights, past_key_value
1104
 
1105
+ def _sparse_attention_forward(
1106
+ self,
1107
+ query_states,
1108
+ key_states,
1109
+ value_states,
1110
+ attention_mask,
1111
+ query_length,
1112
+ dropout=0.0,
1113
+ softmax_scale=None,
1114
+ no_rope_param=None,
1115
+ past_key_value=None):
1116
+ """
1117
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
1118
+ first unpad the input, then computes the attention scores and pad the final attention scores.
1119
+
1120
+ Args:
1121
+ query_states (`torch.Tensor`):
1122
+ Input query states to be passed to Flash Attention API
1123
+ key_states (`torch.Tensor`):
1124
+ Input key states to be passed to Flash Attention API
1125
+ value_states (`torch.Tensor`):
1126
+ Input value states to be passed to Flash Attention API
1127
+ attention_mask (`torch.Tensor`):
1128
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
1129
+ position of padding tokens and 1 for the position of non-padding tokens.
1130
+ dropout (`int`, *optional*):
1131
+ Attention dropout
1132
+ softmax_scale (`float`, *optional*):
1133
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
1134
+ """
1135
+ if not self._flash_attn_uses_top_left_mask:
1136
+ causal = self.is_causal
1137
+ else:
1138
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in MiniCPMFlashAttention2 __init__.
1139
+ causal = self.is_causal and query_length != 1
1140
+ # Contains at least one padding token in the sequence
1141
+ if attention_mask is not None:
1142
+ batch_size = query_states.shape[0]
1143
+ # assert batch_size == 1, 'Only batch_size=1 is supported at the moment.'
1144
+ if past_key_value!=None:
1145
+ compressed_k, compressed_cu_seqlens = self.get_compress_k(
1146
+ key_states=key_states if self.use_nope ==False else no_rope_param['key_states_no_rope'], # This can be optimized a bit;
1147
+ attention_mask=attention_mask,
1148
+ past_key_value=past_key_value)
1149
+
1150
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
1151
+ query_states, key_states, value_states, attention_mask, query_length
1152
+ )
 
 
 
1153
 
1154
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
1155
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
1156
+ if no_rope_param != None:
1157
+ if max_seqlen_in_batch_q == 1:
1158
+ no_rope_param['query_states_no_rope'] = no_rope_param['query_states_no_rope'].squeeze(1)
1159
+ else:
1160
+ no_rope_param['query_states_no_rope'],_, _, _ = _unpad_one_tensor(no_rope_param['query_states_no_rope'],attention_mask=attention_mask)
1161
+ if past_key_value==None:
1162
+ # compress_k use varlen form
1163
+ compressed_k, compressed_cu_seqlens = self.compress_k(key_states,cu_seqlens_k)
1164
+
1165
+ attn_output_unpad = self.sparse_forward(
1166
+ query_states,
1167
+ key_states,
1168
+ value_states,
1169
+ cu_seqlens_q,
1170
+ cu_seqlens_k,
1171
+ max_seqlen_in_batch_q,
1172
+ max_seqlen_in_batch_k,
1173
+ no_rope_param=no_rope_param,
1174
+ compressed_k=compressed_k,
1175
+ compressed_cu_seqlens=compressed_cu_seqlens)
1176
+
1177
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
1178
+ else:
1179
+ raise ValueError('Need attention mask')
1180
 
1181
+ return attn_output
1182
 
1183
+ def get_compress_k(self, key_states, attention_mask, past_key_value):
 
 
1184
  """
1185
+ Get compressed key states and corresponding cumulative sequence lengths.
1186
+
 
1187
  Args:
1188
+ key_states: Key states tensor
1189
+ cu_seqlens_k: Cumulative sequence lengths for keys
1190
+ past_key_value: Past key-value cache
1191
+ no_rope_param: Optional parameter containing key states without rope
1192
+
1193
+ Returns:
1194
+ Tuple of (compressed_k, compressed_cu_seqlens)
 
 
 
 
 
 
1195
  """
1196
+ # Check if this is prefilling or initial compression condition
1197
+ is_prefilling = (
1198
+ key_states.shape[1] >= self.dense_len and
1199
+ (
1200
+ not past_key_value.layers[self.layer_idx].compress_k_cache
1201
+ )
1202
+ )
 
 
 
 
 
 
1203
 
1204
+ if is_prefilling:
1205
+ unpadded_key_states, indices, cu_seqlens, max_seqlen_in_batch = _unpad_one_tensor(key_states,attention_mask=attention_mask)
1206
+ # Compress the keys
1207
+ compressed_k, compressed_cu_seqlens = self.compress_k(unpadded_key_states, cu_seqlens)
1208
 
1209
+ past_key_value.update_compress_k(
1210
+ compressed_k, self.layer_idx, compressed_cu_seqlens)
1211
 
1212
+ no_compress_k_list = []
1213
+ # Compute and update no_compress_k
1214
+ for i in range(len(compressed_cu_seqlens)-1):
1215
+ no_compress_k_start = (compressed_cu_seqlens[i+1]- compressed_cu_seqlens[i]) * self.kernel_stride
 
 
1216
 
1217
+ no_compress_k_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k_start:cu_seqlens[i+1]].clone())
 
1218
 
1219
+ past_key_value.update_no_compress_k(
1220
+ no_compress_k_list, self.layer_idx,kernel_stride=self.kernel_stride,
1221
+ kernel_size=self.kernel_size)
1222
  else:
1223
+ # Decode case: incremental update
1224
+ batch_size = key_states.shape[0] # key_states.shape = [batch_size, seq, k_head_num, head_dim]
1225
+ key_states_split = list(torch.split(
1226
+ key_states[:,-1:].squeeze(1), #[batch_size, seq, k_head_num, head_dim]->[batch_size, 1, k_head_num, head_dim]-> [batch_size, k_head_num, head_dim]
1227
+ [1] * batch_size,dim=0,
1228
+ ))
1229
+ # Try to update no_compress_k buffer
1230
+ no_compress_k_list = past_key_value.update_no_compress_k(
1231
+ key_states_split, self.layer_idx,
1232
+ kernel_stride=self.kernel_stride,
1233
+ kernel_size=self.kernel_size)
1234
+ new_compressed_k_list = []
1235
+ for no_compress_k in no_compress_k_list:
1236
+ if no_compress_k is not None:
1237
+ # We have enough tokens to compress
1238
+ new_compressed_k = no_compress_k.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
1239
+ new_compressed_k_list.append(new_compressed_k)
1240
+ else:
1241
+ new_compressed_k_list.append(None)
1242
+ compressed_k, compressed_cu_seqlens = past_key_value.update_compress_k(new_compressed_k_list, self.layer_idx,)
1243
 
1244
+ return compressed_k, compressed_cu_seqlens
1245
 
1246
  def sparse_forward(self,
1247
  query_layer,
 
1252
  max_seqlen_in_batch_q,
1253
  max_seqlen_in_batch_k,
1254
  no_rope_param=None,
1255
+ compressed_k=None,
1256
+ compressed_cu_seqlens=None):
 
 
 
 
 
 
 
 
 
 
 
1257
  compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
1258
+ cache_lens = None
1259
+ if max_seqlen_in_batch_q==1 and max_seqlen_in_batch_k>1: #decoding
1260
+ seq_lens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
1261
+ cache_lens = seq_lens_k-1
1262
+
1263
  topk_idx = compressed_attention(
1264
  query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
1265
  compressed_k,
1266
+ compressed_k.clone(),
1267
  self.kernel_size,
1268
  self.kernel_stride,
1269
  self.block_size,
 
1275
  None,
1276
  init_blocks=self.init_blocks,
1277
  local_blocks=self.local_blocks,
1278
+ cache_lens=cache_lens
1279
  )
 
1280
  topk_attn_output = infllmv2_attn_varlen_func(
1281
  query_layer,
1282
  key_layer,
 
1288
  dropout_p=0.0,
1289
  deterministic=False,
1290
  softmax_scale=None,
1291
+ causal=max_seqlen_in_batch_q != 1,
1292
  return_attn_probs=False,
1293
+ # block_window_size=self.window_size // self.block_size,
1294
  topk_idx=topk_idx
1295
  )
1296
 
1297
  return topk_attn_output
1298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1299
  def _flash_attention_forward_dense(
1300
  self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
1301
  ):
 
1434
  key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1435
  value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1436
 
1437
+ kv_seq_len = position_ids.max().item() + 1
 
 
1438
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
1439
 
1440
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
1757
  raise ValueError(
1758
  'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.'
1759
  )
 
1760
 
1761
+ # Calculate the usable length of past key values
1762
+ past_key_values_length = past_key_values.get_seq_length() if isinstance(past_key_values, InfLLMv2Cache) else 0
1763
+
1764
+ # Initialize InfLLMv2Cache if needed
1765
  if self.config.sparse_config is not None and torch.cuda.is_available() and past_key_values_length == 0:
1766
+ past_key_values = InfLLMv2Cache(config = self.config, num_hidden_layers=self.config.num_hidden_layers)
1767
 
1768
  if position_ids is None:
1769
  device = input_ids.device if input_ids is not None else inputs_embeds.device
 
1989
  ):
1990
  if past_key_values is not None:
1991
  if isinstance(past_key_values, Cache):
1992
+ # Use the new Cache class methods
1993
  cache_length = past_key_values.get_seq_length()
1994
+
1995
+ if self.config.sparse_config is not None and torch.cuda.is_available() and cache_length == 0:
1996
+ past_key_values = InfLLMv2Cache(config = self.config, num_hidden_layers=self.config.num_hidden_layers)
1997
+ past_length = cache_length
1998
  max_cache_length = None
1999
+ else:
2000
+ raise ValueError(
2001
+ 'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.'
2002
+ )
2003
 
2004
  # Keep only the unprocessed tokens:
2005
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where