update readme and modeling model
Browse files- README.md +326 -0
- modeling_minicpm.py +277 -565
README.md
CHANGED
@@ -18,3 +18,329 @@ library_name: transformers
|
|
18 |
<p align="center">
|
19 |
👋 Contact us in <a href="https://discord.gg/3cGQn9b3YM" target="_blank">Discord</a> and <a href="https://github.com/OpenBMB/MiniCPM/blob/main/assets/wechat.jpg" target="_blank">WeChat</a>
|
20 |
</p>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
<p align="center">
|
19 |
👋 Contact us in <a href="https://discord.gg/3cGQn9b3YM" target="_blank">Discord</a> and <a href="https://github.com/OpenBMB/MiniCPM/blob/main/assets/wechat.jpg" target="_blank">WeChat</a>
|
20 |
</p>
|
21 |
+
|
22 |
+
## What's New
|
23 |
+
- [2025.09.05] **MiniCPM4.1** series are released! This series is a hybrid reasoning model, which can be used in
|
24 |
+
both deep reasoning mode and non-reasoning mode. 🔥🔥🔥
|
25 |
+
- [2025.06.06] **MiniCPM4** series are released! This model achieves ultimate efficiency improvements while maintaining optimal performance at the same scale! It can achieve over 5x generation acceleration on typical end-side chips! You can find technical report [here](https://github.com/OpenBMB/MiniCPM/tree/main/report/MiniCPM_4_Technical_Report.pdf).🔥🔥🔥
|
26 |
+
|
27 |
+
## MiniCPM4 and MiniCPM4.1 Series
|
28 |
+
MiniCPM4 and MiniCPM4.1 series are highly efficient large language models (LLMs) designed explicitly for end-side devices, which achieves this efficiency through systematic innovation in four key dimensions: model architecture, training data, training algorithms, and inference systems.
|
29 |
+
- [MiniCPM4.1-8B](https://huggingface.co/openbmb/MiniCPM4.1-8B): The latest version of MiniCPM4, with 8B parameters, support fusion thinking. (**<-- you are here**)
|
30 |
+
- [MiniCPM4.1-8B-GPTQ](https://huggingface.co/openbmb/MiniCPM4.1-8B-GPTQ): MiniCPM4.1-8B in GPTQ format.
|
31 |
+
- [MiniCPM4.1-8B-AutoAWQ](https://huggingface.co/openbmb/MiniCPM4.1-8B-AutoAWQ): MiniCPM4.1-8B in AutoAWQ format.
|
32 |
+
- [MiniCPM-4.1-8B-Marlin](https://huggingface.co/openbmb/MiniCPM-4.1-8B-Marlin): MiniCPM4.1-8B in Marlin format.
|
33 |
+
- [MiniCPM4.1-8B-GGUF](https://huggingface.co/openbmb/MiniCPM4.1-8B-GGUF): MiniCPM4.1-8B in GGUF format.
|
34 |
+
- [MiniCPM4.1-8B-MLX](https://huggingface.co/openbmb/MiniCPM4.1-8B-MLX): MiniCPM4.1-8B in MLX format.
|
35 |
+
- [MiniCPM4.1-8B-Eagle3](https://huggingface.co/openbmb/MiniCPM4.1-8B-Eagle3): Eagle3 model for MiniCPM4.1-8B.
|
36 |
+
- **MiniCPM4 Series**
|
37 |
+
<details>
|
38 |
+
<summary>Click to expand all MiniCPM4 series models</summary>
|
39 |
+
|
40 |
+
- [**MiniCPM4-8B**](https://huggingface.co/openbmb/MiniCPM4-8B): The flagship model with 8B parameters, trained on 8T tokens
|
41 |
+
- [**MiniCPM4-0.5B**](https://huggingface.co/openbmb/MiniCPM4-0.5B): Lightweight version with 0.5B parameters, trained on 1T tokens
|
42 |
+
- [**MiniCPM4-8B-Eagle-FRSpec**](https://huggingface.co/openbmb/MiniCPM4-8B-Eagle-FRSpec): Eagle head for FRSpec, accelerating speculative inference
|
43 |
+
- [**MiniCPM4-8B-Eagle-FRSpec-QAT-cpmcu**](https://huggingface.co/openbmb/MiniCPM4-8B-Eagle-FRSpec-QAT-cpmcu): Eagle head with QAT for FRSpec, integrating speculation and quantization for ultra acceleration
|
44 |
+
- [**MiniCPM4-8B-Eagle-vLLM**](https://huggingface.co/openbmb/MiniCPM4-8B-Eagle-vLLM): Eagle head in vLLM format for speculative inference
|
45 |
+
- [**MiniCPM4-8B-marlin-Eagle-vLLM**](https://huggingface.co/openbmb/MiniCPM4-8B-marlin-Eagle-vLLM): Quantized Eagle head for vLLM format
|
46 |
+
- [**BitCPM4-0.5B**](https://huggingface.co/openbmb/BitCPM4-0.5B): Extreme ternary quantization of MiniCPM4-0.5B, achieving 90% bit width reduction
|
47 |
+
- [**BitCPM4-1B**](https://huggingface.co/openbmb/BitCPM4-1B): Extreme ternary quantization of MiniCPM3-1B, achieving 90% bit width reduction
|
48 |
+
- [**MiniCPM4-Survey**](https://huggingface.co/openbmb/MiniCPM4-Survey): Generates trustworthy, long-form survey papers from user queries
|
49 |
+
- [**MiniCPM4-MCP**](https://huggingface.co/openbmb/MiniCPM4-MCP): Integrates MCP tools to autonomously satisfy user requirements
|
50 |
+
</details>
|
51 |
+
|
52 |
+
## Introduction
|
53 |
+
MiniCPM4 and MiniCPM4.1 are extremely efficient edge-side large model that has undergone efficient optimization across four dimensions: model architecture, learning algorithms, training data, and inference systems, achieving ultimate efficiency improvements.
|
54 |
+
|
55 |
+
- 🏗️ **Efficient Model Architecture:**
|
56 |
+
- InfLLM v2 -- Trainable Sparse Attention Mechanism: Adopts a trainable sparse attention mechanism architecture where each token only needs to compute relevance with less than 5% of tokens in 128K long text processing, significantly reducing computational overhead for long texts
|
57 |
+
|
58 |
+
- 🧠 **Efficient Learning Algorithms:**
|
59 |
+
- Model Wind Tunnel 2.0 -- Efficient Predictable Scaling: Introduces scaling prediction methods for performance of downstream tasks, enabling more precise model training configuration search
|
60 |
+
- BitCPM -- Ultimate Ternary Quantization: Compresses model parameter bit-width to 3 values, achieving 90% extreme model bit-width reduction
|
61 |
+
- Efficient Training Engineering Optimization: Adopts FP8 low-precision computing technology combined with Multi-token Prediction training strategy
|
62 |
+
|
63 |
+
- 📚 **High-Quality Training Data:**
|
64 |
+
- UltraClean -- High-quality Pre-training Data Filtering and Generation: Builds iterative data cleaning strategies based on efficient data verification, open-sourcing high-quality Chinese and English pre-training dataset [UltraFinweb](https://huggingface.co/datasets/openbmb/Ultra-FineWeb)
|
65 |
+
- UltraChat v2 -- High-quality Supervised Fine-tuning Data Generation: Constructs large-scale high-quality supervised fine-tuning datasets covering multiple dimensions including knowledge-intensive data, reasoning-intensive data, instruction-following data, long text understanding data, and tool calling data
|
66 |
+
|
67 |
+
- ⚡ **Efficient Inference System:**
|
68 |
+
- CPM.cu -- Lightweight and Efficient CUDA Inference Framework: Integrates sparse attention, model quantization, and speculative sampling to achieve efficient prefilling and decoding
|
69 |
+
- ArkInfer -- Cross-platform Deployment System: Supports efficient deployment across multiple backend environments, providing flexible cross-platform adaptation capabilities
|
70 |
+
|
71 |
+
## Usage
|
72 |
+
|
73 |
+
### Inference with [CPM.cu](https://github.com/OpenBMB/cpm.cu)
|
74 |
+
|
75 |
+
We recommend using [CPM.cu](https://github.com/OpenBMB/cpm.cu) for the inference of MiniCPM4 and MiniCPM4.1. CPM.cu is a CUDA inference framework developed by OpenBMB, which integrates efficient sparse, speculative sampling, and quantization techniques, fully leveraging the efficiency advantages of MiniCPM4 and MiniCPM4.1.
|
76 |
+
|
77 |
+
You can install CPM.cu by running the following command:
|
78 |
+
|
79 |
+
```bash
|
80 |
+
git clone https://github.com/OpenBMB/cpm.cu.git --recursive
|
81 |
+
cd cpm.cu
|
82 |
+
python3 setup.py install
|
83 |
+
```
|
84 |
+
|
85 |
+
MiniCPM4.1 natively supports context lengths of up to 65,536(64k) tokens. To reproduce the long-text acceleration effect in the paper, we recommend using the LongRoPE factors that have been validated. Change the `rope_scaling` field in the `config.json` file as the following to enable LongRoPE.
|
86 |
+
```json
|
87 |
+
{
|
88 |
+
...,
|
89 |
+
"rope_scaling": {
|
90 |
+
"rope_type": "longrope",
|
91 |
+
"long_factor": [0.9977997200264581, 1.014658295992452, 1.0349680404997148, 1.059429246056193, 1.0888815016813513, 1.1243301355211495, 1.166977103606075, 1.2182568066927284, 1.2798772354275727, 1.3538666751582975, 1.4426259039919596, 1.5489853358570191, 1.6762658237220625, 1.8283407612492941, 2.0096956085876183, 2.225478927469756, 2.481536379650452, 2.784415934557119, 3.1413289096347365, 3.560047844772632, 4.048719380066383, 4.752651957515948, 5.590913044973868, 6.584005926629993, 7.7532214876576155, 9.119754865903639, 10.704443927019176, 12.524994176518703, 14.59739595363613, 16.93214476166354, 19.53823297353041, 22.417131025031697, 25.568260840911098, 28.991144156566317, 32.68408069090375, 36.65174474170465, 40.90396065611201, 45.4664008671033, 50.37147343433591, 55.6804490772103, 61.470816952306556, 67.8622707390618, 75.00516023410414, 83.11898235973767, 92.50044360202462, 103.57086856690864, 116.9492274587385, 118.16074567836519, 119.18497548708795, 120.04810876261652, 120.77352815196981, 121.38182790207875, 121.89094985353891, 122.31638758099915, 122.6714244963338, 122.9673822552567, 123.21386397019609, 123.41898278254268, 123.58957065488238, 123.73136519024158, 123.84917421274221, 123.94701903496814, 124.02825801299717, 124.09569231686116],
|
92 |
+
"short_factor": [0.9977997200264581, 1.014658295992452, 1.0349680404997148, 1.059429246056193, 1.0888815016813513, 1.1243301355211495, 1.166977103606075, 1.2182568066927284, 1.2798772354275727, 1.3538666751582975, 1.4426259039919596, 1.5489853358570191, 1.6762658237220625, 1.8283407612492941, 2.0096956085876183, 2.225478927469756, 2.481536379650452, 2.784415934557119, 3.1413289096347365, 3.560047844772632, 4.048719380066383, 4.752651957515948, 5.590913044973868, 6.584005926629993, 7.7532214876576155, 9.119754865903639, 10.704443927019176, 12.524994176518703, 14.59739595363613, 16.93214476166354, 19.53823297353041, 22.417131025031697, 25.568260840911098, 28.991144156566317, 32.68408069090375, 36.65174474170465, 40.90396065611201, 45.4664008671033, 50.37147343433591, 55.6804490772103, 61.470816952306556, 67.8622707390618, 75.00516023410414, 83.11898235973767, 92.50044360202462, 103.57086856690864, 116.9492274587385, 118.16074567836519, 119.18497548708795, 120.04810876261652, 120.77352815196981, 121.38182790207875, 121.89094985353891, 122.31638758099915, 122.6714244963338, 122.9673822552567, 123.21386397019609, 123.41898278254268, 123.58957065488238, 123.73136519024158, 123.84917421274221, 123.94701903496814, 124.02825801299717, 124.09569231686116],
|
93 |
+
"original_max_position_embeddings": 32768
|
94 |
+
}
|
95 |
+
}
|
96 |
+
```
|
97 |
+
|
98 |
+
After modification, you can run the following command to reproduce the long-context acceleration effect (the script will automatically download the model weights from HuggingFace)
|
99 |
+
```bash
|
100 |
+
python3 tests/test_generate.py
|
101 |
+
```
|
102 |
+
|
103 |
+
For more details about CPM.cu, please refer to [the repo CPM.cu](https://github.com/OpenBMB/cpm.cu).
|
104 |
+
|
105 |
+
### Hybird Reasoning Mode
|
106 |
+
|
107 |
+
MiniCPM4.1 supports hybrid reasoning mode, which can be used in both deep reasoning mode and non-reasoning mode. To enable hybrid reasoning mode. User can set `enable_thinking=True` in `tokenizer.apply_chat_template` to enable hybrid reasoning mode, and set `enable_thinking=False` to enable non-reasoning mode. Similarly, user can directly add `\no_think` at the end of the query to enable non-reasoning mode. If not add any special token or add `\think` at the end of the query, the model will enable reasoning mode.
|
108 |
+
|
109 |
+
```python
|
110 |
+
# Enable reasoning mode
|
111 |
+
prompt_text = tokenizer.apply_chat_template(
|
112 |
+
messages,
|
113 |
+
tokenize=False,
|
114 |
+
add_generation_prompt=True,
|
115 |
+
enable_thinking=True
|
116 |
+
)
|
117 |
+
# Enable non-reasoning mode
|
118 |
+
prompt_text = tokenizer.apply_chat_template(
|
119 |
+
messages,
|
120 |
+
tokenize=False,
|
121 |
+
add_generation_prompt=True,
|
122 |
+
enable_thinking=False
|
123 |
+
)
|
124 |
+
```
|
125 |
+
|
126 |
+
### Inference with Transformers
|
127 |
+
```python
|
128 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
129 |
+
import torch
|
130 |
+
torch.manual_seed(0)
|
131 |
+
|
132 |
+
path = 'openbmb/MiniCPM4.1-8B'
|
133 |
+
device = "cuda"
|
134 |
+
tokenizer = AutoTokenizer.from_pretrained(path)
|
135 |
+
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True)
|
136 |
+
|
137 |
+
# User can directly use the chat interface
|
138 |
+
# responds, history = model.chat(tokenizer, "Write an article about Artificial Intelligence.", temperature=0.7, top_p=0.7)
|
139 |
+
# print(responds)
|
140 |
+
|
141 |
+
# User can also use the generate interface
|
142 |
+
messages = [
|
143 |
+
{"role": "user", "content": "Write an article about Artificial Intelligence."},
|
144 |
+
]
|
145 |
+
prompt_text = tokenizer.apply_chat_template(
|
146 |
+
messages,
|
147 |
+
tokenize=False,
|
148 |
+
add_generation_prompt=True,
|
149 |
+
)
|
150 |
+
model_inputs = tokenizer([prompt_text], return_tensors="pt").to(device)
|
151 |
+
|
152 |
+
model_outputs = model.generate(
|
153 |
+
**model_inputs,
|
154 |
+
max_new_tokens=8192,
|
155 |
+
top_p=0.7,
|
156 |
+
temperature=0.7
|
157 |
+
)
|
158 |
+
output_token_ids = [
|
159 |
+
model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs['input_ids']))
|
160 |
+
]
|
161 |
+
|
162 |
+
responses = tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0]
|
163 |
+
print(responses)
|
164 |
+
```
|
165 |
+
|
166 |
+
MiniCPM4.1-8B supports `InfLLM v2`, a sparse attention mechanism designed for efficient long-sequence inference. It requires the [infllmv2_cuda_impl](https://github.com/OpenBMB/infllmv2_cuda_impl) library.
|
167 |
+
|
168 |
+
You can install it by running the following command:
|
169 |
+
```bash
|
170 |
+
git clone -b feature_infer https://github.com/OpenBMB/infllmv2_cuda_impl.git
|
171 |
+
cd infllmv2_cuda_impl
|
172 |
+
git submodule update --init --recursive
|
173 |
+
pip install -e . # or python setup.py install
|
174 |
+
```
|
175 |
+
|
176 |
+
To enable InfLLM v2, you need to add the `sparse_config` field in `config.json`:
|
177 |
+
```json
|
178 |
+
{
|
179 |
+
...,
|
180 |
+
"sparse_config": {
|
181 |
+
"kernel_size": 32,
|
182 |
+
"kernel_stride": 16,
|
183 |
+
"init_blocks": 1,
|
184 |
+
"block_size": 64,
|
185 |
+
"window_size": 2048,
|
186 |
+
"topk": 64,
|
187 |
+
"use_nope": false,
|
188 |
+
"dense_len": 8192
|
189 |
+
}
|
190 |
+
}
|
191 |
+
```
|
192 |
+
|
193 |
+
These parameters control the behavior of InfLLM v2:
|
194 |
+
* `kernel_size` (default: 32): The size of semantic kernels.
|
195 |
+
* `kernel_stride` (default: 16): The stride between adjacent kernels.
|
196 |
+
* `init_blocks` (default: 1): The number of initial blocks that every query token attends to. This ensures attention to the beginning of the sequence.
|
197 |
+
* `block_size` (default: 64): The block size for key-value blocks.
|
198 |
+
* `window_size` (default: 2048): The size of the local sliding window.
|
199 |
+
* `topk` (default: 64): The specifies that each token computes attention with only the top-k most relevant key-value blocks.
|
200 |
+
* `use_nope` (default: false): Whether to use the NOPE technique in block selection for improved performance.
|
201 |
+
* `dense_len` (default: 8192): Since Sparse Attention offers limited benefits for short sequences, the model can use standard (dense) attention for shorter texts. The model will use dense attention for sequences with a token length below `dense_len` and switch to sparse attention for sequences exceeding this length. Set this to `-1` to always use sparse attention regardless of sequence length.
|
202 |
+
|
203 |
+
MiniCPM4.1 natively supports context lengths of up to 65,536(64k) tokens. For conversations where the total length (including both input and output) significantly exceeds this limit, we recommend using RoPE scaling techniques for effective handling of long texts. We have validated the model's performance on context lengths of up to 131,072 tokens by modifying the LongRoPE factor.
|
204 |
+
|
205 |
+
You can apply the LongRoPE factor modification by modifying the model files. Specifically, in the `config.json` file, adjust the `rope_scaling` fields.
|
206 |
+
```json
|
207 |
+
{
|
208 |
+
...,
|
209 |
+
"rope_scaling": {
|
210 |
+
"rope_type": "longrope",
|
211 |
+
"long_factor": [0.9977997200264581, 1.014658295992452, 1.0349680404997148, 1.059429246056193, 1.0888815016813513, 1.1243301355211495, 1.166977103606075, 1.2182568066927284, 1.2798772354275727, 1.3538666751582975, 1.4426259039919596, 1.5489853358570191, 1.6762658237220625, 1.8283407612492941, 2.0096956085876183, 2.225478927469756, 2.481536379650452, 2.784415934557119, 3.1413289096347365, 3.560047844772632, 4.048719380066383, 4.752651957515948, 5.590913044973868, 6.584005926629993, 7.7532214876576155, 9.119754865903639, 10.704443927019176, 12.524994176518703, 14.59739595363613, 16.93214476166354, 19.53823297353041, 22.417131025031697, 25.568260840911098, 28.991144156566317, 32.68408069090375, 36.65174474170465, 40.90396065611201, 45.4664008671033, 50.37147343433591, 55.6804490772103, 61.470816952306556, 67.8622707390618, 75.00516023410414, 83.11898235973767, 92.50044360202462, 103.57086856690864, 116.9492274587385, 118.16074567836519, 119.18497548708795, 120.04810876261652, 120.77352815196981, 121.38182790207875, 121.89094985353891, 122.31638758099915, 122.6714244963338, 122.9673822552567, 123.21386397019609, 123.41898278254268, 123.58957065488238, 123.73136519024158, 123.84917421274221, 123.94701903496814, 124.02825801299717, 124.09569231686116],
|
212 |
+
"short_factor": [0.9977997200264581, 1.014658295992452, 1.0349680404997148, 1.059429246056193, 1.0888815016813513, 1.1243301355211495, 1.166977103606075, 1.2182568066927284, 1.2798772354275727, 1.3538666751582975, 1.4426259039919596, 1.5489853358570191, 1.6762658237220625, 1.8283407612492941, 2.0096956085876183, 2.225478927469756, 2.481536379650452, 2.784415934557119, 3.1413289096347365, 3.560047844772632, 4.048719380066383, 4.752651957515948, 5.590913044973868, 6.584005926629993, 7.7532214876576155, 9.119754865903639, 10.704443927019176, 12.524994176518703, 14.59739595363613, 16.93214476166354, 19.53823297353041, 22.417131025031697, 25.568260840911098, 28.991144156566317, 32.68408069090375, 36.65174474170465, 40.90396065611201, 45.4664008671033, 50.37147343433591, 55.6804490772103, 61.470816952306556, 67.8622707390618, 75.00516023410414, 83.11898235973767, 92.50044360202462, 103.57086856690864, 116.9492274587385, 118.16074567836519, 119.18497548708795, 120.04810876261652, 120.77352815196981, 121.38182790207875, 121.89094985353891, 122.31638758099915, 122.6714244963338, 122.9673822552567, 123.21386397019609, 123.41898278254268, 123.58957065488238, 123.73136519024158, 123.84917421274221, 123.94701903496814, 124.02825801299717, 124.09569231686116],
|
213 |
+
"original_max_position_embeddings": 32768
|
214 |
+
}
|
215 |
+
}
|
216 |
+
```
|
217 |
+
|
218 |
+
### Inference with [SGLang](https://github.com/sgl-project/sglang)
|
219 |
+
|
220 |
+
For now, you need to install our forked version of SGLang.
|
221 |
+
```bash
|
222 |
+
git clone -b openbmb https://github.com/OpenBMB/sglang.git
|
223 |
+
cd sglang
|
224 |
+
|
225 |
+
pip install --upgrade pip
|
226 |
+
pip install -e "python[all]"
|
227 |
+
```
|
228 |
+
|
229 |
+
You can start the inference server by running the following command:
|
230 |
+
```bash
|
231 |
+
python -m sglang.launch_server --model openbmb/MiniCPM4.1-8B --trust-remote-code --port 30000 --chat-template chatml
|
232 |
+
```
|
233 |
+
|
234 |
+
Then you can use the chat interface by running the following command:
|
235 |
+
```python
|
236 |
+
import openai
|
237 |
+
|
238 |
+
client = openai.Client(base_url=f"http://localhost:30000/v1", api_key="None")
|
239 |
+
|
240 |
+
response = client.chat.completions.create(
|
241 |
+
model="openbmb/MiniCPM4.1-8B",
|
242 |
+
messages=[
|
243 |
+
{"role": "user", "content": "Write an article about Artificial Intelligence."},
|
244 |
+
],
|
245 |
+
temperature=0.7,
|
246 |
+
max_tokens=8192,
|
247 |
+
)
|
248 |
+
|
249 |
+
print(response.choices[0].message.content)
|
250 |
+
```
|
251 |
+
|
252 |
+
### Inference with [vLLM](https://github.com/vllm-project/vllm)
|
253 |
+
For now, you need to install the latest version of vLLM.
|
254 |
+
```
|
255 |
+
pip install -U vllm \
|
256 |
+
--pre \
|
257 |
+
--extra-index-url https://wheels.vllm.ai/nightly
|
258 |
+
```
|
259 |
+
|
260 |
+
Then you can inference MiniCPM4.1-8B with vLLM:
|
261 |
+
```python
|
262 |
+
from transformers import AutoTokenizer
|
263 |
+
from vllm import LLM, SamplingParams
|
264 |
+
|
265 |
+
model_name = "openbmb/MiniCPM4.1-8B"
|
266 |
+
prompt = [{"role": "user", "content": "Please recommend 5 tourist attractions in Beijing. "}]
|
267 |
+
|
268 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
269 |
+
input_text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
|
270 |
+
|
271 |
+
llm = LLM(
|
272 |
+
model=model_name,
|
273 |
+
trust_remote_code=True,
|
274 |
+
max_num_batched_tokens=32768,
|
275 |
+
dtype="bfloat16",
|
276 |
+
gpu_memory_utilization=0.8,
|
277 |
+
)
|
278 |
+
sampling_params = SamplingParams(top_p=0.7, temperature=0.7, max_tokens=1024, repetition_penalty=1.02)
|
279 |
+
|
280 |
+
outputs = llm.generate(prompts=input_text, sampling_params=sampling_params)
|
281 |
+
|
282 |
+
print(outputs[0].outputs[0].text)
|
283 |
+
```
|
284 |
+
|
285 |
+
Also, you can start the inference server by running the following command:
|
286 |
+
> **Note**: In vLLM's chat API, `add_special_tokens` is `False` by default. This means important special tokens—such as the beginning-of-sequence (BOS) token—will not be added automatically. To ensure the input prompt is correctly formatted for the model, you should explicitly set `extra_body={"add_special_tokens": True}`.
|
287 |
+
|
288 |
+
```bash
|
289 |
+
vllm serve openbmb/MiniCPM4.1-8B
|
290 |
+
```
|
291 |
+
|
292 |
+
Then you can use the chat interface by running the following code:
|
293 |
+
|
294 |
+
```python
|
295 |
+
import openai
|
296 |
+
|
297 |
+
client = openai.Client(base_url="http://localhost:8000/v1", api_key="EMPTY")
|
298 |
+
|
299 |
+
response = client.chat.completions.create(
|
300 |
+
model="openbmb/MiniCPM4.1-8B",
|
301 |
+
messages=[
|
302 |
+
{"role": "user", "content": "Write an article about Artificial Intelligence."},
|
303 |
+
],
|
304 |
+
temperature=0.7,
|
305 |
+
max_tokens=1024,
|
306 |
+
extra_body=dict(add_special_tokens=True), # Ensures special tokens are added for chat template
|
307 |
+
|
308 |
+
)
|
309 |
+
|
310 |
+
print(response.choices[0].message.content)
|
311 |
+
```
|
312 |
+
|
313 |
+
## Evaluation Results
|
314 |
+
On two typical end-side chips, Jetson AGX Orin and RTX 4090, MiniCPM4 demonstrates significantly faster processing speed compared to similar-size models in long text processing tasks. As text length increases, MiniCPM4's efficiency advantage becomes more pronounced. On the Jetson AGX Orin platform, compared to Qwen3-8B, MiniCPM4 achieves approximately 7x decoding speed improvement.
|
315 |
+
|
316 |
+

|
317 |
+
|
318 |
+
#### Comprehensive Evaluation
|
319 |
+
MiniCPM4.1 launches end-side versions with 8B parameter scale, both achieving best-in-class performance in their respective categories.
|
320 |
+
|
321 |
+

|
322 |
+
|
323 |
+
#### Long Text Evaluation
|
324 |
+
MiniCPM4 is pre-trained on 32K long texts and achieves length extension through YaRN technology. In the 128K long text needle-in-a-haystack task, MiniCPM4 demonstrates outstanding performance.
|
325 |
+
|
326 |
+

|
327 |
+
|
328 |
+
## Statement
|
329 |
+
- As a language model, MiniCPM generates content by learning from a vast amount of text.
|
330 |
+
- However, it does not possess the ability to comprehend or express personal opinions or value judgments.
|
331 |
+
- Any content generated by MiniCPM does not represent the viewpoints or positions of the model developers.
|
332 |
+
- Therefore, when using content generated by MiniCPM, users should take full responsibility for evaluating and verifying it on their own.
|
333 |
+
|
334 |
+
## LICENSE
|
335 |
+
- This repository and MiniCPM models are released under the [Apache-2.0](https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE) License.
|
336 |
+
|
337 |
+
## Citation
|
338 |
+
- Please cite our [paper](https://github.com/OpenBMB/MiniCPM/tree/main/report/MiniCPM_4_Technical_Report.pdf) if you find our work valuable.
|
339 |
+
|
340 |
+
```bibtex
|
341 |
+
@article{minicpm4,
|
342 |
+
title={{MiniCPM4}: Ultra-Efficient LLMs on End Devices},
|
343 |
+
author={MiniCPM Team},
|
344 |
+
year={2025}
|
345 |
+
}
|
346 |
+
```
|
modeling_minicpm.py
CHANGED
@@ -24,7 +24,7 @@ import torch.utils.checkpoint
|
|
24 |
from torch import nn
|
25 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
26 |
from transformers.activations import ACT2FN
|
27 |
-
from transformers.cache_utils import Cache, DynamicCache
|
28 |
from transformers.modeling_attn_mask_utils import (
|
29 |
AttentionMaskConverter,
|
30 |
_prepare_4d_attention_mask,
|
@@ -57,6 +57,7 @@ try:
|
|
57 |
infllmv2_attn_varlen_func,
|
58 |
infllmv2_attn_with_kvcache,
|
59 |
max_pooling_1d,
|
|
|
60 |
)
|
61 |
except:
|
62 |
pass
|
@@ -79,8 +80,7 @@ def compressed_attention(
|
|
79 |
sm_scale: float = None,
|
80 |
init_blocks: int = 1,
|
81 |
local_blocks: int = 2,
|
82 |
-
|
83 |
-
total_seq_lens=-1,
|
84 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
85 |
"""Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
|
86 |
|
@@ -99,31 +99,32 @@ def compressed_attention(
|
|
99 |
sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
|
100 |
init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
|
101 |
local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
|
102 |
-
|
103 |
-
We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise.
|
104 |
|
105 |
Returns:
|
106 |
Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
|
107 |
"""
|
108 |
with torch.no_grad():
|
109 |
-
cache_len = 0
|
110 |
batch_size = cu_seqlens_q.shape[0] - 1
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
122 |
else:
|
123 |
-
|
124 |
-
|
125 |
-
q_idx = torch.tensor([total_seq_lens - 1], device=q.device, dtype=torch.int32) // block_size
|
126 |
|
|
|
127 |
score = infllmv2_attn_stage1(
|
128 |
q.contiguous(),
|
129 |
k.contiguous(),
|
@@ -132,22 +133,27 @@ def compressed_attention(
|
|
132 |
cu_seqlens_k=cu_seqlens_k,
|
133 |
max_seqlen_q=max_seqlen_q,
|
134 |
max_seqlen_k=max_seqlen_k,
|
135 |
-
causal=
|
|
|
136 |
score = score[:, :q_idx.shape[0], :]
|
137 |
|
138 |
-
#
|
139 |
-
block_score =
|
140 |
score.contiguous(),
|
141 |
-
|
|
|
|
|
|
|
|
|
142 |
local_blocks=local_blocks,
|
143 |
init_blocks=init_blocks,
|
144 |
block_size=block_size,
|
145 |
-
stride=kernel_stride
|
146 |
-
|
147 |
# get topk
|
148 |
topk = min(topk, block_score.shape[-1])
|
149 |
topk_idx = block_score.topk(topk, dim=-1).indices.sort(-1).values
|
150 |
-
topk_idx[topk_idx
|
151 |
topk_idx = topk_idx.to(torch.int32)
|
152 |
|
153 |
return topk_idx
|
@@ -246,299 +252,89 @@ class CompressK(torch.nn.Module):
|
|
246 |
return compressed_k, cu_seqlens_compressed
|
247 |
|
248 |
|
249 |
-
class DynamicCacheQKV(DynamicCache):
|
250 |
-
"""
|
251 |
-
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
252 |
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
Example:
|
257 |
-
```python
|
258 |
-
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
|
259 |
-
|
260 |
-
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
261 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
|
262 |
-
|
263 |
-
>>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
|
264 |
-
|
265 |
-
>>> # Prepare a cache class and pass it to model's forward
|
266 |
-
>>> past_key_values = DynamicCache()
|
267 |
-
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
|
268 |
-
>>> outputs.past_key_values # access cache filled with key/values from generation
|
269 |
-
DynamicCache()
|
270 |
-
```
|
271 |
-
"""
|
272 |
-
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
|
273 |
super().__init__()
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
|
|
|
|
|
|
281 |
else:
|
282 |
-
self.
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
293 |
-
sequence length.
|
294 |
-
"""
|
295 |
-
if layer_idx < len(self):
|
296 |
-
return (self.key_cache[layer_idx], self.value_cache[layer_idx])
|
297 |
else:
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
def
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
322 |
-
"""
|
323 |
-
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
324 |
|
325 |
-
|
326 |
-
key_states (`torch.Tensor`):
|
327 |
-
The new key states to cache.
|
328 |
-
value_states (`torch.Tensor`):
|
329 |
-
The new value states to cache.
|
330 |
-
layer_idx (`int`):
|
331 |
-
The index of the layer to cache the states for.
|
332 |
-
cache_kwargs (`Dict[str, Any]`, `optional`):
|
333 |
-
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
334 |
-
|
335 |
-
Return:
|
336 |
-
A tuple containing the updated key and value states.
|
337 |
-
"""
|
338 |
-
# Update the number of seen tokens
|
339 |
if layer_idx == 0:
|
340 |
self._seen_tokens += key_states.shape[-2]
|
|
|
341 |
|
342 |
-
|
343 |
-
|
344 |
-
self.key_cache.append(key_states)
|
345 |
-
self.value_cache.append(value_states)
|
346 |
-
|
347 |
-
# content on layer cache can be a tensor and checking not tensor causes errors
|
348 |
-
# so we explicitly check for the empty list
|
349 |
-
elif self.key_cache[layer_idx] == []:
|
350 |
-
self.key_cache[layer_idx] = key_states
|
351 |
-
self.value_cache[layer_idx] = value_states
|
352 |
-
|
353 |
-
else:
|
354 |
-
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
355 |
-
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
356 |
-
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
357 |
-
|
358 |
-
def update_no_rope_key(
|
359 |
-
self,
|
360 |
-
key_states: torch.Tensor,
|
361 |
-
layer_idx: int,
|
362 |
-
cache_kwargs: Optional[Dict[str, Any]] = None):
|
363 |
-
|
364 |
-
# Update the cache
|
365 |
-
if len(self.no_rope_key_cache) <= layer_idx:
|
366 |
-
self.no_rope_key_cache.append(key_states)
|
367 |
-
|
368 |
-
# content on layer cache can be a tensor and checking not tensor causes errors
|
369 |
-
# so we explicitly check for the empty list
|
370 |
-
elif self.no_rope_key_cache[layer_idx] == []:
|
371 |
-
self.no_rope_key_cache[layer_idx] = key_states
|
372 |
-
else:
|
373 |
-
self.no_rope_key_cache[layer_idx] = torch.cat([self.no_rope_key_cache[layer_idx], key_states], dim=1)
|
374 |
-
return self.no_rope_key_cache[layer_idx]
|
375 |
-
|
376 |
-
def update_compress_k(
|
377 |
-
self,
|
378 |
-
key_states: torch.Tensor,
|
379 |
-
layer_idx: int,
|
380 |
-
cache_kwargs: Optional[Dict[str, Any]] = None
|
381 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
382 |
-
"""
|
383 |
-
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
384 |
-
|
385 |
-
Parameters:
|
386 |
-
key_states (`torch.Tensor`):
|
387 |
-
The new key states to cache.
|
388 |
-
value_states (`torch.Tensor`):
|
389 |
-
The new value states to cache.
|
390 |
-
layer_idx (`int`):
|
391 |
-
The index of the layer to cache the states for.
|
392 |
-
cache_kwargs (`Dict[str, Any]`, `optional`):
|
393 |
-
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
394 |
-
|
395 |
-
Return:
|
396 |
-
A tuple containing the updated key and value states.
|
397 |
-
"""
|
398 |
|
399 |
-
|
400 |
-
|
401 |
-
self.compress_k_cache.append(key_states)
|
402 |
|
403 |
-
|
404 |
-
|
405 |
-
elif self.compress_k_cache[layer_idx] == []:
|
406 |
-
self.compress_k_cache[layer_idx] = key_states
|
407 |
-
else:
|
408 |
-
self.compress_k_cache[layer_idx] = torch.cat([self.compress_k_cache[layer_idx], key_states], dim=0)
|
409 |
-
return self.compress_k_cache[layer_idx]
|
410 |
|
411 |
-
def
|
412 |
-
self
|
413 |
-
|
414 |
-
layer_idx: int,
|
415 |
-
kernel_size: int = 32,
|
416 |
-
kernel_stride: int = 16,
|
417 |
-
cache_kwargs: Optional[Dict[str, Any]] = None
|
418 |
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
419 |
-
"""
|
420 |
-
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
421 |
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
value_states (`torch.Tensor`):
|
426 |
-
The new value states to cache.
|
427 |
-
layer_idx (`int`):
|
428 |
-
The index of the layer to cache the states for.
|
429 |
-
cache_kwargs (`Dict[str, Any]`, `optional`):
|
430 |
-
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
431 |
-
|
432 |
-
Return:
|
433 |
-
A tuple containing the updated key and value states.
|
434 |
-
"""
|
435 |
-
# Update the cache
|
436 |
-
if len(self.no_compress_k_cache) <= layer_idx:
|
437 |
-
self.no_compress_k_cache.append(key_states)
|
438 |
-
|
439 |
-
# content on layer cache can be a tensor and checking not tensor causes errors
|
440 |
-
# so we explicitly check for the empty list
|
441 |
-
elif self.no_compress_k_cache[layer_idx] == []:
|
442 |
-
self.no_compress_k_cache[layer_idx] = key_states
|
443 |
-
else:
|
444 |
-
self.no_compress_k_cache[layer_idx] = torch.cat([self.no_compress_k_cache[layer_idx], key_states], dim=0)
|
445 |
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
k_chunk = self.no_compress_k_cache[layer_idx][:kernel_size]
|
450 |
-
self.no_compress_k_cache[layer_idx] = self.no_compress_k_cache[layer_idx][kernel_stride:]
|
451 |
-
return k_chunk
|
452 |
-
else:
|
453 |
-
return None
|
454 |
-
|
455 |
-
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
456 |
-
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
457 |
-
# TODO: deprecate this function in favor of `cache_position`
|
458 |
-
if len(self.key_cache) <= layer_idx or (len(self.key_cache) > layer_idx and self.key_cache[layer_idx] == []):
|
459 |
-
return 0
|
460 |
-
return self.key_cache[layer_idx].shape[-2]
|
461 |
-
|
462 |
-
def get_max_length(self) -> Optional[int]:
|
463 |
-
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
|
464 |
-
return None
|
465 |
-
|
466 |
-
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
467 |
-
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
|
468 |
-
backward compatibility."""
|
469 |
-
legacy_cache = ()
|
470 |
-
for layer_idx in range(len(self)):
|
471 |
-
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
|
472 |
-
return legacy_cache
|
473 |
-
|
474 |
-
# @classmethod
|
475 |
-
# def from_legacy_cache(
|
476 |
-
# cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
|
477 |
-
# ) -> "DynamicCacheQKV":
|
478 |
-
# """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
|
479 |
-
# backward compatibility."""
|
480 |
-
# cache = cls(num_hidden_layers)
|
481 |
-
# if past_key_values is not None:
|
482 |
-
# for layer_idx in range(len(past_key_values)):
|
483 |
-
# key_states, value_states, query_status = past_key_values[layer_idx]
|
484 |
-
# cache.update(key_states, value_states, query_status,layer_idx)
|
485 |
-
# return cache
|
486 |
-
|
487 |
-
def crop(self, max_length: int):
|
488 |
-
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
|
489 |
-
negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
|
490 |
-
# In case it is negative
|
491 |
-
if max_length < 0:
|
492 |
-
max_length = self.get_seq_length() - abs(max_length)
|
493 |
-
|
494 |
-
if self.get_seq_length() <= max_length:
|
495 |
-
return
|
496 |
-
|
497 |
-
self._seen_tokens = max_length
|
498 |
-
for idx in range(len(self.key_cache)):
|
499 |
-
if self.key_cache[idx] != []:
|
500 |
-
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
|
501 |
-
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
|
502 |
-
|
503 |
-
def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: int) -> List['DynamicCacheQKV']:
|
504 |
-
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
|
505 |
-
`_split_model_inputs()` in `generation.utils`"""
|
506 |
-
out = []
|
507 |
-
for i in range(0, full_batch_size, split_size):
|
508 |
-
current_split = DynamicCacheQKV(num_hidden_layers)
|
509 |
-
current_split._seen_tokens = self._seen_tokens
|
510 |
-
current_split.key_cache = [tensor[i: i + split_size] for tensor in self.key_cache]
|
511 |
-
current_split.value_cache = [tensor[i: i + split_size] for tensor in self.value_cache]
|
512 |
-
out.append(current_split)
|
513 |
-
return out
|
514 |
-
|
515 |
-
@classmethod
|
516 |
-
def from_batch_splits(cls, splits: List['DynamicCacheQKV'], num_hidden_layers: int) -> 'DynamicCacheQKV':
|
517 |
-
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
|
518 |
-
`generation.utils`"""
|
519 |
-
cache = cls(num_hidden_layers)
|
520 |
-
for idx in range(len(splits[0])):
|
521 |
-
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
|
522 |
-
value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
|
523 |
-
query_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
|
524 |
-
if key_cache != []:
|
525 |
-
layer_keys = torch.cat(key_cache, dim=0)
|
526 |
-
layer_values = torch.cat(value_cache, dim=0)
|
527 |
-
layer_query = torch.cat(query_cache, dim=0)
|
528 |
-
cache.update(layer_keys, layer_values, idx, query_states=layer_query)
|
529 |
-
return cache
|
530 |
-
|
531 |
-
def batch_repeat_interleave(self, repeats: int):
|
532 |
-
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
|
533 |
-
for layer_idx in range(len(self)):
|
534 |
-
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
535 |
-
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
|
536 |
-
|
537 |
-
def batch_select_indices(self, indices: torch.Tensor):
|
538 |
-
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
|
539 |
-
for layer_idx in range(len(self)):
|
540 |
-
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
|
541 |
-
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
|
542 |
|
543 |
|
544 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
@@ -567,22 +363,6 @@ def _get_unpad_data(attention_mask):
|
|
567 |
)
|
568 |
|
569 |
|
570 |
-
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
571 |
-
warnings.warn(
|
572 |
-
'Calling `transformers.models.minicpm.modeling_minicpm._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask'
|
573 |
-
)
|
574 |
-
return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
575 |
-
|
576 |
-
|
577 |
-
def _make_causal_mask(
|
578 |
-
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
579 |
-
):
|
580 |
-
warnings.warn(
|
581 |
-
'Calling `transformers.models.minicpm.modeling_minicpm._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.minicpm.modeling_minicpm.AttentionMaskConverter._make_causal_mask'
|
582 |
-
)
|
583 |
-
return AttentionMaskConverter._make_causal_mask(
|
584 |
-
input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
|
585 |
-
)
|
586 |
|
587 |
|
588 |
# @torch.jit.script # type: ignore
|
@@ -796,6 +576,21 @@ class MiniCPMMLP(nn.Module):
|
|
796 |
|
797 |
return down_proj
|
798 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
799 |
|
800 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
801 |
"""
|
@@ -927,15 +722,7 @@ class MiniCPMAttention(nn.Module):
|
|
927 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
928 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
929 |
|
930 |
-
kv_seq_len =
|
931 |
-
if past_key_value is not None:
|
932 |
-
if self.layer_idx is None:
|
933 |
-
raise ValueError(
|
934 |
-
f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
|
935 |
-
'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
|
936 |
-
'with a layer index.'
|
937 |
-
)
|
938 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
939 |
cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
|
940 |
|
941 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
@@ -1037,9 +824,7 @@ class MiniCPMFlashAttention2(MiniCPMAttention):
|
|
1037 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
1038 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
1039 |
|
1040 |
-
kv_seq_len =
|
1041 |
-
if past_key_value is not None:
|
1042 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
1043 |
cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
|
1044 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
1045 |
|
@@ -1211,7 +996,7 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
1211 |
self.dense_len = self.config.sparse_config.get('dense_len', 8192)
|
1212 |
|
1213 |
self.local_blocks = self.window_size // self.block_size # local_blocks
|
1214 |
-
self.topk = self.config.sparse_config.get('topk', 64)
|
1215 |
self.use_nope = self.config.sparse_config.get('use_nope', False)
|
1216 |
self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
|
1217 |
|
@@ -1237,7 +1022,6 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
1237 |
output_attentions = False
|
1238 |
|
1239 |
bsz, q_len, _ = hidden_states.size()
|
1240 |
-
assert bsz == 1, 'Only batch_size=1 is supported at the moment.'
|
1241 |
|
1242 |
query_states = self.q_proj(hidden_states)
|
1243 |
key_states = self.k_proj(hidden_states)
|
@@ -1255,9 +1039,7 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
1255 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
1256 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
1257 |
|
1258 |
-
kv_seq_len =
|
1259 |
-
if past_key_value is not None:
|
1260 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
1261 |
cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
|
1262 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
1263 |
|
@@ -1271,12 +1053,11 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
1271 |
key_states = key_states.transpose(1, 2)
|
1272 |
value_states = value_states.transpose(1, 2)
|
1273 |
if self.use_nope:
|
|
|
1274 |
no_rope_param = {
|
1275 |
'key_states_no_rope': key_states_no_rope,
|
1276 |
'query_states_no_rope': query_states_no_rope,
|
1277 |
}
|
1278 |
-
if kv_seq_len <= self.dense_len:
|
1279 |
-
past_key_value.update_no_rope_key(key_states_no_rope, self.layer_idx)
|
1280 |
else:
|
1281 |
no_rope_param = None
|
1282 |
|
@@ -1308,15 +1089,11 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
1308 |
if kv_seq_len < self.dense_len:
|
1309 |
attn_output = self._flash_attention_forward_dense(
|
1310 |
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate)
|
1311 |
-
|
1312 |
-
attn_output = self.
|
1313 |
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate,
|
1314 |
no_rope_param=no_rope_param, # if past_key_value is not None else None,
|
1315 |
past_key_value=past_key_value)
|
1316 |
-
else:
|
1317 |
-
attn_output = self._flash_attention_forward_with_kv_cache(
|
1318 |
-
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, no_rope_param=no_rope_param, past_key_value=past_key_value)
|
1319 |
-
|
1320 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
1321 |
attn_output = self.o_proj(attn_output)
|
1322 |
|
@@ -1325,122 +1102,146 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
1325 |
|
1326 |
return attn_output, attn_weights, past_key_value
|
1327 |
|
1328 |
-
def
|
1329 |
-
|
1330 |
-
|
1331 |
-
|
1332 |
-
|
1333 |
-
|
1334 |
-
|
1335 |
-
|
1336 |
-
|
1337 |
-
|
1338 |
-
|
1339 |
-
|
1340 |
-
|
1341 |
-
|
1342 |
-
|
1343 |
-
|
1344 |
-
|
1345 |
-
|
1346 |
-
|
1347 |
-
|
1348 |
-
|
1349 |
-
|
1350 |
-
|
1351 |
-
|
1352 |
-
|
1353 |
-
|
1354 |
-
|
1355 |
-
|
1356 |
-
|
1357 |
-
|
1358 |
-
|
1359 |
-
|
1360 |
-
|
1361 |
-
|
1362 |
-
|
1363 |
-
|
1364 |
-
|
1365 |
-
|
1366 |
-
|
1367 |
-
|
1368 |
-
|
1369 |
-
|
1370 |
-
|
1371 |
-
|
1372 |
-
|
1373 |
-
|
1374 |
-
|
1375 |
-
|
1376 |
-
no_rope_param=no_rope_param,
|
1377 |
-
past_key_value=past_key_value,
|
1378 |
-
)
|
1379 |
|
1380 |
-
|
1381 |
-
|
1382 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1383 |
|
1384 |
-
|
1385 |
|
1386 |
-
def
|
1387 |
-
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, no_rope_param=None, past_key_value=None
|
1388 |
-
):
|
1389 |
"""
|
1390 |
-
|
1391 |
-
|
1392 |
-
|
1393 |
Args:
|
1394 |
-
|
1395 |
-
|
1396 |
-
|
1397 |
-
|
1398 |
-
|
1399 |
-
|
1400 |
-
|
1401 |
-
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
1402 |
-
position of padding tokens and 1 for the position of non-padding tokens.
|
1403 |
-
dropout (`int`, *optional*):
|
1404 |
-
Attention dropout
|
1405 |
-
softmax_scale (`float`, *optional*):
|
1406 |
-
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
1407 |
"""
|
1408 |
-
if
|
1409 |
-
|
1410 |
-
|
1411 |
-
|
1412 |
-
|
1413 |
-
|
1414 |
-
|
1415 |
-
|
1416 |
-
batch_size = query_states.shape[0]
|
1417 |
-
|
1418 |
-
# query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
1419 |
-
# query_states, key_states, value_states, attention_mask, query_length=query_length
|
1420 |
-
# )
|
1421 |
|
1422 |
-
|
1423 |
-
|
1424 |
-
|
|
|
1425 |
|
1426 |
-
|
1427 |
-
|
1428 |
|
1429 |
-
|
1430 |
-
|
1431 |
-
|
1432 |
-
|
1433 |
-
no_rope_param['query_states_no_rope'] = no_rope_param['query_states_no_rope'].squeeze(0)
|
1434 |
-
no_rope_param['key_states_no_rope'] = no_rope_param['key_states_no_rope'].squeeze(0)
|
1435 |
|
1436 |
-
|
1437 |
-
past_k=past_k, past_v=past_v, new_k=new_k, new_v=new_v, new_q=new_q, batch_size=batch_size, no_rope_param=no_rope_param, past_key_value=past_key_value)
|
1438 |
|
1439 |
-
|
|
|
|
|
1440 |
else:
|
1441 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1442 |
|
1443 |
-
return
|
1444 |
|
1445 |
def sparse_forward(self,
|
1446 |
query_layer,
|
@@ -1451,24 +1252,18 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
1451 |
max_seqlen_in_batch_q,
|
1452 |
max_seqlen_in_batch_k,
|
1453 |
no_rope_param=None,
|
1454 |
-
|
1455 |
-
|
1456 |
-
compressed_k, compressed_cu_seqlens = self.compress_k(stage1_k, cu_seqlens_k)
|
1457 |
-
compressed_v = compressed_k.clone()
|
1458 |
-
if past_key_value is not None:
|
1459 |
-
# Compute the start indices of keys (k) that were not compressed, Only batch_size=1 is supported at the moment.
|
1460 |
-
no_compress_k_start = compressed_k.shape[0] * self.kernel_stride
|
1461 |
-
past_key_value.update_compress_k(
|
1462 |
-
compressed_k, self.layer_idx
|
1463 |
-
)
|
1464 |
-
past_key_value.update_no_compress_k(
|
1465 |
-
key_layer[no_compress_k_start:], self.layer_idx, no_compress_k_start)
|
1466 |
-
past_key_value.cached_compressed_cu_seqlens.append(compressed_cu_seqlens)
|
1467 |
compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
|
|
|
|
|
|
|
|
|
|
|
1468 |
topk_idx = compressed_attention(
|
1469 |
query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
|
1470 |
compressed_k,
|
1471 |
-
|
1472 |
self.kernel_size,
|
1473 |
self.kernel_stride,
|
1474 |
self.block_size,
|
@@ -1480,8 +1275,8 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
1480 |
None,
|
1481 |
init_blocks=self.init_blocks,
|
1482 |
local_blocks=self.local_blocks,
|
|
|
1483 |
)
|
1484 |
-
|
1485 |
topk_attn_output = infllmv2_attn_varlen_func(
|
1486 |
query_layer,
|
1487 |
key_layer,
|
@@ -1493,102 +1288,14 @@ class MiniCPMInfLLMv2Attention(MiniCPMAttention):
|
|
1493 |
dropout_p=0.0,
|
1494 |
deterministic=False,
|
1495 |
softmax_scale=None,
|
1496 |
-
causal=
|
1497 |
return_attn_probs=False,
|
1498 |
-
block_window_size=self.window_size // self.block_size,
|
1499 |
topk_idx=topk_idx
|
1500 |
)
|
1501 |
|
1502 |
return topk_attn_output
|
1503 |
|
1504 |
-
def sparse_forward_with_kv_cache(self, past_k=None, past_v=None, new_k=None, new_v=None, new_q=None, batch_size=None, no_rope_param=None, past_key_value=None):
|
1505 |
-
|
1506 |
-
# stage1_k = new_k.squeeze(0) if no_rope_param is None else no_rope_param['key_states_no_rope']
|
1507 |
-
if past_k.shape[1] + new_k.shape[1] == self.dense_len and (past_key_value.compress_k_cache == [] or len(past_key_value.compress_k_cache) < self.layer_idx + 1 or past_key_value.compress_k_cache[self.layer_idx] == []):
|
1508 |
-
if no_rope_param is not None:
|
1509 |
-
stage1_k = past_key_value.no_rope_key_cache[self.layer_idx].squeeze(0).contiguous() # just batch_size ==1
|
1510 |
-
else:
|
1511 |
-
stage1_k = torch.cat([past_k, new_k], dim=1).contiguous().squeeze(0).contiguous() # just batch_size ==1
|
1512 |
-
compressed_k, compressed_cu_seqlens = self.compress_k(stage1_k, torch.tensor([0, stage1_k.shape[0]], device=stage1_k.device, dtype=torch.int32)) # just batch_size ==1
|
1513 |
-
|
1514 |
-
# Compute the start indices of keys (k) that were not compressed, Only batch_size=1 is supported at the moment.
|
1515 |
-
no_compress_k_start = compressed_k.shape[0] * self.kernel_stride
|
1516 |
-
past_key_value.update_compress_k(
|
1517 |
-
compressed_k, self.layer_idx
|
1518 |
-
)
|
1519 |
-
past_key_value.update_no_compress_k(
|
1520 |
-
stage1_k[no_compress_k_start:], self.layer_idx, no_compress_k_start)
|
1521 |
-
past_key_value.cached_compressed_cu_seqlens.append(compressed_cu_seqlens)
|
1522 |
-
|
1523 |
-
else:
|
1524 |
-
stage1_k = new_k.squeeze(0) if no_rope_param is None else no_rope_param['key_states_no_rope']
|
1525 |
-
no_compress_k = past_key_value.update_no_compress_k(
|
1526 |
-
stage1_k, self.layer_idx, kernel_stride=self.kernel_stride, kernel_size=self.kernel_size)
|
1527 |
-
if no_compress_k is not None:
|
1528 |
-
compressed_k = no_compress_k.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
|
1529 |
-
|
1530 |
-
compressed_k = past_key_value.update_compress_k(
|
1531 |
-
compressed_k, self.layer_idx) # [seqlen, nheads_k, head_dim]
|
1532 |
-
|
1533 |
-
past_key_value.cached_compressed_cu_seqlens[self.layer_idx][-1] += 1 # !Increment the last entry in sequence lengths by 1; currently supports only batch_size = 1
|
1534 |
-
compressed_cu_seqlens = past_key_value.cached_compressed_cu_seqlens[self.layer_idx]
|
1535 |
-
else:
|
1536 |
-
compressed_k = past_key_value.compress_k_cache[self.layer_idx] # [seqlen, nheads_k, head_dim]
|
1537 |
-
compressed_cu_seqlens = past_key_value.cached_compressed_cu_seqlens[self.layer_idx]
|
1538 |
-
|
1539 |
-
compressed_v = compressed_k.clone()
|
1540 |
-
|
1541 |
-
compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
|
1542 |
-
torch.cuda.synchronize()
|
1543 |
-
# Manually verify that the lengths match
|
1544 |
-
assert compressed_k.shape[0] == compressed_seqlens.sum().item(), 'The length of compressed_k does not match the sum of compressed_seqlens'
|
1545 |
-
topk_idx = compressed_attention(
|
1546 |
-
new_q.squeeze(0).contiguous() if no_rope_param is None else no_rope_param['query_states_no_rope'],
|
1547 |
-
compressed_k,
|
1548 |
-
compressed_v,
|
1549 |
-
self.kernel_size,
|
1550 |
-
self.kernel_stride,
|
1551 |
-
self.block_size,
|
1552 |
-
self.topk,
|
1553 |
-
torch.tensor([0, 1], device=compressed_k.device, dtype=torch.int32),
|
1554 |
-
compressed_cu_seqlens,
|
1555 |
-
1,
|
1556 |
-
compressed_seqlens.max().item(),
|
1557 |
-
None,
|
1558 |
-
init_blocks=self.init_blocks,
|
1559 |
-
local_blocks=self.local_blocks,
|
1560 |
-
total_seq_lens=past_k.shape[1] + 1, # !Only batch_size=1 is supported at the moment.
|
1561 |
-
)
|
1562 |
-
|
1563 |
-
repeat_times = 1
|
1564 |
-
if repeat_times > 1:
|
1565 |
-
new_q = new_q.repeat_interleave(repeat_times, dim=-2)
|
1566 |
-
else:
|
1567 |
-
new_q = new_q
|
1568 |
-
|
1569 |
-
cache_batch_idx = torch.arange(batch_size, device=new_q.device, dtype=torch.int32)
|
1570 |
-
|
1571 |
-
seqlen_k = past_k.shape[1] + new_k.shape[1] # !Only batch_size=1 is supported at the moment.
|
1572 |
-
seqlens_k = torch.full((batch_size,), seqlen_k - 1, dtype=torch.int32, device=new_q.device)
|
1573 |
-
|
1574 |
-
past_k = torch.cat([past_k, torch.zeros_like(new_k, dtype=new_k.dtype)], dim=1).contiguous() # Append one zero vector to avoid potential out-of-bounds access
|
1575 |
-
past_v = torch.cat([past_v, torch.zeros_like(new_v, dtype=new_v.dtype)], dim=1).contiguous() # Append one zero vector to avoid potential out-of-bounds access
|
1576 |
-
topk_attn_output = infllmv2_attn_with_kvcache(
|
1577 |
-
q=new_q,
|
1578 |
-
k_cache=past_k,
|
1579 |
-
v_cache=past_v,
|
1580 |
-
topk_idx=topk_idx,
|
1581 |
-
block_window_size=self.window_size // self.block_size,
|
1582 |
-
k=new_k, # [batch_size, 1, nheads_k, d]
|
1583 |
-
v=new_v, # [batch_size, 1, nheads_k, d]
|
1584 |
-
cache_seqlens=seqlens_k, # current_seqlens_k-1
|
1585 |
-
rotary_cos=None, # No rotary embeddings
|
1586 |
-
rotary_sin=None, # No rotary embeddings
|
1587 |
-
cache_batch_idx=cache_batch_idx,
|
1588 |
-
causal=False, # Renaming to match function signature
|
1589 |
-
)
|
1590 |
-
return topk_attn_output
|
1591 |
-
|
1592 |
def _flash_attention_forward_dense(
|
1593 |
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
1594 |
):
|
@@ -1727,9 +1434,7 @@ class MiniCPMSdpaAttention(MiniCPMAttention):
|
|
1727 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
1728 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
1729 |
|
1730 |
-
kv_seq_len =
|
1731 |
-
if past_key_value is not None:
|
1732 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
1733 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
1734 |
|
1735 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
@@ -2052,11 +1757,13 @@ class MiniCPMModel(MiniCPMPreTrainedModel):
|
|
2052 |
raise ValueError(
|
2053 |
'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.'
|
2054 |
)
|
2055 |
-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
2056 |
|
2057 |
-
|
|
|
|
|
|
|
2058 |
if self.config.sparse_config is not None and torch.cuda.is_available() and past_key_values_length == 0:
|
2059 |
-
past_key_values =
|
2060 |
|
2061 |
if position_ids is None:
|
2062 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
@@ -2282,12 +1989,17 @@ class MiniCPMForCausalLM(MiniCPMPreTrainedModel):
|
|
2282 |
):
|
2283 |
if past_key_values is not None:
|
2284 |
if isinstance(past_key_values, Cache):
|
|
|
2285 |
cache_length = past_key_values.get_seq_length()
|
2286 |
-
|
2287 |
-
|
2288 |
-
|
2289 |
-
|
2290 |
max_cache_length = None
|
|
|
|
|
|
|
|
|
2291 |
|
2292 |
# Keep only the unprocessed tokens:
|
2293 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
|
|
24 |
from torch import nn
|
25 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
26 |
from transformers.activations import ACT2FN
|
27 |
+
from transformers.cache_utils import Cache, DynamicCache, CacheLayerMixin, DynamicLayer
|
28 |
from transformers.modeling_attn_mask_utils import (
|
29 |
AttentionMaskConverter,
|
30 |
_prepare_4d_attention_mask,
|
|
|
57 |
infllmv2_attn_varlen_func,
|
58 |
infllmv2_attn_with_kvcache,
|
59 |
max_pooling_1d,
|
60 |
+
max_pooling_1d_varlen
|
61 |
)
|
62 |
except:
|
63 |
pass
|
|
|
80 |
sm_scale: float = None,
|
81 |
init_blocks: int = 1,
|
82 |
local_blocks: int = 2,
|
83 |
+
cache_lens: torch.Tensor = None,
|
|
|
84 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
85 |
"""Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention.
|
86 |
|
|
|
99 |
sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim).
|
100 |
init_blocks (int, optional): Number of init blocks for each query. Defaults to 1.
|
101 |
local_blocks (int, optional): Number of local blocks for each query. Defaults to 2.
|
102 |
+
cache_lens (torch.Tensor, optional): shape [batch_size], used to record the cache length of each query. Defaults to None.
|
|
|
103 |
|
104 |
Returns:
|
105 |
Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention
|
106 |
"""
|
107 |
with torch.no_grad():
|
|
|
108 |
batch_size = cu_seqlens_q.shape[0] - 1
|
109 |
+
|
110 |
+
# Check if it's prefilling stage
|
111 |
+
is_prefilling = cache_lens is None or (cache_lens == 0).all().item()
|
112 |
+
|
113 |
+
# prefilling stage
|
114 |
+
if is_prefilling:
|
115 |
+
# Calculate q_idx for each query position in each batch
|
116 |
+
cache_lens = torch.zeros(batch_size, dtype=torch.int32, device=q.device)
|
117 |
+
q_idx = torch.cat([
|
118 |
+
(torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device) +
|
119 |
+
max_seqlen_q - (cu_seqlens_q[i + 1] - cu_seqlens_q[i])) // block_size
|
120 |
+
for i in range(batch_size)
|
121 |
+
], dim=0) # shape: [total_q_len]
|
122 |
+
# decoding stage
|
123 |
else:
|
124 |
+
# Each batch has only one query (last position). Shape: [batch_size] = [total_q_len] in decoding
|
125 |
+
q_idx = cache_lens // block_size
|
|
|
126 |
|
127 |
+
# compute attention score
|
128 |
score = infllmv2_attn_stage1(
|
129 |
q.contiguous(),
|
130 |
k.contiguous(),
|
|
|
133 |
cu_seqlens_k=cu_seqlens_k,
|
134 |
max_seqlen_q=max_seqlen_q,
|
135 |
max_seqlen_k=max_seqlen_k,
|
136 |
+
causal=is_prefilling)
|
137 |
+
# Shape: [num_heads, total_q_len, num_blocks]
|
138 |
score = score[:, :q_idx.shape[0], :]
|
139 |
|
140 |
+
# Shape: [num_heads, total_q_len, num_blocks]
|
141 |
+
block_score = max_pooling_1d_varlen(
|
142 |
score.contiguous(),
|
143 |
+
cu_seqlens_q,
|
144 |
+
cu_seqlens_k,
|
145 |
+
cache_lens,
|
146 |
+
max_seqlen_q,
|
147 |
+
max_seqlen_k,
|
148 |
local_blocks=local_blocks,
|
149 |
init_blocks=init_blocks,
|
150 |
block_size=block_size,
|
151 |
+
stride=kernel_stride)
|
152 |
+
|
153 |
# get topk
|
154 |
topk = min(topk, block_score.shape[-1])
|
155 |
topk_idx = block_score.topk(topk, dim=-1).indices.sort(-1).values
|
156 |
+
topk_idx[topk_idx > q_idx[None, :, None]] = -1
|
157 |
topk_idx = topk_idx.to(torch.int32)
|
158 |
|
159 |
return topk_idx
|
|
|
252 |
return compressed_k, cu_seqlens_compressed
|
253 |
|
254 |
|
|
|
|
|
|
|
255 |
|
256 |
+
class InfLLMv2CacheLayer(DynamicLayer):
|
257 |
+
def __init__(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
super().__init__()
|
259 |
+
# Initialize any additional attributes specific to InfLLMv2CacheLayer
|
260 |
+
self.no_rope_keys = torch.tensor([], dtype=torch.float32)
|
261 |
+
self.compress_k_cache = []
|
262 |
+
self.no_compress_k_cache = []
|
263 |
+
self.cached_compressed_cu_seqlens = torch.tensor([], dtype=torch.int32)
|
264 |
+
self.compress_k_cache_varlen = torch.tensor([], dtype=torch.float32)
|
265 |
+
|
266 |
+
def update_no_rope_key(self, key_states):
|
267 |
+
if self.no_rope_keys.numel() == 0:
|
268 |
+
self.no_rope_keys = key_states
|
269 |
else:
|
270 |
+
self.no_rope_keys = torch.cat([self.no_rope_keys, key_states], dim=1)
|
271 |
+
return self.no_rope_keys
|
272 |
+
|
273 |
+
def update_compress_k(self, key_states, cu_seqlens=None):
|
274 |
+
if len(self.compress_k_cache) == 0:
|
275 |
+
if cu_seqlens is not None:
|
276 |
+
self.cached_compressed_cu_seqlens = cu_seqlens.clone()
|
277 |
+
self.compress_k_cache_varlen = key_states
|
278 |
+
split_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
279 |
+
self.compress_k_cache = list(torch.split(key_states, split_sizes))
|
|
|
|
|
|
|
|
|
|
|
280 |
else:
|
281 |
+
for index, k in enumerate(key_states):
|
282 |
+
if k is not None:
|
283 |
+
self.compress_k_cache[index] = torch.cat([self.compress_k_cache[index], k], dim=0)
|
284 |
+
new_seq_lens = torch.tensor([tensor.shape[0] for tensor in self.compress_k_cache], dtype=torch.int32)
|
285 |
+
new_cumsum = torch.cumsum(new_seq_lens, dim=0, dtype=torch.int32)
|
286 |
+
|
287 |
+
self.compress_k_cache_varlen = torch.cat(self.compress_k_cache, dim=0)
|
288 |
+
self.cached_compressed_cu_seqlens = torch.cat([torch.tensor([0], dtype=torch.int32), new_cumsum]).to(self.compress_k_cache_varlen.device)
|
289 |
+
return self.compress_k_cache_varlen, self.cached_compressed_cu_seqlens
|
290 |
+
|
291 |
+
def update_no_compress_k(self, key_states, kernel_size=32, kernel_stride=16):
|
292 |
+
k_chunk_list = []
|
293 |
+
for index, k in enumerate(key_states):
|
294 |
+
if len(self.no_compress_k_cache) <= index:
|
295 |
+
self.no_compress_k_cache.append(k)
|
296 |
+
else:
|
297 |
+
self.no_compress_k_cache[index] = torch.cat([self.no_compress_k_cache[index], k], dim=0)
|
298 |
+
current_len = self.no_compress_k_cache[index].shape[0]
|
299 |
+
if current_len >= kernel_size:
|
300 |
+
k_chunk_list.append(self.no_compress_k_cache[index][:kernel_size])
|
301 |
+
self.no_compress_k_cache[index] = self.no_compress_k_cache[index][kernel_stride:]
|
302 |
+
else:
|
303 |
+
k_chunk_list.append(None)
|
304 |
+
return k_chunk_list
|
305 |
|
306 |
+
class InfLLMv2Cache(DynamicCache):
|
307 |
+
def __init__(self,
|
308 |
+
config,num_hidden_layers: Optional[int] = None) -> None:
|
309 |
+
super().__init__(config=config)
|
310 |
+
self.layers = [InfLLMv2CacheLayer() for _ in range(num_hidden_layers)] if num_hidden_layers else []
|
311 |
+
self._seen_tokens = 0
|
|
|
|
|
|
|
312 |
|
313 |
+
def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
if layer_idx == 0:
|
315 |
self._seen_tokens += key_states.shape[-2]
|
316 |
+
return self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
|
317 |
|
318 |
+
def update_no_rope_key(self, key_states, layer_idx, cache_kwargs=None):
|
319 |
+
return self.layers[layer_idx].update_no_rope_key(key_states)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
+
def update_compress_k(self, key_states, layer_idx, cu_seqlens=None, cache_kwargs=None):
|
322 |
+
return self.layers[layer_idx].update_compress_k(key_states, cu_seqlens)
|
|
|
323 |
|
324 |
+
def update_no_compress_k(self, key_states, layer_idx, kernel_size=32, kernel_stride=16, cache_kwargs=None):
|
325 |
+
return self.layers[layer_idx].update_no_compress_k(key_states, kernel_size, kernel_stride)
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
+
def crop(self, max_length):
|
328 |
+
for layer in self.layers:
|
329 |
+
layer.crop(max_length)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
|
331 |
+
def batch_repeat_interleave(self, repeats):
|
332 |
+
for layer in self.layers:
|
333 |
+
layer.batch_repeat_interleave(repeats)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
|
335 |
+
def batch_select_indices(self, indices):
|
336 |
+
for layer in self.layers:
|
337 |
+
layer.batch_select_indices(indices)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
|
339 |
|
340 |
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
|
|
|
363 |
)
|
364 |
|
365 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
|
367 |
|
368 |
# @torch.jit.script # type: ignore
|
|
|
576 |
|
577 |
return down_proj
|
578 |
|
579 |
+
def _unpad_one_tensor(hidden_states, attention_mask):
|
580 |
+
# Unpad the hidden states using the indices
|
581 |
+
indices, cu_seqlens, max_seqlen_in_batch = _get_unpad_data(attention_mask)
|
582 |
+
batch_size, seq_len = hidden_states.shape[:2]
|
583 |
+
|
584 |
+
# Get the remaining dimensions
|
585 |
+
remaining_dims = hidden_states.shape[2:]
|
586 |
+
|
587 |
+
# Reshape to (batch_size * seq_len, *remaining_dims)
|
588 |
+
reshaped_states = hidden_states.reshape(batch_size * seq_len, *remaining_dims)
|
589 |
+
|
590 |
+
# Apply unpadding using indices
|
591 |
+
unpadded_states = index_first_axis(reshaped_states, indices)
|
592 |
+
|
593 |
+
return unpadded_states, indices, cu_seqlens, max_seqlen_in_batch
|
594 |
|
595 |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
596 |
"""
|
|
|
722 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
723 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
724 |
|
725 |
+
kv_seq_len = position_ids.max().item() + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
726 |
cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
|
727 |
|
728 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
824 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
825 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
826 |
|
827 |
+
kv_seq_len = position_ids.max().item() + 1
|
|
|
|
|
828 |
cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
|
829 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
830 |
|
|
|
996 |
self.dense_len = self.config.sparse_config.get('dense_len', 8192)
|
997 |
|
998 |
self.local_blocks = self.window_size // self.block_size # local_blocks
|
999 |
+
self.topk = self.config.sparse_config.get('topk', 64) + (self.window_size//self.block_size)
|
1000 |
self.use_nope = self.config.sparse_config.get('use_nope', False)
|
1001 |
self.compress_k = CompressK(self.num_key_value_heads, self.head_dim, kernel_size=self.kernel_size, kernel_stride=self.kernel_stride)
|
1002 |
|
|
|
1022 |
output_attentions = False
|
1023 |
|
1024 |
bsz, q_len, _ = hidden_states.size()
|
|
|
1025 |
|
1026 |
query_states = self.q_proj(hidden_states)
|
1027 |
key_states = self.k_proj(hidden_states)
|
|
|
1039 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
1040 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
1041 |
|
1042 |
+
kv_seq_len = position_ids.max().item() + 1
|
|
|
|
|
1043 |
cos, sin = self.rotary_emb(value_states.to(torch.float32), seq_len=kv_seq_len)
|
1044 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
1045 |
|
|
|
1053 |
key_states = key_states.transpose(1, 2)
|
1054 |
value_states = value_states.transpose(1, 2)
|
1055 |
if self.use_nope:
|
1056 |
+
key_states_no_rope = past_key_value.update_no_rope_key(key_states_no_rope, self.layer_idx)
|
1057 |
no_rope_param = {
|
1058 |
'key_states_no_rope': key_states_no_rope,
|
1059 |
'query_states_no_rope': query_states_no_rope,
|
1060 |
}
|
|
|
|
|
1061 |
else:
|
1062 |
no_rope_param = None
|
1063 |
|
|
|
1089 |
if kv_seq_len < self.dense_len:
|
1090 |
attn_output = self._flash_attention_forward_dense(
|
1091 |
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate)
|
1092 |
+
else:
|
1093 |
+
attn_output = self._sparse_attention_forward(
|
1094 |
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate,
|
1095 |
no_rope_param=no_rope_param, # if past_key_value is not None else None,
|
1096 |
past_key_value=past_key_value)
|
|
|
|
|
|
|
|
|
1097 |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
1098 |
attn_output = self.o_proj(attn_output)
|
1099 |
|
|
|
1102 |
|
1103 |
return attn_output, attn_weights, past_key_value
|
1104 |
|
1105 |
+
def _sparse_attention_forward(
|
1106 |
+
self,
|
1107 |
+
query_states,
|
1108 |
+
key_states,
|
1109 |
+
value_states,
|
1110 |
+
attention_mask,
|
1111 |
+
query_length,
|
1112 |
+
dropout=0.0,
|
1113 |
+
softmax_scale=None,
|
1114 |
+
no_rope_param=None,
|
1115 |
+
past_key_value=None):
|
1116 |
+
"""
|
1117 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
1118 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
1119 |
+
|
1120 |
+
Args:
|
1121 |
+
query_states (`torch.Tensor`):
|
1122 |
+
Input query states to be passed to Flash Attention API
|
1123 |
+
key_states (`torch.Tensor`):
|
1124 |
+
Input key states to be passed to Flash Attention API
|
1125 |
+
value_states (`torch.Tensor`):
|
1126 |
+
Input value states to be passed to Flash Attention API
|
1127 |
+
attention_mask (`torch.Tensor`):
|
1128 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
1129 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
1130 |
+
dropout (`int`, *optional*):
|
1131 |
+
Attention dropout
|
1132 |
+
softmax_scale (`float`, *optional*):
|
1133 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
1134 |
+
"""
|
1135 |
+
if not self._flash_attn_uses_top_left_mask:
|
1136 |
+
causal = self.is_causal
|
1137 |
+
else:
|
1138 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in MiniCPMFlashAttention2 __init__.
|
1139 |
+
causal = self.is_causal and query_length != 1
|
1140 |
+
# Contains at least one padding token in the sequence
|
1141 |
+
if attention_mask is not None:
|
1142 |
+
batch_size = query_states.shape[0]
|
1143 |
+
# assert batch_size == 1, 'Only batch_size=1 is supported at the moment.'
|
1144 |
+
if past_key_value!=None:
|
1145 |
+
compressed_k, compressed_cu_seqlens = self.get_compress_k(
|
1146 |
+
key_states=key_states if self.use_nope ==False else no_rope_param['key_states_no_rope'], # This can be optimized a bit;
|
1147 |
+
attention_mask=attention_mask,
|
1148 |
+
past_key_value=past_key_value)
|
1149 |
+
|
1150 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
1151 |
+
query_states, key_states, value_states, attention_mask, query_length
|
1152 |
+
)
|
|
|
|
|
|
|
1153 |
|
1154 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
1155 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
1156 |
+
if no_rope_param != None:
|
1157 |
+
if max_seqlen_in_batch_q == 1:
|
1158 |
+
no_rope_param['query_states_no_rope'] = no_rope_param['query_states_no_rope'].squeeze(1)
|
1159 |
+
else:
|
1160 |
+
no_rope_param['query_states_no_rope'],_, _, _ = _unpad_one_tensor(no_rope_param['query_states_no_rope'],attention_mask=attention_mask)
|
1161 |
+
if past_key_value==None:
|
1162 |
+
# compress_k use varlen form
|
1163 |
+
compressed_k, compressed_cu_seqlens = self.compress_k(key_states,cu_seqlens_k)
|
1164 |
+
|
1165 |
+
attn_output_unpad = self.sparse_forward(
|
1166 |
+
query_states,
|
1167 |
+
key_states,
|
1168 |
+
value_states,
|
1169 |
+
cu_seqlens_q,
|
1170 |
+
cu_seqlens_k,
|
1171 |
+
max_seqlen_in_batch_q,
|
1172 |
+
max_seqlen_in_batch_k,
|
1173 |
+
no_rope_param=no_rope_param,
|
1174 |
+
compressed_k=compressed_k,
|
1175 |
+
compressed_cu_seqlens=compressed_cu_seqlens)
|
1176 |
+
|
1177 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
1178 |
+
else:
|
1179 |
+
raise ValueError('Need attention mask')
|
1180 |
|
1181 |
+
return attn_output
|
1182 |
|
1183 |
+
def get_compress_k(self, key_states, attention_mask, past_key_value):
|
|
|
|
|
1184 |
"""
|
1185 |
+
Get compressed key states and corresponding cumulative sequence lengths.
|
1186 |
+
|
|
|
1187 |
Args:
|
1188 |
+
key_states: Key states tensor
|
1189 |
+
cu_seqlens_k: Cumulative sequence lengths for keys
|
1190 |
+
past_key_value: Past key-value cache
|
1191 |
+
no_rope_param: Optional parameter containing key states without rope
|
1192 |
+
|
1193 |
+
Returns:
|
1194 |
+
Tuple of (compressed_k, compressed_cu_seqlens)
|
|
|
|
|
|
|
|
|
|
|
|
|
1195 |
"""
|
1196 |
+
# Check if this is prefilling or initial compression condition
|
1197 |
+
is_prefilling = (
|
1198 |
+
key_states.shape[1] >= self.dense_len and
|
1199 |
+
(
|
1200 |
+
not past_key_value.layers[self.layer_idx].compress_k_cache
|
1201 |
+
)
|
1202 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
1203 |
|
1204 |
+
if is_prefilling:
|
1205 |
+
unpadded_key_states, indices, cu_seqlens, max_seqlen_in_batch = _unpad_one_tensor(key_states,attention_mask=attention_mask)
|
1206 |
+
# Compress the keys
|
1207 |
+
compressed_k, compressed_cu_seqlens = self.compress_k(unpadded_key_states, cu_seqlens)
|
1208 |
|
1209 |
+
past_key_value.update_compress_k(
|
1210 |
+
compressed_k, self.layer_idx, compressed_cu_seqlens)
|
1211 |
|
1212 |
+
no_compress_k_list = []
|
1213 |
+
# Compute and update no_compress_k
|
1214 |
+
for i in range(len(compressed_cu_seqlens)-1):
|
1215 |
+
no_compress_k_start = (compressed_cu_seqlens[i+1]- compressed_cu_seqlens[i]) * self.kernel_stride
|
|
|
|
|
1216 |
|
1217 |
+
no_compress_k_list.append(unpadded_key_states[cu_seqlens[i]+no_compress_k_start:cu_seqlens[i+1]].clone())
|
|
|
1218 |
|
1219 |
+
past_key_value.update_no_compress_k(
|
1220 |
+
no_compress_k_list, self.layer_idx,kernel_stride=self.kernel_stride,
|
1221 |
+
kernel_size=self.kernel_size)
|
1222 |
else:
|
1223 |
+
# Decode case: incremental update
|
1224 |
+
batch_size = key_states.shape[0] # key_states.shape = [batch_size, seq, k_head_num, head_dim]
|
1225 |
+
key_states_split = list(torch.split(
|
1226 |
+
key_states[:,-1:].squeeze(1), #[batch_size, seq, k_head_num, head_dim]->[batch_size, 1, k_head_num, head_dim]-> [batch_size, k_head_num, head_dim]
|
1227 |
+
[1] * batch_size,dim=0,
|
1228 |
+
))
|
1229 |
+
# Try to update no_compress_k buffer
|
1230 |
+
no_compress_k_list = past_key_value.update_no_compress_k(
|
1231 |
+
key_states_split, self.layer_idx,
|
1232 |
+
kernel_stride=self.kernel_stride,
|
1233 |
+
kernel_size=self.kernel_size)
|
1234 |
+
new_compressed_k_list = []
|
1235 |
+
for no_compress_k in no_compress_k_list:
|
1236 |
+
if no_compress_k is not None:
|
1237 |
+
# We have enough tokens to compress
|
1238 |
+
new_compressed_k = no_compress_k.mean(dim=0, keepdim=True) # [1, n_heads_k, head_dim]
|
1239 |
+
new_compressed_k_list.append(new_compressed_k)
|
1240 |
+
else:
|
1241 |
+
new_compressed_k_list.append(None)
|
1242 |
+
compressed_k, compressed_cu_seqlens = past_key_value.update_compress_k(new_compressed_k_list, self.layer_idx,)
|
1243 |
|
1244 |
+
return compressed_k, compressed_cu_seqlens
|
1245 |
|
1246 |
def sparse_forward(self,
|
1247 |
query_layer,
|
|
|
1252 |
max_seqlen_in_batch_q,
|
1253 |
max_seqlen_in_batch_k,
|
1254 |
no_rope_param=None,
|
1255 |
+
compressed_k=None,
|
1256 |
+
compressed_cu_seqlens=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1257 |
compressed_seqlens = compressed_cu_seqlens[1:] - compressed_cu_seqlens[:-1]
|
1258 |
+
cache_lens = None
|
1259 |
+
if max_seqlen_in_batch_q==1 and max_seqlen_in_batch_k>1: #decoding
|
1260 |
+
seq_lens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
|
1261 |
+
cache_lens = seq_lens_k-1
|
1262 |
+
|
1263 |
topk_idx = compressed_attention(
|
1264 |
query_layer if no_rope_param is None else no_rope_param['query_states_no_rope'],
|
1265 |
compressed_k,
|
1266 |
+
compressed_k.clone(),
|
1267 |
self.kernel_size,
|
1268 |
self.kernel_stride,
|
1269 |
self.block_size,
|
|
|
1275 |
None,
|
1276 |
init_blocks=self.init_blocks,
|
1277 |
local_blocks=self.local_blocks,
|
1278 |
+
cache_lens=cache_lens
|
1279 |
)
|
|
|
1280 |
topk_attn_output = infllmv2_attn_varlen_func(
|
1281 |
query_layer,
|
1282 |
key_layer,
|
|
|
1288 |
dropout_p=0.0,
|
1289 |
deterministic=False,
|
1290 |
softmax_scale=None,
|
1291 |
+
causal=max_seqlen_in_batch_q != 1,
|
1292 |
return_attn_probs=False,
|
1293 |
+
# block_window_size=self.window_size // self.block_size,
|
1294 |
topk_idx=topk_idx
|
1295 |
)
|
1296 |
|
1297 |
return topk_attn_output
|
1298 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1299 |
def _flash_attention_forward_dense(
|
1300 |
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
1301 |
):
|
|
|
1434 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
1435 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
1436 |
|
1437 |
+
kv_seq_len = position_ids.max().item() + 1
|
|
|
|
|
1438 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
1439 |
|
1440 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
1757 |
raise ValueError(
|
1758 |
'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.'
|
1759 |
)
|
|
|
1760 |
|
1761 |
+
# Calculate the usable length of past key values
|
1762 |
+
past_key_values_length = past_key_values.get_seq_length() if isinstance(past_key_values, InfLLMv2Cache) else 0
|
1763 |
+
|
1764 |
+
# Initialize InfLLMv2Cache if needed
|
1765 |
if self.config.sparse_config is not None and torch.cuda.is_available() and past_key_values_length == 0:
|
1766 |
+
past_key_values = InfLLMv2Cache(config = self.config, num_hidden_layers=self.config.num_hidden_layers)
|
1767 |
|
1768 |
if position_ids is None:
|
1769 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
1989 |
):
|
1990 |
if past_key_values is not None:
|
1991 |
if isinstance(past_key_values, Cache):
|
1992 |
+
# Use the new Cache class methods
|
1993 |
cache_length = past_key_values.get_seq_length()
|
1994 |
+
|
1995 |
+
if self.config.sparse_config is not None and torch.cuda.is_available() and cache_length == 0:
|
1996 |
+
past_key_values = InfLLMv2Cache(config = self.config, num_hidden_layers=self.config.num_hidden_layers)
|
1997 |
+
past_length = cache_length
|
1998 |
max_cache_length = None
|
1999 |
+
else:
|
2000 |
+
raise ValueError(
|
2001 |
+
'You must use the new past_key_values format, such as the Cache class, instead of the old tuple format.'
|
2002 |
+
)
|
2003 |
|
2004 |
# Keep only the unprocessed tokens:
|
2005 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|