danieldk HF Staff commited on
Commit
c077d9f
·
1 Parent(s): 167406c

Remove source

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +2 -0
  2. build.toml +0 -589
  3. flake.lock +0 -168
  4. flake.nix +0 -35
  5. flash-attn/block.h +0 -139
  6. flash-attn/copy_sm90_bulk_reduce.hpp +0 -49
  7. flash-attn/cuda_check.h +0 -19
  8. flash-attn/epilogue_bwd.hpp +0 -533
  9. flash-attn/epilogue_fwd.hpp +0 -484
  10. flash-attn/flash.h +0 -218
  11. flash-attn/flash_api.cpp +0 -1720
  12. flash-attn/flash_bwd_kernel_sm80.h +0 -173
  13. flash-attn/flash_bwd_kernel_sm90.h +0 -282
  14. flash-attn/flash_bwd_launch_template.h +0 -390
  15. flash-attn/flash_bwd_postprocess_kernel.h +0 -256
  16. flash-attn/flash_bwd_preprocess_kernel.h +0 -252
  17. flash-attn/flash_fwd_combine.cu +0 -13
  18. flash-attn/flash_fwd_combine_kernel.h +0 -482
  19. flash-attn/flash_fwd_combine_launch_template.h +0 -80
  20. flash-attn/flash_fwd_kernel_sm80.h +0 -215
  21. flash-attn/flash_fwd_kernel_sm90.h +0 -458
  22. flash-attn/flash_fwd_launch_template.h +0 -223
  23. flash-attn/flash_prepare_scheduler.cu +0 -124
  24. flash-attn/heuristics.h +0 -59
  25. flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu +0 -18
  26. flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu +0 -12
  27. flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu +0 -18
  28. flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu +0 -12
  29. flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu +0 -6
  30. flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu +0 -18
  31. flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu +0 -12
  32. flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu +0 -18
  33. flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu +0 -12
  34. flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu +0 -6
  35. flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu +0 -18
  36. flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu +0 -12
  37. flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu +0 -18
  38. flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu +0 -12
  39. flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu +0 -6
  40. flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu +0 -18
  41. flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu +0 -12
  42. flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu +0 -18
  43. flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu +0 -12
  44. flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu +0 -6
  45. flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu +0 -18
  46. flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu +0 -12
  47. flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu +0 -18
  48. flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu +0 -12
  49. flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu +0 -6
  50. flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu +0 -18
README.md CHANGED
@@ -11,3 +11,5 @@ attention mechanism, designed to work with large models and long sequences.
11
  This is a Hugging Face compliant kernel build of Flash Attention.
12
 
13
  Original code here [https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention).
 
 
 
11
  This is a Hugging Face compliant kernel build of Flash Attention.
12
 
13
  Original code here [https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention).
14
+
15
+ Kernel source: https://github.com/huggingface/kernels-community/tree/main/flash-attn3
build.toml DELETED
@@ -1,589 +0,0 @@
1
- [general]
2
- name = "flash_attn3"
3
- universal = false
4
- cuda-minver = "12.4"
5
- cuda-maxver = "12.4"
6
-
7
- [torch]
8
- src = [
9
- "torch-ext/pytorch_shim.h",
10
- "torch-ext/torch_binding.cpp",
11
- "torch-ext/torch_binding.h",
12
- ]
13
-
14
- [kernel.flash_attn]
15
- backend = "cuda"
16
- cuda-capabilities = ["8.0", "9.0a"]
17
- cuda-flags = [
18
- "-O3",
19
- "-std=c++17",
20
- "--ftemplate-backtrace-limit=0", # To debug template code
21
- "--use_fast_math",
22
- "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
23
- "-DCUTLASS_ENABLE_GDC_FOR_SM90",
24
- "--expt-relaxed-constexpr",
25
- "--expt-extended-lambda",
26
- "--use_fast_math",
27
- "-DNDEBUG",
28
- ]
29
-
30
- src = [
31
- "flash-attn/cuda_check.h",
32
- "flash-attn/flash_api.cpp",
33
- "flash-attn/flash_fwd_combine.cu",
34
- "flash-attn/flash_fwd_combine_kernel.h",
35
- "flash-attn/flash_fwd_combine_launch_template.h",
36
- "flash-attn/flash.h",
37
- "flash-attn/flash_prepare_scheduler.cu",
38
- "flash-attn/heuristics.h",
39
- "flash-attn/seqlen.h",
40
- "flash-attn/static_switch.h",
41
- "flash-attn/tile_size.h",
42
- "flash-attn/utils.h",
43
- ]
44
- depends = ["torch", "cutlass_3_9"]
45
-
46
- [kernel.flash_attn_sm80]
47
- backend = "cuda"
48
- cuda-capabilities = ["8.0", "9.0a"]
49
- cuda-flags = [
50
- "-O3",
51
- "-std=c++17",
52
- "--ftemplate-backtrace-limit=0", # To debug template code
53
- "--use_fast_math",
54
- "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
55
- "-DCUTLASS_ENABLE_GDC_FOR_SM90",
56
- "--expt-relaxed-constexpr",
57
- "--expt-extended-lambda",
58
- "--use_fast_math",
59
- "-DNDEBUG",
60
- ]
61
- src = [
62
- "flash-attn/block.h",
63
- "flash-attn/copy_sm90_bulk_reduce.hpp",
64
- "flash-attn/epilogue_bwd.hpp",
65
- "flash-attn/epilogue_fwd.hpp",
66
- "flash-attn/flash.h",
67
- "flash-attn/flash_bwd_kernel_sm80.h",
68
- "flash-attn/flash_bwd_kernel_sm90.h",
69
- "flash-attn/flash_bwd_launch_template.h",
70
- "flash-attn/flash_bwd_postprocess_kernel.h",
71
- "flash-attn/flash_bwd_preprocess_kernel.h",
72
- "flash-attn/flash_fwd_launch_template.h",
73
- "flash-attn/flash_fwd_kernel_sm80.h",
74
- "flash-attn/flash_fwd_kernel_sm90.h",
75
- "flash-attn/heuristics.h",
76
- "flash-attn/mainloop_bwd_sm80.hpp",
77
- "flash-attn/mainloop_fwd_sm80.hpp",
78
- "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
79
- "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
80
- "flash-attn/mask.h",
81
- "flash-attn/named_barrier.hpp",
82
- "flash-attn/pack_gqa.h",
83
- "flash-attn/paged_kv.h",
84
- "flash-attn/rotary.h",
85
- "flash-attn/sm90_pipeline_no_cluster.hpp",
86
- "flash-attn/softmax.h",
87
- "flash-attn/tile_size.h",
88
- "flash-attn/tile_scheduler.hpp",
89
-
90
- "flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu",
91
- "flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu",
92
- "flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu",
93
- "flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu",
94
- "flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu",
95
- "flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu",
96
- "flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu",
97
- "flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu",
98
- "flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu",
99
- "flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu",
100
- "flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu",
101
- "flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu",
102
- "flash-attn/instantiations/flash_bwd_hdim64_bf16_sm80.cu",
103
- "flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu",
104
- "flash-attn/instantiations/flash_bwd_hdim64_fp16_sm80.cu",
105
- "flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu",
106
- "flash-attn/instantiations/flash_bwd_hdim96_bf16_sm80.cu",
107
- "flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu",
108
- "flash-attn/instantiations/flash_bwd_hdim96_fp16_sm80.cu",
109
- "flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu",
110
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu",
111
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu",
112
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu",
113
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu",
114
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm80.cu",
115
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu",
116
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu",
117
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu",
118
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu",
119
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu",
120
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu",
121
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu",
122
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_sm80.cu",
123
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu",
124
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu",
125
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu",
126
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu",
127
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu",
128
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu",
129
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu",
130
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_sm80.cu",
131
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu",
132
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu",
133
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu",
134
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu",
135
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu",
136
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu",
137
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu",
138
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_sm80.cu",
139
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu",
140
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu",
141
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu",
142
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu",
143
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu",
144
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu",
145
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu",
146
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_sm80.cu",
147
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu",
148
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu",
149
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu",
150
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu",
151
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu",
152
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu",
153
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu",
154
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_sm80.cu",
155
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu",
156
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu",
157
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu",
158
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu",
159
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu",
160
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu",
161
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu",
162
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_sm80.cu",
163
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu",
164
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu",
165
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu",
166
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu",
167
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu",
168
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu",
169
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu",
170
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_sm80.cu",
171
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu",
172
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu",
173
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu",
174
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu",
175
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu",
176
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu",
177
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu",
178
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_sm80.cu",
179
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu",
180
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu",
181
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu",
182
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu",
183
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu",
184
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu",
185
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu",
186
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_sm80.cu",
187
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu",
188
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu",
189
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu"
190
- ]
191
- include = ["flash-attn"]
192
- depends = ["torch", "cutlass_3_9"]
193
-
194
- [kernel.flash_attn_sm90]
195
- backend = "cuda"
196
- cuda-capabilities = ["8.0", "9.0a"]
197
- cuda-flags = [
198
- "-O3",
199
- "-std=c++17",
200
- "--ftemplate-backtrace-limit=0", # To debug template code
201
- "--use_fast_math",
202
- "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
203
- "-DCUTLASS_ENABLE_GDC_FOR_SM90",
204
- "--expt-relaxed-constexpr",
205
- "--expt-extended-lambda",
206
- "--use_fast_math",
207
- "-DNDEBUG",
208
- ]
209
- src = [
210
- "flash-attn/block.h",
211
- "flash-attn/copy_sm90_bulk_reduce.hpp",
212
- "flash-attn/epilogue_bwd.hpp",
213
- "flash-attn/epilogue_fwd.hpp",
214
- "flash-attn/flash.h",
215
- "flash-attn/flash_bwd_kernel_sm80.h",
216
- "flash-attn/flash_bwd_kernel_sm90.h",
217
- "flash-attn/flash_bwd_launch_template.h",
218
- "flash-attn/flash_bwd_postprocess_kernel.h",
219
- "flash-attn/flash_bwd_preprocess_kernel.h",
220
- "flash-attn/flash_fwd_launch_template.h",
221
- "flash-attn/flash_fwd_kernel_sm80.h",
222
- "flash-attn/flash_fwd_kernel_sm90.h",
223
- "flash-attn/heuristics.h",
224
- "flash-attn/mainloop_bwd_sm80.hpp",
225
- "flash-attn/mainloop_fwd_sm80.hpp",
226
- "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
227
- "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
228
- "flash-attn/mask.h",
229
- "flash-attn/named_barrier.hpp",
230
- "flash-attn/pack_gqa.h",
231
- "flash-attn/paged_kv.h",
232
- "flash-attn/rotary.h",
233
- "flash-attn/sm90_pipeline_no_cluster.hpp",
234
- "flash-attn/softmax.h",
235
- "flash-attn/tile_size.h",
236
- "flash-attn/tile_scheduler.hpp",
237
-
238
- "flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu",
239
- "flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu",
240
- "flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu",
241
- "flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu",
242
- "flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu",
243
- "flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu",
244
- "flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu",
245
- "flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu",
246
- "flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu",
247
- "flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu",
248
- "flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu",
249
- "flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu",
250
- "flash-attn/instantiations/flash_bwd_hdim64_bf16_sm90.cu",
251
- "flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu",
252
- "flash-attn/instantiations/flash_bwd_hdim64_fp16_sm90.cu",
253
- "flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu",
254
- "flash-attn/instantiations/flash_bwd_hdim96_bf16_sm90.cu",
255
- "flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu",
256
- "flash-attn/instantiations/flash_bwd_hdim96_fp16_sm90.cu",
257
- "flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu",
258
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu",
259
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu",
260
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu",
261
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu",
262
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu",
263
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm90.cu",
264
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu",
265
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu",
266
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu",
267
- "flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu",
268
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu",
269
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu",
270
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu",
271
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu",
272
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu",
273
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_sm90.cu",
274
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu",
275
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu",
276
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu",
277
- "flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu",
278
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu",
279
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu",
280
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu",
281
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu",
282
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu",
283
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_sm90.cu",
284
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu",
285
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu",
286
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu",
287
- "flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu",
288
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu",
289
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu",
290
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu",
291
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu",
292
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu",
293
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu",
294
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu",
295
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu",
296
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu",
297
- "flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu",
298
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu",
299
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu",
300
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu",
301
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu",
302
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu",
303
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu",
304
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu",
305
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu",
306
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu",
307
- "flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu",
308
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu",
309
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu",
310
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu",
311
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu",
312
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu",
313
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu",
314
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu",
315
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu",
316
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu",
317
- "flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu",
318
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu",
319
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu",
320
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu",
321
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu",
322
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu",
323
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_sm90.cu",
324
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu",
325
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu",
326
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu",
327
- "flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu",
328
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu",
329
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu",
330
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu",
331
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu",
332
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu",
333
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_sm90.cu",
334
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu",
335
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu",
336
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu",
337
- "flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu",
338
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu",
339
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu",
340
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu",
341
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu",
342
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu",
343
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_sm90.cu",
344
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu",
345
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu",
346
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu",
347
- "flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu",
348
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu",
349
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu",
350
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu",
351
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu",
352
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu",
353
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_sm90.cu",
354
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu",
355
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu",
356
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu",
357
- "flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu",
358
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu",
359
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu",
360
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu",
361
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu",
362
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu",
363
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_sm90.cu",
364
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu",
365
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu",
366
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu",
367
- "flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu",
368
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu",
369
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu",
370
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu",
371
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu",
372
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu",
373
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_sm90.cu",
374
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu",
375
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu",
376
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu",
377
- "flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu",
378
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu",
379
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu",
380
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu",
381
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu",
382
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu",
383
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu",
384
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu",
385
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu",
386
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu",
387
- "flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu",
388
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu",
389
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu",
390
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu",
391
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu",
392
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu",
393
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu",
394
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu",
395
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu",
396
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu",
397
- "flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu",
398
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu",
399
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu",
400
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu",
401
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu",
402
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu",
403
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu",
404
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu",
405
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu",
406
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu",
407
- "flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu",
408
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu",
409
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu",
410
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu",
411
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu",
412
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu",
413
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu",
414
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu",
415
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu",
416
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu",
417
- "flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu",
418
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu",
419
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu",
420
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu",
421
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu",
422
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu",
423
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_sm90.cu",
424
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu",
425
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu",
426
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu",
427
- "flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu",
428
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu",
429
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu",
430
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu",
431
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu",
432
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu",
433
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_sm90.cu",
434
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu",
435
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu",
436
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu",
437
- "flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu",
438
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu",
439
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu",
440
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu",
441
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu",
442
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu",
443
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_sm90.cu",
444
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu",
445
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu",
446
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu",
447
- "flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu",
448
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu",
449
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu",
450
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu",
451
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu",
452
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu",
453
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_sm90.cu",
454
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu",
455
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu",
456
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu",
457
- "flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu",
458
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu",
459
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu",
460
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu",
461
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu",
462
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu",
463
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_sm90.cu",
464
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu",
465
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu",
466
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu",
467
- "flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu",
468
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu",
469
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu",
470
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu",
471
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu",
472
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu",
473
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_sm90.cu",
474
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu",
475
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu",
476
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu",
477
- "flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu",
478
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu",
479
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu",
480
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu",
481
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu",
482
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu",
483
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_sm90.cu",
484
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu",
485
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu",
486
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu",
487
- "flash-attn/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu",
488
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu",
489
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu",
490
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu",
491
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu",
492
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu",
493
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_sm90.cu",
494
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu",
495
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu",
496
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu",
497
- "flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu",
498
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu",
499
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu",
500
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu",
501
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu",
502
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu",
503
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_sm90.cu",
504
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu",
505
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu",
506
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu",
507
- "flash-attn/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu",
508
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu",
509
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu",
510
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu",
511
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu",
512
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu",
513
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu",
514
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu",
515
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu",
516
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu",
517
- "flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu",
518
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu",
519
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu",
520
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu",
521
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu",
522
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu",
523
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu",
524
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu",
525
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu",
526
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu",
527
- "flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu",
528
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu",
529
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu",
530
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu",
531
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu",
532
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu",
533
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu",
534
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu",
535
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu",
536
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu",
537
- "flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu",
538
- ]
539
- include = ["flash-attn"]
540
- depends = ["torch", "cutlass_3_9"]
541
-
542
- # [kernel.flash_attn_sm100]
543
- # backend = "cuda"
544
- # cuda-capabilities = ["8.0", "9.0a", "10.0"]
545
- # cuda-flags = [
546
- # "-O3",
547
- # "-std=c++17",
548
- # "--ftemplate-backtrace-limit=0", # To debug template code
549
- # "--use_fast_math",
550
- # "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
551
- # "-DCUTLASS_ENABLE_GDC_FOR_SM90",
552
- # "--expt-relaxed-constexpr",
553
- # "--expt-extended-lambda",
554
- # "--use_fast_math",
555
- # "-DNDEBUG",
556
- # ]
557
- # src = [
558
- # "flash-attn/block.h",
559
- # "flash-attn/copy_sm90_bulk_reduce.hpp",
560
- # "flash-attn/epilogue_bwd.hpp",
561
- # "flash-attn/epilogue_fwd.hpp",
562
- # "flash-attn/flash.h",
563
- # "flash-attn/flash_bwd_kernel_sm80.h",
564
- # "flash-attn/flash_bwd_kernel_sm90.h",
565
- # "flash-attn/flash_bwd_launch_template.h",
566
- # "flash-attn/flash_bwd_postprocess_kernel.h",
567
- # "flash-attn/flash_bwd_preprocess_kernel.h",
568
- # "flash-attn/flash_fwd_launch_template.h",
569
- # "flash-attn/flash_fwd_kernel_sm80.h",
570
- # "flash-attn/flash_fwd_kernel_sm90.h",
571
- # "flash-attn/heuristics.h",
572
- # "flash-attn/mainloop_bwd_sm80.hpp",
573
- # "flash-attn/mainloop_fwd_sm80.hpp",
574
- # "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
575
- # "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
576
- # "flash-attn/mask.h",
577
- # "flash-attn/named_barrier.hpp",
578
- # "flash-attn/pack_gqa.h",
579
- # "flash-attn/paged_kv.h",
580
- # "flash-attn/rotary.h",
581
- # "flash-attn/sm90_pipeline_no_cluster.hpp",
582
- # "flash-attn/softmax.h",
583
- # "flash-attn/tile_size.h",
584
- # "flash-attn/tile_scheduler.hpp",
585
- #
586
- # "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm100.cu",
587
- # ]
588
- # include = ["flash-attn"]
589
- # depends = ["torch", "cutlass_3_9"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.lock DELETED
@@ -1,168 +0,0 @@
1
- {
2
- "nodes": {
3
- "flake-compat": {
4
- "locked": {
5
- "lastModified": 1747046372,
6
- "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
- "owner": "edolstra",
8
- "repo": "flake-compat",
9
- "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
- "type": "github"
11
- },
12
- "original": {
13
- "owner": "edolstra",
14
- "repo": "flake-compat",
15
- "type": "github"
16
- }
17
- },
18
- "flake-compat_2": {
19
- "locked": {
20
- "lastModified": 1747046372,
21
- "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
22
- "owner": "edolstra",
23
- "repo": "flake-compat",
24
- "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
25
- "type": "github"
26
- },
27
- "original": {
28
- "owner": "edolstra",
29
- "repo": "flake-compat",
30
- "type": "github"
31
- }
32
- },
33
- "flake-utils": {
34
- "inputs": {
35
- "systems": "systems"
36
- },
37
- "locked": {
38
- "lastModified": 1731533236,
39
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
- "owner": "numtide",
41
- "repo": "flake-utils",
42
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
- "type": "github"
44
- },
45
- "original": {
46
- "owner": "numtide",
47
- "repo": "flake-utils",
48
- "type": "github"
49
- }
50
- },
51
- "flake-utils_2": {
52
- "inputs": {
53
- "systems": "systems_2"
54
- },
55
- "locked": {
56
- "lastModified": 1731533236,
57
- "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
- "owner": "numtide",
59
- "repo": "flake-utils",
60
- "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
- "type": "github"
62
- },
63
- "original": {
64
- "owner": "numtide",
65
- "repo": "flake-utils",
66
- "type": "github"
67
- }
68
- },
69
- "hf-nix": {
70
- "inputs": {
71
- "flake-compat": "flake-compat_2",
72
- "flake-utils": "flake-utils_2",
73
- "nixpkgs": "nixpkgs"
74
- },
75
- "locked": {
76
- "lastModified": 1759493343,
77
- "narHash": "sha256-8fhl0gwMAnOkQbogPIVq+Fha+Yeq52FaRXfwF+F9Q+k=",
78
- "owner": "huggingface",
79
- "repo": "hf-nix",
80
- "rev": "b1fc3a18b52447a0f24bc6884418edc5e66082b9",
81
- "type": "github"
82
- },
83
- "original": {
84
- "owner": "huggingface",
85
- "repo": "hf-nix",
86
- "type": "github"
87
- }
88
- },
89
- "kernel-builder": {
90
- "inputs": {
91
- "flake-compat": "flake-compat",
92
- "flake-utils": "flake-utils",
93
- "hf-nix": "hf-nix",
94
- "nixpkgs": [
95
- "kernel-builder",
96
- "hf-nix",
97
- "nixpkgs"
98
- ]
99
- },
100
- "locked": {
101
- "lastModified": 1759516823,
102
- "narHash": "sha256-UJVvZHtS9c64Dm4iZRaOKWB+VHI7jzcazGH57KXWeg8=",
103
- "owner": "huggingface",
104
- "repo": "kernel-builder",
105
- "rev": "e13610a05f67b7296be9ead89ad172a0a088a1c3",
106
- "type": "github"
107
- },
108
- "original": {
109
- "owner": "huggingface",
110
- "repo": "kernel-builder",
111
- "type": "github"
112
- }
113
- },
114
- "nixpkgs": {
115
- "locked": {
116
- "lastModified": 1755963616,
117
- "narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
118
- "owner": "nixos",
119
- "repo": "nixpkgs",
120
- "rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
121
- "type": "github"
122
- },
123
- "original": {
124
- "owner": "nixos",
125
- "ref": "nixos-unstable-small",
126
- "repo": "nixpkgs",
127
- "type": "github"
128
- }
129
- },
130
- "root": {
131
- "inputs": {
132
- "kernel-builder": "kernel-builder"
133
- }
134
- },
135
- "systems": {
136
- "locked": {
137
- "lastModified": 1681028828,
138
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
- "owner": "nix-systems",
140
- "repo": "default",
141
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
- "type": "github"
143
- },
144
- "original": {
145
- "owner": "nix-systems",
146
- "repo": "default",
147
- "type": "github"
148
- }
149
- },
150
- "systems_2": {
151
- "locked": {
152
- "lastModified": 1681028828,
153
- "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
- "owner": "nix-systems",
155
- "repo": "default",
156
- "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
- "type": "github"
158
- },
159
- "original": {
160
- "owner": "nix-systems",
161
- "repo": "default",
162
- "type": "github"
163
- }
164
- }
165
- },
166
- "root": "root",
167
- "version": 7
168
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flake.nix DELETED
@@ -1,35 +0,0 @@
1
- {
2
- description = "Flake for Hopper Flash Attention kernel";
3
-
4
- inputs = {
5
- kernel-builder.url = "github:huggingface/kernel-builder";
6
- };
7
-
8
- outputs =
9
- {
10
- self,
11
- kernel-builder,
12
- }:
13
- kernel-builder.lib.genFlakeOutputs {
14
- inherit self;
15
- path = ./.;
16
- # Building with CUDA later than 12.4 fails with:
17
- #
18
- # error: 'ptxas' died due to signal 11 (Invalid memory reference)
19
- #
20
- # So, build for 12.4 only and copy to all the other build variants
21
- # by hand (which works fine thanks to backward compat).
22
- torchVersions = _: [
23
- {
24
- torchVersion = "2.9";
25
- cudaVersion = "12.4";
26
- cxx11Abi = true;
27
- systems = [
28
- "x86_64-linux"
29
- "aarch64-linux"
30
- ];
31
- bundleBuild = true;
32
- }
33
- ];
34
- };
35
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/block.h DELETED
@@ -1,139 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- namespace flash {
8
-
9
- template <class SeqlenInfo_t, int kBlockM, int kBlockN, bool Is_causal, bool Is_local, bool PackGQA=false, bool Split=false>
10
- struct BlockMN {
11
-
12
- static
13
- CUTLASS_DEVICE
14
- cute::tuple<int, int> get_n_block_min_max(
15
- SeqlenInfo_t const& seqlen_info,
16
- int const m_block, int const bidb, int const split_idx, int const num_splits,
17
- int const window_size_left, int const window_size_right,
18
- cutlass::FastDivmod const& attention_chunk_divmod,
19
- cutlass::FastDivmod const& qhead_per_khead_divmod) {
20
-
21
- int const seqlen_k = seqlen_info.seqlen_k;
22
- int const seqlen_q = seqlen_info.seqlen_q;
23
- int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
24
- if constexpr (Is_causal || Is_local) {
25
- int m_idx_max = (m_block + 1) * kBlockM;
26
- // TODO: check off-by-1 error
27
- if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
28
- int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
29
- int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right;
30
- if (Is_local && attention_chunk_divmod.divisor > 0) {
31
- n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx));
32
- }
33
- n_block_max = std::min(n_block_max, cute::ceil_div(n_idx_right, kBlockN));
34
- }
35
- int n_block_min = 0;
36
- if constexpr (Is_local) {
37
- int m_idx_min = m_block * kBlockM;
38
- if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); }
39
- int const n_idx = m_idx_min + seqlen_k - seqlen_q;
40
- int n_idx_left = n_idx - window_size_left;
41
- if (attention_chunk_divmod.divisor > 0) {
42
- n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx));
43
- }
44
- n_block_min = std::max(int(0), n_idx_left / kBlockN);
45
- }
46
- // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
47
- if constexpr (Split) {
48
- uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
49
- int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
50
- int split_idx_actual = split_idx & 0x0000FFFF;
51
- int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
52
- int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual);
53
- n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split;
54
- n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max);
55
- // if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); }
56
- }
57
- // if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
58
- return {n_block_min, n_block_max};
59
- }
60
-
61
- static
62
- CUTLASS_DEVICE
63
- cute::tuple<int, int> get_n_block_k_new_min_max(
64
- SeqlenInfo_t const& seqlen_info,
65
- int const m_block, int const bidb, int const split_idx, int const num_splits,
66
- int const window_size_left, int const window_size_right,
67
- cutlass::FastDivmod const& attention_chunk_divmod,
68
- cutlass::FastDivmod const& qhead_per_khead_divmod) {
69
-
70
- auto [n_block_min, n_block_max] = get_n_block_min_max(
71
- seqlen_info, m_block, bidb, split_idx, num_splits,
72
- window_size_left, window_size_right, attention_chunk_divmod, qhead_per_khead_divmod);
73
- int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0);
74
- int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
75
- int const n_block_new_min = idx_k_new_min / kBlockN;
76
- int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;
77
- // if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);}
78
- return {n_block_new_min, n_block_new_max};
79
- }
80
-
81
- static
82
- CUTLASS_DEVICE
83
- cute::tuple<int, int> get_m_block_min_max(
84
- SeqlenInfo_t const& seqlen_info,
85
- int const n_block, int const bidb,
86
- int const window_size_left, int const window_size_right, int const sink_token_length) {
87
- // TODO: support attention_chunk
88
- int const seqlen_q = seqlen_info.seqlen_q;
89
- int const seqlen_k = seqlen_info.seqlen_k;
90
- int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
91
- if constexpr (Is_local) {
92
- if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) {
93
- m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM));
94
- }
95
- }
96
- int m_block_min = 0;
97
- if constexpr (Is_causal || Is_local) {
98
- m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM);
99
- }
100
- return {m_block_min, m_block_max};
101
- }
102
-
103
- // If we have separate iterations with causal or local masking at the start, where do we stop
104
- static
105
- CUTLASS_DEVICE
106
- int get_n_block_min_causal_local_mask(
107
- SeqlenInfo_t const& seqlen_info,
108
- int const m_block, int const n_block_min, int const window_size_right,
109
- cutlass::FastDivmod const& attention_chunk_divmod,
110
- cutlass::FastDivmod const& qhead_per_khead_divmod) {
111
- int const m_idx_min = !PackGQA ? m_block * kBlockM : qhead_per_khead_divmod.divide(m_block * kBlockM);
112
- int const n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
113
- int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right;
114
- if (Is_local && attention_chunk_divmod.divisor > 0) {
115
- n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx));
116
- }
117
- return std::max(n_block_min, n_idx_right / kBlockN);
118
- }
119
-
120
- // If we have separate iterations with local masking at the end, where do we stop the non-masked iterations
121
- static
122
- CUTLASS_DEVICE
123
- int get_n_block_min_before_local_mask(
124
- SeqlenInfo_t const& seqlen_info,
125
- int const m_block, int const n_block_min, int const window_size_left,
126
- cutlass::FastDivmod const& attention_chunk_divmod,
127
- cutlass::FastDivmod const& qhead_per_khead_divmod) {
128
- int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1;
129
- int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
130
- int n_idx_left = !Is_local ? n_idx : n_idx - window_size_left;
131
- if (Is_local && attention_chunk_divmod.divisor > 0) {
132
- n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx));
133
- }
134
- return !Is_local ? n_block_min : std::max(n_block_min, cute::ceil_div(n_idx_left, kBlockN));
135
- }
136
-
137
- };
138
-
139
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/copy_sm90_bulk_reduce.hpp DELETED
@@ -1,49 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include<cute/arch/copy_sm90_tma.hpp>
8
-
9
- namespace cute
10
- {
11
-
12
- ////////////////////////////////////////////////////////////////////////////////////////////////////
13
-
14
- struct SM90_BULK_REDUCE_ADD
15
- {
16
- CUTE_HOST_DEVICE static void
17
- copy(float const* smem_ptr,
18
- float * gmem_ptr, int32_t store_bytes)
19
- {
20
- #if defined(CUTE_ARCH_TMA_SM90_ENABLED)
21
- uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
22
- asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n"
23
- :
24
- : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes)
25
- : "memory");
26
- #else
27
- CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED.");
28
- #endif
29
- }
30
-
31
- CUTE_HOST_DEVICE static void
32
- copy(float const* smem_ptr,
33
- float * gmem_ptr, int32_t store_bytes, uint64_t cache_hint)
34
- {
35
- #if defined(CUTE_ARCH_TMA_SM90_ENABLED)
36
- uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
37
- asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [%0], [%1], %2, %3;\n"
38
- :
39
- : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes), "l"(cache_hint)
40
- : "memory");
41
- #else
42
- CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED.");
43
- #endif
44
- }
45
- };
46
-
47
- ////////////////////////////////////////////////////////////////////////////////////////////////////
48
-
49
- } // end namespace cute
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/cuda_check.h DELETED
@@ -1,19 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include <assert.h>
8
- #include <stdlib.h>
9
-
10
- #define CHECK_CUDA(call) \
11
- do { \
12
- cudaError_t status_ = call; \
13
- if (status_ != cudaSuccess) { \
14
- fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
15
- exit(1); \
16
- } \
17
- } while(0)
18
-
19
- #define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/epilogue_bwd.hpp DELETED
@@ -1,533 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cutlass/cutlass.h"
8
- #include "cutlass/barrier.h"
9
- #include "cute/tensor.hpp"
10
-
11
- #include "cutlass/gemm/collective/builders/sm90_common.inl"
12
-
13
- #include "seqlen.h"
14
- #include "named_barrier.hpp"
15
- #include "utils.h"
16
-
17
- namespace flash {
18
-
19
- using namespace cute;
20
-
21
- template <class TileShape_MNK_, class Element_, class ArchTag_,
22
- int NumEpilogueThreads_, bool Varlen_, bool dKV_swapAB_, int AtomLayoutKdKV=1>
23
- struct CollectiveEpilogueBwd {
24
-
25
- using TileShape_MNK = TileShape_MNK_;
26
- using Element = Element_;
27
- using ArchTag = ArchTag_;
28
- static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
29
- static constexpr bool Varlen = Varlen_;
30
- static constexpr bool dKV_swapAB = dKV_swapAB_;
31
- static constexpr bool Use_TMA = !Varlen && ArchTag::kMinComputeCapability >= 90;
32
-
33
- static_assert(ArchTag::kMinComputeCapability >= 80);
34
-
35
- using GmemTiledCopydKVTMA = cute::SM90_TMA_STORE;
36
-
37
- // These are for storing the output tensor without TMA (e.g., for setting output to zero)
38
- static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
39
- static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
40
- static constexpr int kHeadDim = get<2>(TileShape_MNK{});
41
- static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads);
42
- static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
43
- using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
44
- Stride<Int<kGmemThreadsPerRow>, _1>>;
45
- using GmemTiledCopydKV = decltype(
46
- make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
47
- GmemLayoutAtom{},
48
- Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
49
-
50
- using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
51
- // TODO: do we have to change this if dKV_swapAB is true?
52
- decltype(cute::get<1>(TileShape_MNK{})), Int<CUTE_STATIC_V(cute::get<2>(TileShape_MNK{})) / AtomLayoutKdKV>>());
53
- using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{})));
54
- using SmemLayoutdKVtTMA =
55
- decltype(cute::composition(SmemLayoutdKVTMA{},
56
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
57
- make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
58
-
59
- // If we don't use TMA
60
- static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16);
61
- static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
62
- using SmemLayoutAtomdKVSTG =
63
- decltype(composition(Swizzle<kSwizzle, 3, 3>{},
64
- Layout<Shape<Int<8>, Int<kBlockKSmem>>,
65
- Stride<Int<kBlockKSmem>, _1>>{}));
66
-
67
- using SmemLayoutAtomdKV = std::conditional_t<Use_TMA, SmemLayoutAtomdKVTMA, SmemLayoutAtomdKVSTG>;
68
- using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{})));
69
- using SmemLayoutdKVt =
70
- decltype(cute::composition(SmemLayoutdKV{},
71
- make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
72
- make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
73
-
74
- using SmemCopyAtomdKV = Copy_Atom<
75
- std::conditional_t<
76
- ArchTag::kMinComputeCapability >= 90,
77
- std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
78
- AutoVectorizingCopyWithAssumedAlignment<128>
79
- >,
80
- Element>;
81
-
82
- static constexpr size_t SmemAlignmentdKV = ArchTag::kMinComputeCapability >= 90 ? cutlass::detail::alignment_for_swizzle(SmemLayoutdKV{}) : 128;
83
- static_assert(SmemAlignmentdKV >= 128, "Require at least 128B alignment");
84
-
85
- struct TensorStorage : cute::aligned_struct<SmemAlignmentdKV> {
86
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dk;
87
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dv;
88
- };
89
-
90
- using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_k, d, head, batch)
91
- using StridedKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
92
-
93
- using TMA_dKV = std::conditional_t<
94
- Use_TMA,
95
- decltype(make_tma_copy(
96
- GmemTiledCopydKVTMA{},
97
- make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapedKV{}, StridedKV{}),
98
- SmemLayoutdKVTMA{},
99
- select<1, 2>(TileShape_MNK{}),
100
- _1{})), // no mcast for dKV
101
- std::nullptr_t
102
- >;
103
-
104
- // Host side kernel arguments
105
- struct Arguments {
106
- Element* ptr_dK;
107
- ShapedKV const shape_dK;
108
- StridedKV const stride_dK;
109
- Element* ptr_dV;
110
- ShapedKV const shape_dV;
111
- StridedKV const stride_dV;
112
- int const num_heads_q;
113
- int* dk_semaphore;
114
- int* dv_semaphore;
115
- int const* cu_seqlens;
116
- int const* seqused;
117
- };
118
-
119
- // Device side kernel params
120
- struct Params {
121
- Element* ptr_dK;
122
- ShapedKV const shape_dK;
123
- StridedKV const stride_dK;
124
- Element* ptr_dV;
125
- ShapedKV const shape_dV;
126
- StridedKV const stride_dV;
127
- TMA_dKV tma_store_dK, tma_store_dV;
128
- int const* cu_seqlens = nullptr;
129
- int const* seqused = nullptr;
130
- };
131
-
132
- static Params
133
- to_underlying_arguments(Arguments const& args) {
134
- Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK);
135
- Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dV, args.stride_dV);
136
- TMA_dKV tma_store_dK = [&] {
137
- if constexpr (Use_TMA) {
138
- return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV
139
- } else {
140
- return nullptr;
141
- }
142
- }();
143
- TMA_dKV tma_store_dV = [&] {
144
- if constexpr (Use_TMA) {
145
- return make_tma_copy(GmemTiledCopydKVTMA{}, mdV, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV
146
- } else {
147
- return nullptr;
148
- }
149
- }();
150
- return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.shape_dV, args.stride_dV,
151
- tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused};
152
- }
153
-
154
- /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
155
- CUTLASS_DEVICE
156
- static void prefetch_tma_descriptors(Params const& params) {
157
- if constexpr (Use_TMA) {
158
- cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor());
159
- cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor());
160
- }
161
- }
162
-
163
- template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
164
- CUTLASS_DEVICE void
165
- store(Params const& params,
166
- FrgTensorO const& tdKrdK,
167
- FrgTensorO const& tdVrdV,
168
- SharedStorage& shared_storage,
169
- TiledMma tiled_mma,
170
- int thread_idx,
171
- cute::tuple<int32_t, int32_t, int32_t> const& block_coord
172
- ) {
173
-
174
- auto [n_block, bidh, bidb] = block_coord;
175
- Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{}));
176
- Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{}));
177
- Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{}));
178
- Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKVt{}));
179
- auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma);
180
- auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx);
181
-
182
- Tensor tdVrdV_out = make_tensor_like<Element>(tdVrdV);
183
- flash::convert_type_out(tdVrdV, tdVrdV_out);
184
- Tensor tdKrdK_out = make_tensor_like<Element>(tdKrdK);
185
- flash::convert_type_out(tdKrdK, tdKrdK_out);
186
- Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N)
187
- Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N)
188
- // if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf("\n"); print(sdKt); printf("\n"); }
189
- Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdK, sdKt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
190
- Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdV, sdVt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
191
-
192
- // Make sure all WGs have finished reading K and V
193
- flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
194
- cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
195
- cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
196
- if constexpr (Use_TMA) {
197
- cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
198
- cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
199
- cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
200
-
201
- Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK);
202
- Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dV);
203
- Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
204
- Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
205
- auto block_tma_dK = params.tma_store_dK.get_slice(_0{});
206
- auto block_tma_dV = params.tma_store_dV.get_slice(_0{});
207
- Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K)
208
- Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K)
209
- Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K)
210
- Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
211
- int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
212
- if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
213
- cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
214
- cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
215
- if (cute::elect_one_sync()) {
216
- cute::copy(params.tma_store_dV, tdVsdV, tdVgdV);
217
- cute::copy(params.tma_store_dK, tdKsdK, tdKgdK);
218
- tma_store_arrive();
219
- }
220
- }
221
- tma_store_wait<0>();
222
- // // Tell warp 0 that smem_k and smem_v are ready
223
- // cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
224
-
225
- } else {
226
- flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
227
- static constexpr int kBlockN = get<1>(TileShape_MNK{});
228
- flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
229
- bool const is_varlen = Varlen && params.cu_seqlens;
230
- Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
231
- Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
232
- Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
233
- Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
234
-
235
- GmemTiledCopydKV gmem_tiled_copy_dKV;
236
- auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
237
- Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
238
- Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
239
- Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
240
- Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K)
241
- Tensor tdKVrdV = make_fragment_like(tdKVgdV);
242
- Tensor tdKVrdK = make_fragment_like(tdKVgdK);
243
- Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k)
244
- // Repeat the partitioning with identity layouts
245
- Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
246
- Tensor tdKVpdV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));
247
- Tensor tdKVpdK = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));
248
- #pragma unroll
249
- for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); }
250
- #pragma unroll
251
- for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
252
- // Need to check OOB when reading from smem if kBlockN isn't evenly tiled
253
- static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;
254
- flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
255
- gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdV, kBlockN);
256
- flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
257
- gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdK, kBlockN);
258
- // // Tell warp 0 that smem_k and smem_v are ready
259
- // cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v
260
- // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
261
- // Construct identity layout for gdKV
262
- // Clear_OOB_K must be false since we don't want to write zeros to gmem
263
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
264
- gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)
265
- );
266
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
267
- gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdK, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)
268
- );
269
- }
270
- }
271
-
272
- CUTLASS_DEVICE void
273
- store_tail() {
274
- // if constexpr (Use_TMA) { tma_store_wait<0>(); }
275
- }
276
-
277
- // Write 0 to dK and dV
278
- CUTLASS_DEVICE void
279
- store_zero(
280
- Params const& params,
281
- int thread_idx,
282
- cute::tuple<int32_t, int32_t, int32_t> const& block_coord
283
- ) {
284
- static constexpr int kBlockN = get<1>(TileShape_MNK{});
285
- auto [n_block, bidh, bidb] = block_coord;
286
- flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
287
- bool const is_varlen = Varlen && params.cu_seqlens;
288
- Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
289
- Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
290
- Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
291
- Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
292
-
293
- GmemTiledCopydKV gmem_tiled_copy_dKV;
294
- auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
295
- Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
296
- Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
297
- Tensor tdKVrdKV = make_fragment_like(tdKVgdK);
298
- clear(tdKVrdKV);
299
- // Construct identity layout for gdKV
300
- Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
301
- // Repeat the partitioning with identity layouts
302
- Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
303
- Tensor tdKVpdK = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));
304
- Tensor tdKVpdV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));
305
- #pragma unroll
306
- for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
307
- #pragma unroll
308
- for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); }
309
- // Clear_OOB_K must be false since we don't want to write zeros to gmem
310
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
311
- gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdK, seqlen_info.seqlen - n_block * kBlockN
312
- );
313
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
314
- gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdV, seqlen_info.seqlen - n_block * kBlockN
315
- );
316
- }
317
-
318
- };
319
-
320
- template <class TileShape_MNK_, class ElementAccum, class ArchTag_,
321
- int NumEpilogueThreads_, bool Varlen_, bool Deterministic>
322
- struct CollectiveEpilogueBwdGQA {
323
-
324
- using TileShape_MNK = TileShape_MNK_;
325
- using Element = ElementAccum;
326
- using ArchTag = ArchTag_;
327
- static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
328
- static constexpr bool Varlen = Varlen_;
329
- static constexpr bool Use_TMA = ArchTag::kMinComputeCapability >= 90;
330
-
331
- static_assert(ArchTag::kMinComputeCapability >= 80);
332
-
333
- static constexpr int kBlockN = get<1>(TileShape_MNK{});
334
- static constexpr int kHeadDim = get<2>(TileShape_MNK{});
335
- static_assert(NumEpilogueThreads % cutlass::NumThreadsPerWarp == 0, "NumEpilogueThreads must be a multiple of NumThreadsPerWarp");
336
- static constexpr int NumWarpGroups = NumEpilogueThreads / cutlass::NumThreadsPerWarpGroup;
337
- // Thread layout, 256 or 384 threads per row
338
- // We split into NumWarpGroups so that we can use the same postprocessing kernel as dQ
339
- using R2SLayoutAtomdKVaccum = Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumWarpGroups>>>;
340
- using R2STiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdKVaccum{},
341
- Layout<Shape < _4>>{})); // Val layout, 4 vals per store
342
- // For Sm80
343
- using R2GLayoutAtomdKVaccum = Layout<Shape<Int<NumEpilogueThreads>>>;
344
- using R2GTiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2GLayoutAtomdKVaccum{},
345
- Layout<Shape < _1>>{})); // Val layout, 1 vals per store
346
-
347
- using SmemLayoutdKVaccum = Layout<Shape<Int<kBlockN * kHeadDim / NumWarpGroups>, Int<NumWarpGroups>>>;
348
- using SmemLayoutdKVaccumFlat = Layout<Shape<Int<kBlockN * kHeadDim>>>;
349
-
350
- // Strangely without this SmemAlignment, the total smem for hdim 128 (80 x 128) is 228KB even though we
351
- // only need 227KB. We use the same alignment as the non-GQA epilogue to avoid this issue.
352
- static constexpr int SmemAlignment = kHeadDim % 64 == 0 ? 1024 : (kHeadDim % 32 == 0 ? 512 : 256);
353
- struct TensorStorageTMA : cute::aligned_struct<SmemAlignment> {
354
- cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdKVaccum>, SmemAlignment> smem_dkv;
355
- };
356
- struct TensorStorageSTG {
357
- cute::array<ElementAccum, 0> smem_dkv;
358
- };
359
- using TensorStorage = std::conditional_t<Use_TMA, TensorStorageTMA, TensorStorageSTG>;
360
-
361
- using ShapedKV = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_k_rounded * d, head, batch)
362
- using StridedKV = cute::Stride<_1, int64_t, int64_t>;
363
-
364
- // Host side kernel arguments
365
- struct Arguments {
366
- ElementAccum* ptr_dKaccum;
367
- ShapedKV const shape_dKaccum;
368
- StridedKV const stride_dKaccum;
369
- ElementAccum* ptr_dVaccum;
370
- ShapedKV const shape_dVaccum;
371
- StridedKV const stride_dVaccum;
372
- int num_heads_q;
373
- int* dk_semaphore;
374
- int* dv_semaphore;
375
- int const* cu_seqlens;
376
- int const* seqused;
377
- };
378
-
379
- // Device side kernel params
380
- struct Params {
381
- ElementAccum* ptr_dKaccum;
382
- ShapedKV const shape_dKaccum;
383
- StridedKV const stride_dKaccum;
384
- ElementAccum* ptr_dVaccum;
385
- ShapedKV const shape_dVaccum;
386
- StridedKV const stride_dVaccum;
387
- cutlass::FastDivmod qhead_per_khead_divmod;
388
- int* dk_semaphore;
389
- int* dv_semaphore;
390
- int const* cu_seqlens = nullptr;
391
- int const* seqused = nullptr;
392
- };
393
-
394
- static Params
395
- to_underlying_arguments(Arguments const& args) {
396
- if constexpr (Deterministic) {
397
- assert(args.dk_semaphore != nullptr);
398
- assert(args.dv_semaphore != nullptr);
399
- }
400
- return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.shape_dVaccum, args.stride_dVaccum,
401
- cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))),
402
- args.dk_semaphore, args.dv_semaphore,
403
- args.cu_seqlens, args.seqused};
404
- }
405
-
406
- /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
407
- CUTLASS_DEVICE
408
- static void prefetch_tma_descriptors(Params const& params) {
409
- }
410
-
411
- template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
412
- CUTLASS_DEVICE void
413
- store(Params const& params,
414
- FrgTensorO const& tdKrdK,
415
- FrgTensorO const& tdVrdV,
416
- SharedStorage& shared_storage,
417
- TiledMma tiled_mma,
418
- int thread_idx,
419
- cute::tuple<int32_t, int32_t, int32_t> const& block_coord
420
- ) {
421
-
422
- auto [n_block, bidh, bidb] = block_coord;
423
- int bidh_idx_in_group;
424
- int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh);
425
- Tensor sdKV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccum{});
426
- Tensor sdKV_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccumFlat{});
427
- static constexpr int dKV_TMA_num_bytes = CUTE_STATIC_V(size(sdKV_flat)) * sizeof(ElementAccum);
428
-
429
- flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused};
430
- bool const is_varlen = Varlen && params.cu_seqlens;
431
- Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
432
- Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dVaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
433
- Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
434
- Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
435
-
436
- R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum;
437
- auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
438
- Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV);
439
-
440
- // Only used if !Use_TMA
441
- R2GTiledCopydKVaccum r2g_tiled_copy_dKVaccum;
442
- auto r2g_thr_copy_dKVaccum = r2g_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
443
-
444
- // Make sure all WGs have finished reading K and V, otherwise we get racy dQ
445
- // because smem_q could be changed.
446
- flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
447
- if constexpr (Use_TMA) {
448
- Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N)
449
- cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum);
450
- }
451
-
452
- // int const num_batch = params.num_batch;
453
- int const num_batch = get<2>(params.shape_dKaccum);
454
- int const num_head_kv = get<1>(params.shape_dKaccum);
455
- int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv;
456
- using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
457
-
458
- // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
459
-
460
- if constexpr (Deterministic) {
461
- Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
462
- }
463
- // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore);}
464
- if constexpr (Use_TMA) {
465
- cutlass::arch::fence_view_async_shared();
466
- cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
467
- if (thread_idx == 0) {
468
- SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdVaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
469
- tma_store_arrive();
470
- tma_store_wait<0>();
471
- }
472
- } else {
473
- Tensor tdVrdV_atomic = r2g_thr_copy_dKVaccum.retile_S(tdVrdV);
474
- Tensor tdVgdV_atomic = r2g_thr_copy_dKVaccum.partition_D(gdVaccum);
475
- static_assert(CUTE_STATIC_V(size(tdVrdV_atomic)) == CUTE_STATIC_V(size(tdVgdV_atomic)));
476
- #pragma unroll
477
- for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdV_atomic(i), tdVrdV_atomic(i)); }
478
- }
479
- if constexpr (Deterministic) {
480
- Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
481
- }
482
-
483
- if constexpr (Use_TMA) {
484
- cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
485
- Tensor taccdKVrdK = r2s_thr_copy_dKVaccum.retile_S(tdKrdK); // ((Atom,AtomNum), MMA_M, MMA_N)
486
- cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum);
487
- }
488
- lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv;
489
- // if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
490
-
491
- if constexpr (Deterministic) {
492
- Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
493
- }
494
- // if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore);}
495
- if constexpr (Use_TMA) {
496
- cutlass::arch::fence_view_async_shared();
497
- cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
498
- if (thread_idx == 0) {
499
- SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdKaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
500
- tma_store_arrive();
501
- tma_store_wait<0>();
502
- }
503
- } else {
504
- Tensor tdKrdK_atomic = r2g_thr_copy_dKVaccum.retile_S(tdKrdK);
505
- Tensor tdKgdK_atomic = r2g_thr_copy_dKVaccum.partition_D(gdKaccum);
506
- static_assert(CUTE_STATIC_V(size(tdKrdK_atomic)) == CUTE_STATIC_V(size(tdKgdK_atomic)));
507
- #pragma unroll
508
- for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdK_atomic(i), tdKrdK_atomic(i)); }
509
- }
510
- if constexpr (Deterministic) {
511
- Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
512
- }
513
- // // Tell warp 0 that smem_k and smem_v are ready
514
- // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
515
- }
516
-
517
- CUTLASS_DEVICE void
518
- store_tail() {
519
- }
520
-
521
- // Write 0 to dK and dV
522
- CUTLASS_DEVICE void
523
- store_zero(
524
- Params const& params,
525
- int thread_idx,
526
- cute::tuple<int32_t, int32_t, int32_t> const& block_coord
527
- ) {
528
- // Don't need to do anything since dKaccum and dVaccum are already zero-initialized
529
- }
530
-
531
- };
532
-
533
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/epilogue_fwd.hpp DELETED
@@ -1,484 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include <cutlass/cutlass.h>
8
- #include <cutlass/fast_math.h> // For FastDivMod
9
- #include "cute/tensor.hpp"
10
-
11
- #include "cutlass/gemm/collective/builders/sm90_common.inl"
12
- #include "cutlass/epilogue/collective/builders/sm90_common.inl"
13
-
14
- #include "seqlen.h"
15
- #include "named_barrier.hpp"
16
- #include "pack_gqa.h"
17
- #include "utils.h"
18
-
19
- namespace flash {
20
-
21
- using namespace cute;
22
-
23
- template <class TileShape_MNK_PV_, class ClusterShape_, class Element_, class ArchTag_,
24
- int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false>
25
- struct CollectiveEpilogueFwd {
26
-
27
- using TileShape_MNK_PV = TileShape_MNK_PV_;
28
- using ClusterShape = ClusterShape_;
29
- using Element = Element_;
30
- using ElementPartial = float;
31
- using ArchTag = ArchTag_;
32
- static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
33
- static constexpr bool Varlen = Varlen_;
34
- static constexpr bool PackGQA = PackGQA_;
35
- static constexpr bool Split = Split_;
36
- static constexpr bool Use_smem = !(Split && !Varlen);
37
- static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA;
38
-
39
- static_assert(ArchTag::kMinComputeCapability >= 80);
40
- static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1);
41
- static_assert(sizeof(Element) <= 2);
42
-
43
- static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
44
- static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{});
45
-
46
- static constexpr bool LargeHeadDimV = kHeadDimV > 256;
47
-
48
- using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
49
-
50
- // These are for storing the output tensor without TMA (e.g., for setting output to zero)
51
- static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element);
52
- static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore");
53
- // We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements
54
- // in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times
55
- // we need to call divmod.
56
- static constexpr int kBytePerRow = kHeadDimV * sizeof(Element);
57
- static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
58
- static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore;
59
- // If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp
60
- static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0);
61
- static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
62
- using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
63
- Stride<Int<kGmemThreadsPerRow>, _1>>;
64
- static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow");
65
- using GmemTiledCopyO = decltype(
66
- make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
67
- GmemLayoutAtom{},
68
- Layout<Shape<_1, Int<kGmemElemsPerStore>>>{})); // Val layout, 8 or 16 vals per store
69
-
70
- using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
71
- decltype(cute::get<0>(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>());
72
- using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{})));
73
- static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));
74
- static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);
75
- using SmemLayoutAtomO = decltype(
76
- composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},
77
- Layout<Shape<_8, Int<kBlockKGmem>>,
78
- Stride<Int<kBlockKGmem>, _1>>{}));
79
- using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{})));
80
- using SmemLayoutO = std::conditional_t<ArchTag::kMinComputeCapability >= 90, SmemLayoutOTMA, SmemLayoutOSTS>;
81
-
82
- using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch, num_splits)
83
- using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
84
- using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits)
85
- // ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)
86
- using ShapeOPacked = std::conditional_t<!PackGQA, ShapeO, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t, int32_t>>;
87
- using StrideOPacked = std::conditional_t<!PackGQA, StrideO, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t, int64_t>>;
88
- // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits)
89
- using ShapeLSEPacked = std::conditional_t<!PackGQA, cute::Shape<int32_t, int32_t, int32_t, int32_t>, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
90
- using StrideLSEPacked = std::conditional_t<!PackGQA, StrideLSE, cute::Stride<cute::Stride<int64_t, _1>, int64_t, int64_t, int64_t>>;
91
-
92
- using CopyOpR2S = std::conditional_t<
93
- ArchTag::kMinComputeCapability >= 90,
94
- // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16)
95
- decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element>()),
96
- AutoVectorizingCopyWithAssumedAlignment<128>
97
- >;
98
- using SmemCopyAtomO = Copy_Atom<CopyOpR2S, Element>;
99
-
100
- // static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{});
101
- // static_assert(SmemAlignmentO >= 128, "Require at least 128B alignment");
102
- // struct TensorStorage : cute::aligned_struct<SmemAlignmentO> {
103
- // cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0, SmemAlignmentO> smem_o;
104
- // };
105
- struct TensorStorage : cute::aligned_struct<128> {
106
- cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0> smem_o;
107
- };
108
-
109
- using TMA_O = std::conditional_t<
110
- Use_TMA_O,
111
- decltype(make_tma_copy(
112
- GmemTiledCopyOTMA{},
113
- make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeO{}, StrideO{}),
114
- SmemLayoutOTMA{},
115
- select<0, 1>(TileShape_MNK_PV{}),
116
- _1{})), // no mcast for O
117
- std::nullptr_t
118
- >;
119
-
120
- // Host side kernel arguments
121
- struct Arguments {
122
- Element* ptr_O;
123
- ShapeO const shape_O;
124
- StrideO const stride_O;
125
- ElementPartial* ptr_O_partial;
126
- StrideO const stride_O_partial;
127
- float* ptr_LSE;
128
- StrideLSE const stride_LSE;
129
- float* ptr_LSE_partial;
130
- StrideLSE const stride_LSE_partial;
131
- int32_t const nheads_kv;
132
- int const* cu_seqlens = nullptr;
133
- int const* seqused = nullptr;
134
- };
135
-
136
- // Device side kernel params
137
- struct Params {
138
- Element* ptr_O;
139
- ShapeO const shape_O;
140
- StrideO const stride_O;
141
- ShapeOPacked const shape_O_packed;
142
- StrideOPacked const stride_O_packed;
143
- ElementPartial* ptr_O_partial;
144
- StrideO const stride_O_partial;
145
- StrideOPacked const stride_O_partial_packed;
146
- float* ptr_LSE;
147
- StrideLSE const stride_LSE;
148
- ShapeLSEPacked const shape_LSE_packed;
149
- StrideLSEPacked const stride_LSE_packed;
150
- float* ptr_LSE_partial;
151
- StrideLSE const stride_LSE_partial;
152
- StrideLSEPacked const stride_LSE_partial_packed;
153
- cutlass::FastDivmod qhead_per_khead_divmod;
154
- TMA_O tma_store_O;
155
- int const* cu_seqlens = nullptr;
156
- int const* seqused = nullptr;
157
- };
158
-
159
- static Params
160
- to_underlying_arguments(Arguments const& args) {
161
- Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
162
- TMA_O tma_store_O = [&]{
163
- if constexpr (Use_TMA_O) {
164
- return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast
165
- } else {
166
- return nullptr;
167
- }
168
- }();
169
- // If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits)
170
- int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv);
171
- auto const shape_O_packed = cute::conditional_return<!PackGQA>(
172
- args.shape_O,
173
- make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
174
- );
175
- auto const stride_O_packed = cute::conditional_return<!PackGQA>(
176
- args.stride_O,
177
- make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O))
178
- );
179
- auto const stride_O_partial_packed = cute::conditional_return<!PackGQA>(
180
- args.stride_O_partial,
181
- make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial))
182
- );
183
- // If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits)
184
- auto const shape_LSE_packed = cute::conditional_return<!PackGQA>(
185
- select<0, 2, 3, 4>(args.shape_O),
186
- make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
187
- );
188
- auto const stride_LSE_packed = cute::conditional_return<!PackGQA>(
189
- args.stride_LSE,
190
- make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE))
191
- );
192
- auto const stride_LSE_partial_packed = cute::conditional_return<!PackGQA>(
193
- args.stride_LSE_partial,
194
- make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial))
195
- );
196
- return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed,
197
- args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed,
198
- args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed,
199
- args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed,
200
- cutlass::FastDivmod(qhead_per_khead),
201
- tma_store_O, args.cu_seqlens, args.seqused};
202
- }
203
-
204
- /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
205
- CUTLASS_DEVICE
206
- static void prefetch_tma_descriptors(Params const& params) {
207
- if constexpr (Use_TMA_O) {
208
- cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor());
209
- }
210
- }
211
-
212
- template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
213
- CUTLASS_DEVICE void
214
- store(Params const& params,
215
- FrgTensorO& tOrO,
216
- FrgTensorLSE const& lse,
217
- SharedStorage& shared_storage,
218
- TiledMma tiled_mma,
219
- int thread_idx,
220
- cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
221
- ) {
222
-
223
- auto [m_block, bidh, bidb, split_idx] = block_coord;
224
- int num_splits = get<4>(params.shape_O_packed);
225
- if constexpr (Split && Varlen) {
226
- uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
227
- int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
228
- num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
229
- split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx
230
- }
231
- bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
232
-
233
- Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{});
234
- // Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO);
235
-
236
- static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4);
237
- // If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion.
238
- // Otherwise we can permute after conversion.
239
- if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); }
240
- Tensor tOrO_out = make_tensor_like<Element>(tOrO);
241
- flash::convert_type_out(tOrO, tOrO_out);
242
- if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); }
243
-
244
- // Make sure all WGs have finished reading V
245
- // Technically we don't need this if we're not using smem, but the mainloop makes the assumption that
246
- // all epilogue threads sync at least once during the epilogue (so that we can start loading Q with
247
- // cp.async if we need).
248
- flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
249
-
250
- // Step 1: Write O from rmem -> smem
251
- if constexpr (Use_smem) {
252
- auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
253
- auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
254
- Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
255
- Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
256
- // Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi); // ((Atom,AtomNum),PIPE_M,PIPE_N)
257
- cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
258
- if constexpr (Use_TMA_O) {
259
- cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
260
- cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
261
- cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
262
- } else {
263
- flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
264
- }
265
- } else {
266
- if constexpr (ArchTag::kMinComputeCapability >= 90) {
267
- #pragma unroll
268
- for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
269
- shared_storage.pipelines.barrier_O.arrive(cta_id);
270
- }
271
- }
272
- }
273
-
274
- flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
275
- bool is_varlen = Varlen && params.cu_seqlens;
276
- int offset_o = seqlen_info.offset;
277
- int seqlen_o = seqlen_info.seqlen;
278
- int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);
279
-
280
- // Step 2: Write LSE from rmem -> gmem
281
- auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
282
- // (MMA,MMA_M,MMA_K)
283
- Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
284
- static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
285
- static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
286
- Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()));
287
- Tensor taccOcO_row = taccOcO_rowcol(_, _0{});
288
- CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
289
-
290
- using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;
291
- using PackGQApartial_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>;
292
-
293
- Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
294
- params.shape_LSE_packed,
295
- !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
296
- // if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); }
297
- if (!LargeHeadDimV || warp_group_idx == 0) {
298
- if constexpr (!PackGQA) {
299
- #pragma unroll
300
- for (int mi = 0; mi < size(lse); ++mi) {
301
- int const row = m_block * kBlockM + get<0>(taccOcO_row(mi));
302
- if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); }
303
- }
304
- } else {
305
- PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
306
- }
307
- }
308
-
309
- // Step 3: Write O from smem -> gmem
310
- if constexpr (Use_TMA_O) {
311
- Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx);
312
- Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
313
- auto block_tma_O = params.tma_store_O.get_slice(_0{});
314
- Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
315
- Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
316
- int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
317
- if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
318
- cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
319
- cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
320
- if (cute::elect_one_sync()) {
321
- cute::copy(params.tma_store_O, tOsO, tOgO);
322
- tma_store_arrive();
323
- tma_store_wait<0>();
324
- #pragma unroll
325
- for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
326
- shared_storage.pipelines.barrier_O.arrive(cta_id);
327
- }
328
- }
329
- }
330
- } else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence
331
- if (!is_split) {
332
- Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});
333
- Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
334
- // if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast<int>(&mO(0)) - reinterpret_cast<int>(params.ptr_O)); }
335
- GmemTiledCopyO gmem_tiled_copy_O;
336
- auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
337
- Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
338
- // Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi); // ((Atom,AtomNum),ATOM_M,ATOM_N)
339
- Tensor tOrO = make_fragment_like(tOsO);
340
- cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
341
- if constexpr (ArchTag::kMinComputeCapability >= 90) {
342
- cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v
343
- #pragma unroll
344
- for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
345
- shared_storage.pipelines.barrier_O.arrive(cta_id);
346
- }
347
- }
348
- if constexpr (!PackGQA) {
349
- // (BLK_M,BLK_K) -> (blk_m,blk_k)
350
- Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
351
- Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOsO)));
352
- #pragma unroll
353
- for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
354
- Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
355
- // Clear_OOB_K must be false since we don't want to write zeros to gmem
356
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
357
- gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
358
- );
359
- } else {
360
- // If PackGQA, we split the work of compute O_ptr among threads in the same row
361
- PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
362
- }
363
- } else {
364
- Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx);
365
- Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
366
- // We already arrived on barrier_O earlier if !Use_smem
367
- if constexpr (Use_smem) {
368
- if constexpr (ArchTag::kMinComputeCapability >= 90) {
369
- #pragma unroll
370
- for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
371
- shared_storage.pipelines.barrier_O.arrive(cta_id);
372
- }
373
- }
374
- }
375
- if constexpr (!PackGQA) {
376
- static constexpr int kGmemElemsPerStoreDirect = 2;
377
- cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial> gmem_copy_direct;
378
- // Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
379
- Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout()));
380
- Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
381
- Tensor tOgO = thread_mma.partition_C(gOpartial);
382
- Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout()));
383
- Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
384
- Tensor taccOcO_col = taccOcO_rowcol(_0{}, _);
385
- #pragma unroll
386
- for (int m = 0; m < size(taccOcO_row); ++m) {
387
- if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) {
388
- #pragma unroll
389
- for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) {
390
- if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) {
391
- cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k));
392
- }
393
- }
394
- }
395
- }
396
- } else {
397
- PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
398
- }
399
- }
400
- }
401
- }
402
-
403
- CUTLASS_DEVICE void
404
- store_tail() {
405
- // Don't need to do tma_store_wait<0>() here since we already did in @store
406
- }
407
-
408
- // Write 0 to output and -inf to LSE
409
- CUTLASS_DEVICE void
410
- store_zero(
411
- Params const& params,
412
- int thread_idx,
413
- cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
414
- ) {
415
- static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
416
- auto [m_block, bidh, bidb, split_idx] = block_coord;
417
- int num_splits = get<4>(params.shape_O_packed);
418
- if constexpr (Split && Varlen) {
419
- uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
420
- int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
421
- num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
422
- split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx
423
- }
424
- bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
425
-
426
- flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
427
- bool const is_varlen = Varlen && params.cu_seqlens;
428
- int offset_o = seqlen_info.offset;
429
- int seqlen_o = seqlen_info.seqlen;
430
- int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;
431
- Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
432
- params.shape_LSE_packed,
433
- !is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
434
- Tensor gLSE = local_tile(mLSE, Shape<Int<kBlockM>>{}, make_coord(m_block));
435
-
436
- static_assert(kBlockM <= NumEpilogueThreads);
437
- if (thread_idx < kBlockM) {
438
- const int row = m_block * kBlockM + thread_idx;
439
- if constexpr (!PackGQA) {
440
- if (row < seqlen_o) { mLSE(row) = -INFINITY; }
441
- } else {
442
- if (row < seqlen_o * qhead_per_khead) {
443
- int m_idx, h_idx;
444
- m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row);
445
- // mLSE has shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord"
446
- mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY;
447
- }
448
- }
449
- }
450
-
451
- // If split, we don't have to write 0 to mOpartial if the mha_combine kernel is used,
452
- // since it will not use the value of O if LSE is -inf.
453
- if (!is_split) {
454
- Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});
455
-
456
- GmemTiledCopyO gmem_tiled_copy_O;
457
- auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
458
- Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
459
- if constexpr (!PackGQA) {
460
- Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
461
- #pragma unroll
462
- for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
463
- Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
464
- Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
465
- Tensor tOrO = make_fragment_like(tOgO);
466
- cute::clear(tOrO);
467
- // Clear_OOB_K must be false since we don't want to write zeros to gmem
468
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
469
- gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
470
- );
471
- } else {
472
- // If PackGQA, we split the work of compute O_ptr among threads in the same row
473
- using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;
474
- Tensor tOrO = make_tensor<Element>(make_shape(Shape<_1, Int<kGmemElemsPerStore>>{}, size<1>(tOcO), size<2>(tOcO)));
475
- cute::clear(tOrO);
476
- PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
477
- }
478
- }
479
-
480
- }
481
-
482
- };
483
-
484
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash.h DELETED
@@ -1,218 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2023, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include <cuda.h>
8
- #include <vector>
9
-
10
- ////////////////////////////////////////////////////////////////////////////////////////////////////
11
-
12
- struct Qkv_params {
13
- using index_t = int64_t;
14
- // The QKV matrices.
15
- void *__restrict__ q_ptr;
16
- void *__restrict__ k_ptr;
17
- void *__restrict__ v_ptr;
18
-
19
- // The stride between rows of the Q, K and V matrices.
20
- index_t q_batch_stride;
21
- index_t k_batch_stride;
22
- index_t v_batch_stride;
23
- index_t q_row_stride;
24
- index_t k_row_stride;
25
- index_t v_row_stride;
26
- index_t q_head_stride;
27
- index_t k_head_stride;
28
- index_t v_head_stride;
29
- index_t v_dim_stride;
30
-
31
- // The number of heads.
32
- int h, h_k;
33
- };
34
-
35
- ////////////////////////////////////////////////////////////////////////////////////////////////////
36
-
37
- struct Flash_fwd_params : public Qkv_params {
38
- using index_t = int64_t;
39
-
40
- // The O matrix (output).
41
- void * __restrict__ o_ptr;
42
- void * __restrict__ oaccum_ptr;
43
-
44
- // The stride between rows of O.
45
- index_t o_batch_stride;
46
- index_t o_row_stride;
47
- index_t o_head_stride;
48
-
49
- // The pointer to the softmax sum.
50
- void * __restrict__ softmax_lse_ptr;
51
- void * __restrict__ softmax_lseaccum_ptr;
52
-
53
- // For FP8 scaling
54
- float * __restrict__ q_descale_ptr;
55
- float * __restrict__ k_descale_ptr;
56
- float * __restrict__ v_descale_ptr;
57
- index_t q_descale_batch_stride;
58
- index_t q_descale_head_stride;
59
- index_t k_descale_batch_stride;
60
- index_t k_descale_head_stride;
61
- index_t v_descale_batch_stride;
62
- index_t v_descale_head_stride;
63
-
64
- // The dimensions.
65
- int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
66
- int total_q, total_k, total_knew;
67
- int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q
68
- int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim
69
-
70
- // The scaling factors for the kernel.
71
- float scale_softmax;
72
- float softcap;
73
-
74
- // array of length b+1 holding starting offset of each sequence.
75
- int * __restrict__ cu_seqlens_q;
76
- int * __restrict__ cu_seqlens_k;
77
- int * __restrict__ cu_seqlens_knew;
78
- int * __restrict__ leftpad_k;
79
-
80
- // If provided, the actual length of each q/k sequence.
81
- int *__restrict__ seqused_q;
82
- int *__restrict__ seqused_k;
83
-
84
- // The stride between rows of Oaccum.
85
- index_t oaccum_split_stride;
86
- index_t oaccum_batch_stride;
87
- index_t oaccum_row_stride;
88
- index_t oaccum_head_stride;
89
-
90
- // The stride between rows of LSEaccum.
91
- index_t lseaccum_split_stride;
92
- index_t lseaccum_batch_stride;
93
- index_t lseaccum_head_stride;
94
-
95
- // The K_new and V_new matrices.
96
- void * __restrict__ knew_ptr;
97
- void * __restrict__ vnew_ptr;
98
-
99
- // The stride between rows of the Q, K and V matrices.
100
- index_t knew_batch_stride;
101
- index_t vnew_batch_stride;
102
- index_t knew_row_stride;
103
- index_t vnew_row_stride;
104
- index_t knew_head_stride;
105
- index_t vnew_head_stride;
106
-
107
- void *__restrict__ qv_ptr;
108
- index_t qv_batch_stride;
109
- index_t qv_row_stride;
110
- index_t qv_head_stride;
111
-
112
- // The cos and sin matrices for rotary embedding.
113
- void * __restrict__ rotary_cos_ptr;
114
- void * __restrict__ rotary_sin_ptr;
115
- int *__restrict__ seqlens_rotary;
116
-
117
- // The indices to index into the KV cache.
118
- int * __restrict__ kv_batch_idx;
119
-
120
- // Paged KV cache
121
- int * __restrict__ page_table;
122
- index_t page_table_batch_stride;
123
- int page_size;
124
- int num_pages;
125
- bool pagedkv_tma;
126
-
127
- // The dropout probability (probability of keeping an activation).
128
- float p_dropout;
129
- // uint32_t p_dropout_in_uint;
130
- // uint16_t p_dropout_in_uint16_t;
131
- uint8_t p_dropout_in_uint8_t;
132
-
133
- // Scale factor of 1 / (1 - p_dropout).
134
- float rp_dropout;
135
-
136
- // Local window size
137
- int window_size_left, window_size_right;
138
- int attention_chunk;
139
-
140
- // Pointer to the RNG seed (idx 0) and offset (idx 1).
141
- uint64_t * rng_state;
142
-
143
- bool is_bf16;
144
- bool is_fp32;
145
- bool is_e4m3;
146
- bool is_causal;
147
- bool is_local;
148
-
149
- bool is_rotary_interleaved;
150
-
151
- int num_splits; // For split-KV version
152
- bool pack_gqa;
153
-
154
- int * __restrict__ tile_count_semaphore;
155
- // int * __restrict__ num_m_blocks_ptr;
156
- // int * __restrict__ num_n_blocks_ptr;
157
- int * __restrict__ num_splits_dynamic_ptr;
158
- bool skip_scheduler_metadata_computation;
159
-
160
- int arch;
161
- int num_sm;
162
- };
163
-
164
- ////////////////////////////////////////////////////////////////////////////////////////////////////
165
-
166
- struct Flash_bwd_params : public Flash_fwd_params {
167
- using index_t = int64_t;
168
-
169
- // The dO and dQKV matrices.
170
- void *__restrict__ do_ptr;
171
- void *__restrict__ dq_ptr;
172
- void *__restrict__ dk_ptr;
173
- void *__restrict__ dv_ptr;
174
-
175
- // To accumulate dQ
176
- void *__restrict__ dq_accum_ptr;
177
- void *__restrict__ dk_accum_ptr;
178
- void *__restrict__ dv_accum_ptr;
179
-
180
- // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
181
- // dimension void *__restrict__ dk_accum_ptr; void *__restrict__
182
- // dv_accum_ptr;
183
-
184
- // The stride between rows of the dO, dQ, dK and dV matrices.
185
- index_t do_batch_stride;
186
- index_t do_row_stride;
187
- index_t do_head_stride;
188
- index_t dq_batch_stride;
189
- index_t dk_batch_stride;
190
- index_t dv_batch_stride;
191
- index_t dq_row_stride;
192
- index_t dk_row_stride;
193
- index_t dv_row_stride;
194
- index_t dq_head_stride;
195
- index_t dk_head_stride;
196
- index_t dv_head_stride;
197
-
198
- // The pointer to the softmax d sum.
199
- void *__restrict__ dsoftmax_sum;
200
- void *__restrict__ softmax_lse_log2_ptr;
201
-
202
- int *__restrict__ dq_semaphore;
203
- int *__restrict__ dk_semaphore;
204
- int *__restrict__ dv_semaphore;
205
-
206
- bool deterministic;
207
- index_t dq_accum_split_stride;
208
- };
209
-
210
- ////////////////////////////////////////////////////////////////////////////////////////////////////
211
-
212
- template <int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>
213
- void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
214
- void prepare_varlen_num_blocks(Flash_fwd_params &params, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl);
215
- template <int Arch, typename T, int kHeadDim, bool Has_softcap>
216
- void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
217
- template <typename T, typename Tpartial, int kBlockK>
218
- void run_mha_fwd_combine_(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_api.cpp DELETED
@@ -1,1720 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #include <Python.h>
6
- #include <torch/nn/functional/padding.h>
7
- #include <ATen/cuda/CUDAContextLight.h>
8
- #include <c10/cuda/CUDAGuard.h>
9
-
10
- #include <cutlass/numeric_types.h>
11
-
12
- #include "flash.h"
13
- #include "static_switch.h"
14
- #include "tile_size.h"
15
- #include "heuristics.h"
16
- #include "cuda_check.h"
17
-
18
-
19
- extern "C" {
20
- /* Creates a dummy empty _C module that can be imported from Python.
21
- The import from Python will load the .so consisting of this file
22
- in this extension, so that the TORCH_LIBRARY static initializers
23
- below are run. */
24
- PyObject* PyInit__C(void)
25
- {
26
- static struct PyModuleDef module_def = {
27
- PyModuleDef_HEAD_INIT,
28
- "_C", /* name of module */
29
- NULL, /* module documentation, may be NULL */
30
- -1, /* size of per-interpreter state of the module,
31
- or -1 if the module keeps state in global variables. */
32
- NULL, /* methods */
33
- };
34
- return PyModule_Create(&module_def);
35
- }
36
- }
37
-
38
- #define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
39
- #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
40
- #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
41
-
42
- void set_params_fprop(Flash_fwd_params &params,
43
- // sizes
44
- const size_t b,
45
- const size_t seqlen_q,
46
- const size_t seqlen_k,
47
- const size_t seqlen_q_rounded,
48
- const size_t seqlen_k_rounded,
49
- const size_t h,
50
- const size_t h_k,
51
- const size_t d,
52
- const size_t d_rounded,
53
- // device pointers
54
- const at::Tensor q,
55
- const at::Tensor k,
56
- const at::Tensor v,
57
- at::Tensor out,
58
- void *cu_seqlens_q_d,
59
- void *cu_seqlens_k_d,
60
- void *seqused_q,
61
- void *seqused_k,
62
- void *softmax_lse_d,
63
- float p_dropout,
64
- float softmax_scale,
65
- int window_size_left,
66
- int window_size_right,
67
- int attention_chunk,
68
- const float softcap=0.f,
69
- const int sm_margin=0) {
70
-
71
- // Reset the parameters
72
- params = {};
73
-
74
- params.is_bf16 = q.dtype() == torch::kBFloat16;
75
- params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
76
-
77
- // Set the pointers and strides.
78
- params.q_ptr = q.data_ptr();
79
- params.k_ptr = k.data_ptr();
80
- params.v_ptr = v.data_ptr();
81
- // All stride are in elements, not bytes.
82
- params.q_row_stride = q.stride(-3);
83
- params.k_row_stride = k.stride(-3);
84
- params.v_row_stride = v.stride(-3);
85
- params.q_head_stride = q.stride(-2);
86
- params.k_head_stride = k.stride(-2);
87
- params.v_head_stride = v.stride(-2);
88
- params.v_dim_stride = v.stride(-1);
89
- params.o_ptr = out.data_ptr();
90
- params.o_row_stride = out.stride(-3);
91
- params.o_head_stride = out.stride(-2);
92
-
93
- if (cu_seqlens_q_d == nullptr) {
94
- params.q_batch_stride = q.stride(0);
95
- params.o_batch_stride = out.stride(0);
96
- }
97
- if (cu_seqlens_k_d == nullptr) {
98
- params.k_batch_stride = k.stride(0);
99
- params.v_batch_stride = v.stride(0);
100
- }
101
-
102
- params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
103
- params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
104
- params.seqused_q = static_cast<int *>(seqused_q);
105
- params.seqused_k = static_cast<int *>(seqused_k);
106
-
107
- // Softmax sum
108
- params.softmax_lse_ptr = softmax_lse_d;
109
-
110
- // Set the dimensions.
111
- params.b = b;
112
- params.h = h;
113
- params.h_k = h_k;
114
- params.seqlen_q = seqlen_q;
115
- params.seqlen_k = seqlen_k;
116
- params.seqlen_q_rounded = seqlen_q_rounded;
117
- params.seqlen_k_rounded = seqlen_k_rounded;
118
- params.d = d;
119
- params.d_rounded = d_rounded;
120
-
121
- // Set the different scale values.
122
- params.scale_softmax = softmax_scale;
123
- params.softcap = softcap;
124
-
125
- // Set this to probability of keeping an element to simplify things.
126
- params.p_dropout = 1.f - p_dropout;
127
- // Convert p from float to int so we don't have to convert the random uint to float to compare.
128
- // [Minor] We want to round down since when we do the comparison we use <= instead of <
129
- // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
130
- // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
131
- params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
132
- params.rp_dropout = 1.f / params.p_dropout;
133
- TORCH_CHECK(p_dropout < 1.f);
134
- #ifdef FLASHATTENTION_DISABLE_DROPOUT
135
- TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
136
- #endif
137
-
138
- // Causal is the special case where window_size_right == 0 and window_size_left < 0.
139
- // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
140
- params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0;
141
- params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal;
142
-
143
- // TODO: check this
144
- if (window_size_left < 0) { window_size_left = seqlen_k - 1; }
145
- if (window_size_right < 0) { window_size_right = seqlen_q - 1; }
146
- if (attention_chunk > 0) {
147
- window_size_left = std::min(window_size_left, attention_chunk - 1);
148
- window_size_right = std::min(window_size_right, attention_chunk - 1);
149
- }
150
- params.window_size_left = window_size_left;
151
- params.window_size_right = window_size_right;
152
- params.attention_chunk = attention_chunk;
153
-
154
- params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
155
- params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
156
-
157
- #ifdef FLASHATTENTION_DISABLE_LOCAL
158
- TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
159
- #endif
160
- }
161
-
162
- void set_params_dgrad(Flash_bwd_params &params,
163
- // sizes
164
- const size_t b,
165
- const size_t seqlen_q,
166
- const size_t seqlen_k,
167
- const size_t seqlen_q_rounded,
168
- const size_t seqlen_k_rounded,
169
- const size_t h,
170
- const size_t h_k,
171
- const size_t d,
172
- const size_t d_rounded,
173
- // device pointers
174
- const at::Tensor q,
175
- const at::Tensor k,
176
- const at::Tensor v,
177
- const at::Tensor out,
178
- const at::Tensor dout,
179
- at::Tensor dq,
180
- at::Tensor dk,
181
- at::Tensor dv,
182
- void *cu_seqlens_q_d,
183
- void *cu_seqlens_k_d,
184
- void *seqused_q,
185
- void *seqused_k,
186
- void *dq_accum_d,
187
- void *dk_accum_d,
188
- void *dv_accum_d,
189
- void *softmax_lse_d,
190
- void *dsoftmax_sum_d,
191
- float p_dropout,
192
- float softmax_scale,
193
- int window_size_left,
194
- int window_size_right,
195
- int attention_chunk,
196
- const float softcap=0.f,
197
- bool deterministic=false,
198
- int const sm_margin=0) {
199
-
200
- set_params_fprop(params,
201
- b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
202
- q, k, v, out,
203
- cu_seqlens_q_d,
204
- cu_seqlens_k_d,
205
- seqused_q,
206
- seqused_k,
207
- softmax_lse_d,
208
- p_dropout,
209
- softmax_scale,
210
- window_size_left,
211
- window_size_right,
212
- attention_chunk,
213
- softcap,
214
- sm_margin);
215
-
216
- // Set the pointers and strides.
217
- params.do_ptr = dout.data_ptr();
218
- params.do_row_stride = dout.stride(-3);
219
- params.do_head_stride = dout.stride(-2);
220
- params.dq_ptr = dq.data_ptr();
221
- params.dk_ptr = dk.data_ptr();
222
- params.dv_ptr = dv.data_ptr();
223
- params.dq_row_stride = dq.stride(-3);
224
- params.dk_row_stride = dk.stride(-3);
225
- params.dv_row_stride = dv.stride(-3);
226
- params.dq_head_stride = dq.stride(-2);
227
- params.dk_head_stride = dk.stride(-2);
228
- params.dv_head_stride = dv.stride(-2);
229
-
230
- if (cu_seqlens_q_d == nullptr) {
231
- params.do_batch_stride = dout.stride(0);
232
- params.dq_batch_stride = dq.stride(0);
233
- params.dk_batch_stride = dk.stride(0);
234
- params.dv_batch_stride = dv.stride(0);
235
- }
236
-
237
- params.dq_accum_ptr = dq_accum_d;
238
- params.dk_accum_ptr = dk_accum_d;
239
- params.dv_accum_ptr = dv_accum_d;
240
-
241
- // Softmax sum
242
- params.dsoftmax_sum = dsoftmax_sum_d;
243
-
244
- params.deterministic = deterministic;
245
- }
246
-
247
- template <int Arch, int Split, bool PagedKVNonTMA, bool PackGQA, bool Has_softcap>
248
- void run_mha_fwd_constexpr(Flash_fwd_params &params, cudaStream_t stream) {
249
- if (!params.is_e4m3) {
250
- if (params.is_bf16) {
251
- #ifndef FLASHATTENTION_DISABLE_HDIM64
252
- if (params.d <= 64) {
253
- if constexpr (Arch == 90) {
254
- if (params.dv > 256) {
255
- return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
256
- } else if (params.dv > 64) {
257
- return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
258
- }
259
- }
260
- return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
261
- }
262
- #endif
263
- #ifndef FLASHATTENTION_DISABLE_HDIM96
264
- if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
265
- #endif
266
- #ifndef FLASHATTENTION_DISABLE_HDIM128
267
- if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
268
- #endif
269
- #ifndef FLASHATTENTION_DISABLE_HDIM192
270
- if (params.d <= 192) {
271
- if constexpr (Arch == 90) {
272
- if (params.dv <= 128) {
273
- return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
274
- }
275
- }
276
- return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
277
- }
278
- #endif
279
- #ifndef FLASHATTENTION_DISABLE_HDIM256
280
- if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
281
- #endif
282
- } else {
283
- #ifndef FLASHATTENTION_DISABLE_FP16
284
- #ifndef FLASHATTENTION_DISABLE_HDIM64
285
- if (params.d <= 64) {
286
- if constexpr (Arch == 90) {
287
- if (params.dv > 256) {
288
- return run_mha_fwd_<Arch, cutlass::half_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
289
- } else if (params.dv > 64) {
290
- return run_mha_fwd_<Arch, cutlass::half_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
291
- }
292
- }
293
- return run_mha_fwd_<Arch, cutlass::half_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
294
- }
295
- #endif
296
- #ifndef FLASHATTENTION_DISABLE_HDIM96
297
- if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::half_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
298
- #endif
299
- #ifndef FLASHATTENTION_DISABLE_HDIM128
300
- if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::half_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
301
- #endif
302
- #ifndef FLASHATTENTION_DISABLE_HDIM192
303
- if (params.d <= 192) {
304
- if constexpr (Arch == 90) {
305
- if (params.dv <= 128) {
306
- return run_mha_fwd_<Arch, cutlass::half_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
307
- }
308
- }
309
- return run_mha_fwd_<Arch, cutlass::half_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
310
- }
311
- #endif
312
- #ifndef FLASHATTENTION_DISABLE_HDIM256
313
- if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::half_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
314
- #endif
315
- #else
316
- TORCH_CHECK(false, "This flash attention build does not support FP16.");
317
- #endif
318
- }
319
- } else {
320
- #ifndef FLASHATTENTION_DISABLE_FP8
321
- #ifndef FLASHATTENTION_DISABLE_HDIM64
322
- if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
323
- #endif
324
- #ifndef FLASHATTENTION_DISABLE_HDIM96
325
- if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
326
- #endif
327
- #ifndef FLASHATTENTION_DISABLE_HDIM128
328
- if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
329
- #endif
330
- #ifndef FLASHATTENTION_DISABLE_HDIM192
331
- if (params.d <= 192) {
332
- if constexpr (Arch == 90) {
333
- if (params.dv <= 128) {
334
- return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
335
- }
336
- }
337
- return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
338
- }
339
- #endif
340
- #ifndef FLASHATTENTION_DISABLE_HDIM256
341
- if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
342
- #endif
343
- #else
344
- TORCH_CHECK(false, "This flash attention build does not support FP8.");
345
- #endif
346
- }
347
- }
348
-
349
- void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream) {
350
- // HEADDIM_SWITCH(params.d, [&] {
351
- // run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
352
- // });
353
- TORCH_CHECK(params.num_splits >= 1);
354
- ARCH_SWITCH(params.arch, Arch, [&] {
355
- SPLIT_SWITCH(params.num_splits > 1, Split, [&] {
356
- PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] {
357
- PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] {
358
- // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation
359
- static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split;
360
- SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] {
361
- run_mha_fwd_constexpr<Arch, Split, PagedKVNonTMA, PackGQA, Has_softcap>(params, stream);
362
- });
363
- });
364
- });
365
- });
366
- });
367
- }
368
-
369
- void run_mha_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl=false) {
370
- #ifndef FLASHATTENTION_DISABLE_SPLIT
371
- // If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively
372
- // so that kBlockM is smaller and we have more parallelism.
373
- if (params.is_fp32) {
374
- if (params.dv <= 64) {
375
- run_mha_fwd_combine_<float, float, 64>(params, stream, enable_pdl);
376
- } else {
377
- run_mha_fwd_combine_<float, float, 128>(params, stream, enable_pdl);
378
- }
379
- } else if (params.is_bf16) {
380
- if (params.dv <= 64) {
381
- run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(params, stream, enable_pdl);
382
- } else {
383
- run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(params, stream, enable_pdl);
384
- }
385
- } else {
386
- if (params.dv <= 64) {
387
- run_mha_fwd_combine_<cutlass::half_t, float, 64>(params, stream, enable_pdl);
388
- } else {
389
- run_mha_fwd_combine_<cutlass::half_t, float, 128>(params, stream, enable_pdl);
390
- }
391
- }
392
- #else
393
- TORCH_CHECK(false, "This flash attention build does not support combine kernels.");
394
- #endif
395
- }
396
-
397
- inline bool get_pagedkv_tma(Flash_fwd_params const& params) {
398
- if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; }
399
- // This needs to match the kernel configs
400
- auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f);
401
- int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
402
- int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90);
403
- // Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower,
404
- // at least for MLA.
405
- return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM;
406
- }
407
-
408
- inline bool get_pack_gqa(Flash_fwd_params const& params) {
409
- // Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size.
410
- // Has little effect on speed.
411
- if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; }
412
- #ifdef FLASHATTENTION_DISABLE_PACKGQA
413
- return false;
414
- #else
415
- // params.page_table must already be set
416
- if (params.h == params.h_k) { return false; }
417
- // This needs to match the kernel configs
418
- auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
419
- int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
420
- return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM);
421
- #endif
422
- }
423
-
424
- inline int get_num_splits(Flash_fwd_params const& params) {
425
- #ifdef FLASHATTENTION_DISABLE_SPLIT
426
- return 1;
427
- #else
428
- // Always enable PackGQA for Split
429
- // params.page_table must already be set
430
- // This needs to match the kernel configs
431
- bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k;
432
- auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
433
- // Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits
434
- // has not been set here. It's OK though because we might just underestimate kBlockN a bit
435
- auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);
436
- int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
437
- int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
438
- int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k);
439
- // If is_local, we're not going to load all of seqlen_k
440
- int const seqlen_k_loaded = !params.is_local
441
- ? params.seqlen_k
442
- : std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM));
443
- int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN;
444
- int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM;
445
- int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2);
446
- // Always enable PackGQA for Split
447
- // If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits.
448
- // We assume the case where there's 1 long sequence and the rest are short, i.e. pretending
449
- // that batch = 1.
450
- int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks;
451
- return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128);
452
- #endif
453
- }
454
-
455
- inline int get_max_headdim() {
456
- #ifndef FLASHATTENTION_DISABLE_HDIM256
457
- return 256;
458
- #endif
459
- #ifndef FLASHATTENTION_DISABLE_HDIM192
460
- return 192;
461
- #endif
462
- #ifndef FLASHATTENTION_DISABLE_HDIM128
463
- return 128;
464
- #endif
465
- #ifndef FLASHATTENTION_DISABLE_HDIM96
466
- return 96;
467
- #endif
468
- #ifndef FLASHATTENTION_DISABLE_HDIM64
469
- return 64;
470
- #endif
471
- return 0;
472
- }
473
-
474
- inline int round_up_headdim(int head_size) {
475
- #ifndef FLASHATTENTION_DISABLE_HDIM64
476
- if (head_size <= 64) { return 64; }
477
- #endif
478
- #ifndef FLASHATTENTION_DISABLE_HDIM96
479
- if (head_size <= 96) { return 96; }
480
- #endif
481
- #ifndef FLASHATTENTION_DISABLE_HDIM128
482
- if (head_size <= 128) { return 128; }
483
- #endif
484
- #ifndef FLASHATTENTION_DISABLE_HDIM192
485
- if (head_size <= 192) { return 192; }
486
- #endif
487
- #ifndef FLASHATTENTION_DISABLE_HDIM256
488
- if (head_size <= 256) { return 256; }
489
- #endif
490
- return 256;
491
- }
492
-
493
- inline int round_up_headdimv(int head_size) {
494
- if (head_size <= 64) { return 64; }
495
- if (head_size <= 96) { return 96; }
496
- if (head_size <= 128) { return 128; }
497
- if (head_size <= 192) { return 192; }
498
- if (head_size <= 256) { return 256; }
499
- return 512;
500
- }
501
-
502
- // Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
503
- at::Tensor
504
- mha_fwd_get_scheduler_metadata(
505
- int64_t batch_size,
506
- int64_t max_seqlen_q,
507
- int64_t max_seqlen_k,
508
- int64_t num_heads,
509
- int64_t num_heads_k,
510
- int64_t headdim,
511
- int64_t headdim_v,
512
- at::ScalarType qkv_dtype,
513
- at::Tensor seqused_k, // b
514
- std::optional<at::Tensor> cu_seqlens_q_, // b+1
515
- std::optional<at::Tensor> cu_seqlens_k_, // b+1
516
- std::optional<at::Tensor> cu_seqlens_k_new_, // b+1
517
- std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
518
- std::optional<at::Tensor> leftpad_k_, // b
519
- std::optional<int64_t> page_size,
520
- int64_t max_seqlen_k_new, // 0 means we're not appending new KV
521
- bool is_causal,
522
- int64_t window_size_left,
523
- int64_t window_size_right,
524
- int64_t attention_chunk,
525
- bool has_softcap,
526
- int64_t num_splits,
527
- std::optional<bool> pack_gqa_,
528
- int64_t sm_margin
529
- ) {
530
-
531
- TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn,
532
- "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
533
- TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
534
-
535
- // Reset the parameters
536
- Flash_fwd_params params{};
537
- params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16;
538
- params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn;
539
- params.b = batch_size;
540
- params.seqlen_q = max_seqlen_q;
541
- params.seqlen_k = max_seqlen_k;
542
- params.h = num_heads;
543
- params.h_k = num_heads_k;
544
- params.d = headdim;
545
- params.dv = headdim_v;
546
- params.d_rounded = round_up_headdim(headdim);
547
- params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v);
548
- params.seqlen_knew = max_seqlen_k_new;
549
-
550
- bool const is_varlen_q = cu_seqlens_q_.has_value();
551
- params.cu_seqlens_q = is_varlen_q ? cu_seqlens_q_.value().data_ptr<int>() : nullptr;
552
- bool const is_varlen_k = cu_seqlens_k_.has_value();
553
- params.cu_seqlens_k = is_varlen_k ? cu_seqlens_k_.value().data_ptr<int>() : nullptr;
554
- params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? cu_seqlens_k_new_.value().data_ptr<int>() : nullptr;
555
- params.seqused_q = seqused_q_.has_value() ? seqused_q_.value().data_ptr<int>() : nullptr;
556
- params.seqused_k = seqused_k.data_ptr<int>();
557
- params.leftpad_k = leftpad_k_.has_value() ? leftpad_k_.value().data_ptr<int>() : nullptr;
558
- params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast<int*>(1) : nullptr;
559
- if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; }
560
- if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; }
561
- // causal=true is the same as causal=false in this case
562
- if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) {
563
- // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA
564
- if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) {
565
- is_causal = false;
566
- }
567
- }
568
- if (is_causal) { window_size_right = 0; }
569
-
570
- params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0;
571
- params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal;
572
- if (window_size_left < 0) { window_size_left = max_seqlen_k - 1; }
573
- if (window_size_right < 0) { window_size_right = max_seqlen_q - 1; }
574
- if (attention_chunk > 0) {
575
- window_size_left = std::min(window_size_left, attention_chunk - 1);
576
- window_size_right = std::min(window_size_right, attention_chunk - 1);
577
- }
578
- params.window_size_left = window_size_left;
579
- params.window_size_right = window_size_right;
580
- params.attention_chunk = attention_chunk;
581
- params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
582
- params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
583
- params.softcap = has_softcap ? 1.0f : 0.0f;
584
-
585
- params.page_size = page_size.has_value() ? page_size.value() : 1;
586
- params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast<int*>(1);
587
-
588
- bool const use_dynamic_split = params.b <= 992;
589
- params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
590
-
591
- params.pagedkv_tma = get_pagedkv_tma(params);
592
- params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
593
- // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide
594
- params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
595
-
596
- bool is_varlen = true;
597
-
598
- // Otherwise the kernel will be launched from cuda:0 device
599
- // Cast to char to avoid compiler warning about narrowing
600
- at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()};
601
-
602
- auto opts = seqused_k.options();
603
- // This needs to be set after get_num_splits
604
- at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
605
- bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1;
606
- if (scheduler_needs_semaphore || use_dynamic_split) {
607
- tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32));
608
- if (scheduler_needs_semaphore) {
609
- if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing
610
- params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
611
- } else {
612
- params.tile_count_semaphore = nullptr;
613
- }
614
- params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
615
- }
616
-
617
- if (params.num_splits_dynamic_ptr) {
618
- auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
619
- auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);
620
- int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
621
- int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
622
- auto stream = at::cuda::getCurrentCUDAStream().stream();
623
- prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/);
624
- CHECK_CUDA_KERNEL_LAUNCH();
625
- }
626
- return tile_count_semaphore;
627
- }
628
-
629
- // b: batch_size
630
- // b_k: batch_size_k
631
- // s_q: seqlen_q
632
- // s_k: seqlen_k
633
- // s_k_new: seqlen_k_new
634
- // h: num_heads
635
- // h_k: num_heads_k
636
- // d: head_size
637
- std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
638
- mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
639
- at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.
640
- at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table.
641
- std::optional<at::Tensor> k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
642
- std::optional<at::Tensor> v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
643
- std::optional<at::Tensor> q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
644
- std::optional<at::Tensor> out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
645
- std::optional<at::Tensor> cu_seqlens_q_, // b+1
646
- std::optional<at::Tensor> cu_seqlens_k_, // b+1
647
- std::optional<at::Tensor> cu_seqlens_k_new_, // b+1
648
- std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
649
- std::optional<at::Tensor> seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
650
- std::optional<int64_t> max_seqlen_q_,
651
- // TODO: check if we need max_seqlen_k
652
- std::optional<int64_t> max_seqlen_k_,
653
- std::optional<at::Tensor> page_table_, // (b_k, max_num_pages_per_seq)
654
- std::optional<at::Tensor> kv_batch_idx_, // b. indices to index into the KV cache
655
- std::optional<at::Tensor> leftpad_k_, // b
656
- std::optional<at::Tensor> rotary_cos_, // seqlen_ro x (rotary_dim / 2)
657
- std::optional<at::Tensor> rotary_sin_, // seqlen_ro x (rotary_dim / 2)
658
- std::optional<at::Tensor> seqlens_rotary_, // b
659
- std::optional<at::Tensor> q_descale_, // (b, h_k), not (b, h)
660
- std::optional<at::Tensor> k_descale_, // (b, h_k)
661
- std::optional<at::Tensor> v_descale_, // (b, h_k)
662
- std::optional<double> softmax_scale_,
663
- bool is_causal,
664
- int64_t window_size_left,
665
- int64_t window_size_right,
666
- int64_t attention_chunk,
667
- double softcap,
668
- bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
669
- std::optional<at::Tensor> scheduler_metadata_, // (b + 1)
670
- int64_t num_splits,
671
- std::optional<bool> pack_gqa_,
672
- int64_t sm_margin
673
- ) {
674
-
675
- auto dprops = at::cuda::getCurrentDeviceProperties();
676
- bool is_sm8x = dprops->major >= 8;
677
- TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
678
-
679
- auto q_type = q.scalar_type();
680
- TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,
681
- "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
682
- if (dprops->major < 9) {
683
- TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
684
- "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type");
685
- }
686
- TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype");
687
- TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype");
688
-
689
- CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
690
-
691
- TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
692
- TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
693
- TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
694
-
695
- at::Tensor page_table;
696
- const bool paged_KV = page_table_.has_value();
697
- if (paged_KV) {
698
- page_table = page_table_.value();
699
- CHECK_DEVICE(page_table);
700
- TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32");
701
- TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension");
702
- }
703
-
704
- at::Tensor cu_seqlens_q;
705
- bool const is_varlen_q = cu_seqlens_q_.has_value();
706
- if (is_varlen_q) {
707
- cu_seqlens_q = cu_seqlens_q_.value();
708
- CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
709
- TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
710
- TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
711
- }
712
- at::Tensor cu_seqlens_k;
713
- bool const is_varlen_k = cu_seqlens_k_.has_value();
714
- if (is_varlen_k) {
715
- cu_seqlens_k = cu_seqlens_k_.value();
716
- CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
717
- TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
718
- TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
719
- TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported");
720
- TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported");
721
- }
722
-
723
- auto const sizes = q.sizes();
724
- const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
725
- int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
726
- int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
727
- int num_heads = q.size(-2);
728
- int const head_size = q.size(-1);
729
- int const head_size_v = v.size(-1);
730
- int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1);
731
- int const num_pages = !paged_KV ? 0 : k.size(0);
732
- int const page_size = !paged_KV ? 1 : k.size(1);
733
- int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value();
734
- int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
735
- int const num_heads_k = k.size(-2);
736
- int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0);
737
- double softmax_scale = 1.0 / sqrt(double(head_size));
738
- if (softmax_scale_.has_value()) {
739
- softmax_scale = softmax_scale_.value();
740
- }
741
- if (!kv_batch_idx_.has_value()) {
742
- TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k");
743
- }
744
- int const max_headdim = get_max_headdim();
745
- TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
746
- TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
747
- if (head_size_v != head_size) {
748
- TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) ||
749
- (head_size <= 64 && head_size_v <= 512),
750
- "If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], "
751
- "or (Q/K <= 64 and V <= 512).");
752
- TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim");
753
- if (head_size_v > 256) {
754
- TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
755
- "HeaddimV > 256 requires fp16 and bf16 data type");
756
- }
757
- }
758
-
759
- // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
760
- // TODO: check this
761
- if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
762
- if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
763
- // causal=true is the same as causal=false in this case
764
- if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) {
765
- // Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA
766
- if ((head_size <= 64 || head_size > 128) || !paged_KV) {
767
- is_causal = false;
768
- }
769
- }
770
- if (is_causal) { window_size_right = 0; }
771
-
772
- if (!is_varlen_q) {
773
- CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
774
- } else {
775
- CHECK_SHAPE(q, total_q, num_heads, head_size);
776
- CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
777
- }
778
- if (!paged_KV) {
779
- if (!is_varlen_k) {
780
- CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size);
781
- CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v);
782
- } else {
783
- CHECK_SHAPE(k, total_k, num_heads_k, head_size);
784
- CHECK_SHAPE(v, total_k, num_heads_k, head_size_v);
785
- CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
786
- }
787
- } else {
788
- CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size);
789
- CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v);
790
- CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq);
791
- }
792
-
793
- if (seqused_q_.has_value()){
794
- auto seqused_q = seqused_q_.value();
795
- TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
796
- CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
797
- CHECK_SHAPE(seqused_q, batch_size);
798
- }
799
- if (seqused_k_.has_value()) {
800
- auto seqused_k = seqused_k_.value();
801
- TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
802
- CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
803
- CHECK_SHAPE(seqused_k, batch_size);
804
- }
805
-
806
- if (leftpad_k_.has_value()) {
807
- auto leftpad_k = leftpad_k_.value();
808
- TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
809
- CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k);
810
- CHECK_SHAPE(leftpad_k, batch_size);
811
- }
812
-
813
- // This is what we will template on
814
- bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value();
815
- #ifdef FLASHATTENTION_DISABLE_VARLEN
816
- TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
817
- #endif
818
-
819
- int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8;
820
- TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment));
821
- TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment));
822
-
823
- auto opts = q.options();
824
- auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
825
- at::Tensor out;
826
- if (out_.has_value()) {
827
- out = out_.value();
828
- TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16");
829
- CHECK_DEVICE(out);
830
- TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
831
- if (!is_varlen_q) {
832
- CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v);
833
- } else {
834
- CHECK_SHAPE(out, total_q, num_heads, head_size_v);
835
- }
836
- } else {
837
- out = !is_varlen_q
838
- ? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type))
839
- : torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type));
840
- }
841
-
842
- auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
843
- int const head_size_rounded = round_up_headdim(head_size);
844
- int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v);
845
- int const seqlen_q_rounded = round_multiple(seqlen_q, 128);
846
- int const seqlen_k_rounded = round_multiple(seqlen_k, 128);
847
-
848
- // Otherwise the kernel will be launched from cuda:0 device
849
- // Cast to char to avoid compiler warning about narrowing
850
- at::cuda::CUDAGuard device_guard{(char)q.get_device()};
851
-
852
- at::Tensor softmax_lse;
853
- if (!is_varlen_q) {
854
- softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
855
- } else {
856
- softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
857
- }
858
-
859
- Flash_fwd_params params;
860
- set_params_fprop(params,
861
- batch_size,
862
- seqlen_q, seqlen_k,
863
- seqlen_q_rounded, seqlen_k_rounded,
864
- num_heads, num_heads_k,
865
- head_size, head_size_rounded,
866
- q, k, v, out,
867
- !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
868
- !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
869
- seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
870
- seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
871
- softmax_lse.data_ptr(),
872
- /*p_dropout=*/0.f,
873
- softmax_scale,
874
- window_size_left,
875
- window_size_right,
876
- attention_chunk,
877
- softcap,
878
- sm_margin);
879
- params.total_q = total_q;
880
- params.total_k = total_k;
881
- params.b_k = batch_size_k;
882
- params.dv = head_size_v;
883
- params.dv_rounded = head_size_v_rounded;
884
- if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma
885
- params.leftpad_k = static_cast<int *>(leftpad_k_.value().data_ptr());
886
- }
887
- if (paged_KV) {
888
- params.page_table = page_table.data_ptr<int>();
889
- params.page_table_batch_stride = page_table.stride(0);
890
- }
891
- params.page_size = page_size;
892
- params.num_pages = num_pages;
893
-
894
- if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma
895
- at::Tensor k_new, v_new;
896
- TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in");
897
- TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in");
898
- TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache");
899
- at::Tensor cu_seqlens_k_new;
900
- bool const is_varlen_k_new = cu_seqlens_k_new_.has_value();
901
- if (is_varlen_k_new) {
902
- cu_seqlens_k_new = cu_seqlens_k_new_.value();
903
- CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new);
904
- TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32");
905
- }
906
- k_new = k_new_.value();
907
- v_new = v_new_.value();
908
- TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query");
909
- TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query");
910
- CHECK_DEVICE(k_new); CHECK_DEVICE(v_new);
911
- TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension");
912
- TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension");
913
- // We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new
914
- int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0;
915
- int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0);
916
- if (!is_varlen_k_new) {
917
- CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size);
918
- CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v);
919
- } else {
920
- CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size);
921
- CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v);
922
- CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1);
923
- }
924
- params.seqlen_knew = seqlen_k_new;
925
- params.total_knew = total_k_new;
926
- params.knew_ptr = k_new.data_ptr();
927
- params.vnew_ptr = v_new.data_ptr();
928
- // All stride are in elements, not bytes.
929
- params.knew_row_stride = k_new.stride(-3);
930
- params.vnew_row_stride = v_new.stride(-3);
931
- params.knew_head_stride = k_new.stride(-2);
932
- params.vnew_head_stride = v_new.stride(-2);
933
- if (!is_varlen_k_new) {
934
- params.knew_batch_stride = k_new.stride(0);
935
- params.vnew_batch_stride = v_new.stride(0);
936
- }
937
- if (is_varlen_k_new) {
938
- params.cu_seqlens_knew = static_cast<int*>(cu_seqlens_k_new.data_ptr());
939
- }
940
- }
941
-
942
- // 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel
943
- bool const use_dynamic_split = is_varlen && params.b <= 992;
944
- // Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it
945
- params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
946
-
947
- params.pagedkv_tma = get_pagedkv_tma(params);
948
- params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
949
- // Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide
950
- params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
951
-
952
- // This needs to be set after get_num_splits
953
- at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
954
- // We don't use the persistent scheduler if Split and not Varlen
955
- bool const scheduler_needs_semaphore = params.arch >= 90
956
- ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen)
957
- : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1));
958
- if (scheduler_needs_semaphore || use_dynamic_split) {
959
- int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b;
960
- params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value();
961
- if (scheduler_metadata_.has_value()) {
962
- at::Tensor scheduler_metadata = scheduler_metadata_.value();
963
- CHECK_DEVICE(scheduler_metadata);
964
- CHECK_SHAPE(scheduler_metadata, metadata_size);
965
- CHECK_CONTIGUOUS(scheduler_metadata);
966
- TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32");
967
- tile_count_semaphore = scheduler_metadata;
968
- } else {
969
- tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32));
970
- }
971
- if (scheduler_needs_semaphore && !use_dynamic_split) {
972
- tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing
973
- }
974
- params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr<int>() : nullptr;
975
- params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
976
- }
977
-
978
- if (q_v_.has_value()) {
979
- TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64");
980
- TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
981
- "q_v is only supported for fp16 and bf16 data type");
982
- TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs");
983
- at::Tensor q_v = q_v_.value();
984
- TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query");
985
- CHECK_DEVICE(q_v);
986
- TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension");
987
- if (!is_varlen_q) {
988
- CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v);
989
- } else {
990
- CHECK_SHAPE(q_v, total_q, num_heads, head_size_v);
991
- }
992
- params.qv_ptr = q_v.data_ptr();
993
- // All stride are in elements, not bytes.
994
- params.qv_row_stride = q_v.stride(-3);
995
- params.qv_head_stride = q_v.stride(-2);
996
- if (!is_varlen_q) {
997
- params.qv_batch_stride = q_v.stride(0);
998
- }
999
- }
1000
-
1001
- if (rotary_cos_.has_value()) {
1002
- TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
1003
- auto rotary_cos = rotary_cos_.value();
1004
- CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos);
1005
- params.rotary_dim = rotary_cos.size(1) * 2;
1006
- TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
1007
- TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
1008
- const int seqlen_ro = rotary_cos.size(0);
1009
- if (paged_KV) {
1010
- TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
1011
- }
1012
- CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
1013
- TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
1014
-
1015
- TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
1016
- auto rotary_sin = rotary_sin_.value();
1017
- CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin);
1018
- CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
1019
- TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
1020
- params.rotary_cos_ptr = rotary_cos.data_ptr();
1021
- params.rotary_sin_ptr = rotary_sin.data_ptr();
1022
- params.is_rotary_interleaved = is_rotary_interleaved;
1023
- if (seqlens_rotary_.has_value()) {
1024
- at::Tensor seqlens_rotary = seqlens_rotary_.value();
1025
- CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary);
1026
- TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32");
1027
- CHECK_SHAPE(seqlens_rotary, batch_size);
1028
- params.seqlens_rotary = seqlens_rotary.data_ptr<int>();
1029
- }
1030
- } else {
1031
- params.rotary_dim = 0;
1032
- }
1033
-
1034
- if (kv_batch_idx_.has_value()) {
1035
- auto kv_batch_idx = kv_batch_idx_.value();
1036
- CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx);
1037
- TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32");
1038
- params.kv_batch_idx = reinterpret_cast<int *>(kv_batch_idx.data_ptr());
1039
- }
1040
-
1041
- at::Tensor out_accum, softmax_lse_accum;
1042
- auto outaccum_type = at::ScalarType::Float;
1043
- if (params.num_splits > 1) {
1044
- TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported");
1045
- if (!is_varlen_q) {
1046
- out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type));
1047
- softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
1048
- params.oaccum_batch_stride = out_accum.stride(1);
1049
- params.lseaccum_batch_stride = softmax_lse_accum.stride(1);
1050
- } else {
1051
- out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type));
1052
- softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat));
1053
- }
1054
- params.is_fp32 = false;
1055
- params.oaccum_ptr = out_accum.data_ptr();
1056
- params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
1057
- params.oaccum_split_stride = out_accum.stride(0);
1058
- params.oaccum_row_stride = out_accum.stride(-2);
1059
- params.oaccum_head_stride = out_accum.stride(-3);
1060
- params.lseaccum_split_stride = softmax_lse_accum.stride(0);
1061
- params.lseaccum_head_stride = softmax_lse_accum.stride(-2);
1062
- }
1063
-
1064
- if (q_type == at::ScalarType::Float8_e4m3fn) {
1065
- if (q_descale_.has_value()) {
1066
- auto q_descale = q_descale_.value();
1067
- CHECK_DEVICE(q_descale);
1068
- CHECK_SHAPE(q_descale, batch_size, num_heads_k);
1069
- params.q_descale_ptr = q_descale.data_ptr<float>();
1070
- params.q_descale_batch_stride = q_descale.stride(0);
1071
- params.q_descale_head_stride = q_descale.stride(1);
1072
- } else {
1073
- params.q_descale_ptr = nullptr;
1074
- }
1075
- if (k_descale_.has_value()) {
1076
- auto k_descale = k_descale_.value();
1077
- CHECK_DEVICE(k_descale);
1078
- CHECK_SHAPE(k_descale, batch_size, num_heads_k);
1079
- params.k_descale_ptr = k_descale.data_ptr<float>();
1080
- params.k_descale_batch_stride = k_descale.stride(0);
1081
- params.k_descale_head_stride = k_descale.stride(1);
1082
- } else {
1083
- params.k_descale_ptr = nullptr;
1084
- }
1085
- if (v_descale_.has_value()) {
1086
- auto v_descale = v_descale_.value();
1087
- CHECK_DEVICE(v_descale);
1088
- CHECK_SHAPE(v_descale, batch_size, num_heads_k);
1089
- params.v_descale_ptr = v_descale.data_ptr<float>();
1090
- params.v_descale_batch_stride = v_descale.stride(0);
1091
- params.v_descale_head_stride = v_descale.stride(1);
1092
- } else {
1093
- params.v_descale_ptr = nullptr;
1094
- }
1095
- }
1096
-
1097
- #ifdef FLASHATTENTION_DISABLE_LOCAL
1098
- TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
1099
- #endif
1100
- #ifdef FLASHATTENTION_DISABLE_SOFTCAP
1101
- TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
1102
- #endif
1103
- #ifdef FLASHATTENTION_DISABLE_SPLIT
1104
- TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits.");
1105
- #endif
1106
- #ifdef FLASHATTENTION_DISABLE_PACKGQA
1107
- TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa.");
1108
- #endif
1109
- #ifdef FLASHATTENTION_DISABLE_PAGEDKV
1110
- TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV.");
1111
- #endif
1112
- #ifdef FLASHATTENTION_DISABLE_APPENDKV
1113
- TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV.");
1114
- #endif
1115
-
1116
- if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) {
1117
- auto stream = at::cuda::getCurrentCUDAStream().stream();
1118
- run_mha_fwd(params, stream);
1119
- if (params.num_splits > 1) {
1120
- if (out_type == at::ScalarType::BFloat16) {
1121
- // Since we want output in BF16. Otherwise fwd_combine will output to FP16
1122
- params.is_bf16 = true;
1123
- }
1124
- // Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1
1125
- // and seqlen = total_q, and don't need to dispatch to Varlen there.
1126
- // However, with dynamic split, each row needs to know which batch it belongs to
1127
- // to read the number of splits, so we just use the varlen version of combine kernel.
1128
- // if (is_varlen_q && !seqused_q_.has_value()) {
1129
- // if (is_varlen_q) {
1130
- // params.b = 1;
1131
- // params.seqlen_q = total_q;
1132
- // }
1133
- // This will zero out the semaphore if needed
1134
- run_mha_fwd_combine(params, stream, true /*enable_pdl*/);
1135
- } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) {
1136
- // need to zero out the semaphore in this case
1137
- tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_();
1138
- }
1139
- } else if (total_q > 0 && num_heads_k > 0) {
1140
- // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
1141
- out.zero_();
1142
- softmax_lse.fill_(std::numeric_limits<float>::infinity());
1143
- }
1144
-
1145
- // return {out, softmax_lse};
1146
- return {out, softmax_lse, out_accum, softmax_lse_accum};
1147
- }
1148
-
1149
- #ifdef FLASHATTENTION_DISABLE_BACKWARD
1150
- void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
1151
- TORCH_CHECK(false, "Flash-Attention was built with backward disabled");
1152
- }
1153
- #else
1154
- template <int Arch, bool Has_softcap>
1155
- void run_mha_bwd_constexpr(Flash_bwd_params &params, cudaStream_t stream) {
1156
- if (!params.is_bf16) {
1157
- #ifndef FLASHATTENTION_DISABLE_FP16
1158
- #ifndef FLASHATTENTION_DISABLE_HDIM64
1159
- if (params.d_rounded == 64) { return run_mha_bwd_<Arch, cutlass::half_t, 64, Has_softcap>(params, stream); }
1160
- #endif
1161
- #ifndef FLASHATTENTION_DISABLE_HDIM96
1162
- if (params.d_rounded == 96) { return run_mha_bwd_<Arch, cutlass::half_t, 96, Has_softcap>(params, stream); }
1163
- #endif
1164
- #ifndef FLASHATTENTION_DISABLE_HDIM128
1165
- if (params.d_rounded == 128) { return run_mha_bwd_<Arch, cutlass::half_t, 128, Has_softcap>(params, stream); }
1166
- #endif
1167
- #ifndef FLASHATTENTION_DISABLE_HDIM192
1168
- if (params.d_rounded == 192) { return run_mha_bwd_<Arch, cutlass::half_t, 192, Has_softcap>(params, stream); }
1169
- #endif
1170
- #ifndef FLASHATTENTION_DISABLE_HDIM256
1171
- if (params.d_rounded == 256) { return run_mha_bwd_<Arch, cutlass::half_t, 256, Has_softcap>(params, stream); }
1172
- #endif
1173
- #else
1174
- TORCH_CHECK(false, "This flash attention build does not support FP16.");
1175
- #endif
1176
- } else {
1177
- #ifndef FLASHATTENTION_DISABLE_HDIM64
1178
- if (params.d_rounded == 64) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 64, Has_softcap>(params, stream); }
1179
- #endif
1180
- #ifndef FLASHATTENTION_DISABLE_HDIM96
1181
- if (params.d_rounded == 96) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 96, Has_softcap>(params, stream); }
1182
- #endif
1183
- #ifndef FLASHATTENTION_DISABLE_HDIM128
1184
- if (params.d_rounded == 128) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 128, Has_softcap>(params, stream); }
1185
- #endif
1186
- #ifndef FLASHATTENTION_DISABLE_HDIM192
1187
- if (params.d_rounded == 192) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 192, Has_softcap>(params, stream); }
1188
- #endif
1189
- #ifndef FLASHATTENTION_DISABLE_HDIM256
1190
- if (params.d_rounded == 256) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 256, Has_softcap>(params, stream); }
1191
- #endif
1192
- }
1193
- }
1194
-
1195
- void run_mha_bwd(Flash_bwd_params &params, cudaStream_t stream) {
1196
- // FP16_SWITCH(!params.is_bf16, [&] {
1197
- // HEADDIM_SWITCH(params.d, [&] {
1198
- // run_mha_bwd_<elem_type, kHeadDim>(params, stream);
1199
- // });
1200
- // });
1201
- ARCH_SWITCH(params.arch, Arch, [&] {
1202
- SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
1203
- run_mha_bwd_constexpr<Arch, Has_softcap>(params, stream);
1204
- });
1205
- });
1206
- }
1207
- #endif
1208
-
1209
-
1210
- // b: batch_size
1211
- // s_q: seqlen_q
1212
- // s_k: seqlen_k
1213
- // h: num_heads
1214
- // h_k: num_heads_k
1215
- // d: head_size
1216
- std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
1217
- at::Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
1218
- at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
1219
- at::Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
1220
- at::Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k
1221
- at::Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
1222
- at::Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q
1223
- std::optional<at::Tensor> dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
1224
- std::optional<at::Tensor> dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
1225
- std::optional<at::Tensor> dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k
1226
- std::optional<at::Tensor> cu_seqlens_q_, // b+1
1227
- std::optional<at::Tensor> cu_seqlens_k_, // b+1
1228
- std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
1229
- std::optional<at::Tensor> seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
1230
- std::optional<int64_t> max_seqlen_q_,
1231
- std::optional<int64_t> max_seqlen_k_,
1232
- std::optional<double> softmax_scale_,
1233
- bool is_causal,
1234
- int64_t window_size_left,
1235
- int64_t window_size_right,
1236
- double softcap,
1237
- bool deterministic,
1238
- int64_t sm_margin
1239
- ) {
1240
-
1241
- #ifdef FLASHATTENTION_DISABLE_BACKWARD
1242
- TORCH_CHECK(false, "This flash attention build does not support backward.");
1243
- #endif
1244
-
1245
- auto dprops = at::cuda::getCurrentDeviceProperties();
1246
- bool is_sm8x = dprops->major >= 8;
1247
- TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
1248
-
1249
- auto q_type = q.dtype();
1250
- TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16,
1251
- "FlashAttention only support fp16 and bf16 data type");
1252
- TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype");
1253
- TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype");
1254
- TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype");
1255
- TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype");
1256
-
1257
- CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
1258
- CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
1259
-
1260
- TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1261
- TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1262
- TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1263
- TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
1264
- TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
1265
-
1266
- at::Tensor cu_seqlens_q;
1267
- bool const is_varlen_q = cu_seqlens_q_.has_value();
1268
- if (is_varlen_q) {
1269
- cu_seqlens_q = cu_seqlens_q_.value();
1270
- CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
1271
- TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
1272
- TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
1273
- }
1274
- at::Tensor cu_seqlens_k;
1275
- bool const is_varlen_k = cu_seqlens_k_.has_value();
1276
- if (is_varlen_k) {
1277
- cu_seqlens_k = cu_seqlens_k_.value();
1278
- CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
1279
- TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
1280
- TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
1281
- }
1282
- // This is what we will template on
1283
- bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value();
1284
- #ifdef FLASHATTENTION_DISABLE_VARLEN
1285
- TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
1286
- #endif
1287
-
1288
- auto const sizes = q.sizes();
1289
- int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
1290
- int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
1291
- int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
1292
- int const num_heads = q.size(-2);
1293
- int const head_size = q.size(-1);
1294
- int const head_size_v = v.size(-1);
1295
- int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value();
1296
- int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
1297
- int const num_heads_k = k.size(-2);
1298
- TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
1299
- TORCH_CHECK(head_size_v % 8 == 0, "head_size_v should be a multiple of 8");
1300
- int const max_headdim = get_max_headdim();
1301
- TORCH_CHECK(std::max(head_size, head_size_v) <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
1302
- TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
1303
- double softmax_scale = 1.0 / sqrt(double(head_size));
1304
- if (softmax_scale_.has_value()) {
1305
- softmax_scale = softmax_scale_.value();
1306
- }
1307
-
1308
- // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
1309
- if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
1310
- if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
1311
- if (is_causal) { window_size_right = 0; }
1312
- // There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true.
1313
- // If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA).
1314
- is_causal = window_size_left < 0 && window_size_right == 0;
1315
-
1316
- int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
1317
- int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v));
1318
- int const head_size_v_rounded = head_size_rounded;
1319
- // Very important that these match the kernel configs
1320
- bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal;
1321
- int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128)
1322
- : (head_size_rounded <= 96 ? 64
1323
- : (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80)
1324
- : 64));
1325
- int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;
1326
- int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32;
1327
- int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
1328
- int const kBlockN_sm90 = head_size_rounded <= 128
1329
- ? 128
1330
- : (head_size_rounded <= 192 ? 96 : 80);
1331
- int const kBlockN_sm80 = head_size_rounded <= 128
1332
- ? 128
1333
- : (head_size_rounded <= 192 ? 80 : 64);
1334
- int const kBlockN_sm86 = head_size_rounded <= 64 ? 128
1335
- : (head_size_rounded <= 96 ? 128
1336
- : (head_size_rounded <= 128 ? 96
1337
- : (head_size_rounded <= 192 ? 64 : 64)));
1338
- int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80);
1339
- auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1340
- int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
1341
- int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN);
1342
- int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM);
1343
- int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN);
1344
-
1345
- if (!is_varlen_q) {
1346
- CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
1347
- CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v);
1348
- CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_v);
1349
- } else {
1350
- CHECK_SHAPE(q, total_q, num_heads, head_size);
1351
- CHECK_SHAPE(out, total_q, num_heads, head_size_v);
1352
- CHECK_SHAPE(dout, total_q, num_heads, head_size_v);
1353
- CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
1354
- }
1355
- if (!is_varlen_k) {
1356
- CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
1357
- CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_v);
1358
- } else {
1359
- CHECK_SHAPE(k, total_k, num_heads_k, head_size);
1360
- CHECK_SHAPE(v, total_k, num_heads_k, head_size_v);
1361
- CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
1362
- }
1363
-
1364
- if (seqused_q_.has_value()){
1365
- auto seqused_q = seqused_q_.value();
1366
- TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
1367
- CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
1368
- CHECK_SHAPE(seqused_q, batch_size);
1369
- }
1370
- if (seqused_k_.has_value()){
1371
- auto seqused_k = seqused_k_.value();
1372
- TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
1373
- CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
1374
- CHECK_SHAPE(seqused_k, batch_size);
1375
- }
1376
-
1377
- at::Tensor dq, dk, dv;
1378
- if (dq_.has_value()) {
1379
- dq = dq_.value();
1380
- TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q");
1381
- CHECK_DEVICE(dq);
1382
- TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
1383
- if (!is_varlen_q) {
1384
- CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
1385
- } else {
1386
- CHECK_SHAPE(dq, total_q, num_heads, head_size);
1387
- }
1388
- } else {
1389
- dq = torch::empty_like(q);
1390
- }
1391
- if (dk_.has_value()) {
1392
- dk = dk_.value();
1393
- TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q");
1394
- CHECK_DEVICE(dk);
1395
- TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
1396
- if (!is_varlen_k) {
1397
- CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
1398
- } else {
1399
- CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
1400
- }
1401
- } else {
1402
- dk = torch::empty_like(k);
1403
- }
1404
- if (dv_.has_value()) {
1405
- dv = dv_.value();
1406
- TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q");
1407
- CHECK_DEVICE(dv);
1408
- TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
1409
- if (!is_varlen_k) {
1410
- CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_v);
1411
- } else {
1412
- CHECK_SHAPE(dv, total_k, num_heads_k, head_size_v);
1413
- }
1414
- } else {
1415
- dv = torch::empty_like(v);
1416
- }
1417
-
1418
- // Otherwise the kernel will be launched from cuda:0 device
1419
- // Cast to char to avoid compiler warning about narrowing
1420
- at::cuda::CUDAGuard device_guard{(char)q.get_device()};
1421
-
1422
- auto opts = q.options();
1423
- // Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
1424
- at::Tensor softmax_d, softmax_lse_log2;
1425
- if (!is_varlen) {
1426
- // Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
1427
- softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
1428
- softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
1429
- } else {
1430
- softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
1431
- softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
1432
- }
1433
- at::Tensor dq_accum, dk_accum, dv_accum;
1434
- if (!is_varlen) {
1435
- dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, opts.dtype(at::kFloat));
1436
- } else {
1437
- dq_accum = torch::empty({num_heads, total_q_padded_rounded * head_size_rounded}, opts.dtype(at::kFloat));
1438
- }
1439
- if (num_heads_k != num_heads) { // MQA / GQA
1440
- if (!is_varlen) {
1441
- dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
1442
- dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_v_rounded}, opts.dtype(at::kFloat));
1443
- } else {
1444
- dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
1445
- dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_v_rounded}, opts.dtype(at::kFloat));
1446
- }
1447
- }
1448
-
1449
- Flash_bwd_params params;
1450
- set_params_dgrad(params,
1451
- batch_size,
1452
- seqlen_q, seqlen_k,
1453
- seqlen_q_rounded, seqlen_k_rounded,
1454
- num_heads, num_heads_k,
1455
- head_size, head_size_rounded,
1456
- q, k, v, out,
1457
- dout, dq, dk, dv,
1458
- !is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
1459
- !is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
1460
- seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
1461
- seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
1462
- dq_accum.data_ptr(),
1463
- num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr,
1464
- num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr,
1465
- softmax_lse.data_ptr(),
1466
- softmax_d.data_ptr(),
1467
- /*p_dropout=*/0.f,
1468
- softmax_scale,
1469
- window_size_left,
1470
- window_size_right,
1471
- 0, // attention_chunk
1472
- softcap,
1473
- deterministic,
1474
- sm_margin);
1475
- params.total_q = total_q;
1476
- params.total_k = total_k;
1477
- params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
1478
- params.dv = head_size_v;
1479
- params.dv_rounded = head_size_v_rounded;
1480
-
1481
- // auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
1482
- // params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
1483
- // Will be zero'ed out in the backward preprocess kernel
1484
- at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
1485
- params.dq_semaphore = dq_semaphore.data_ptr<int>();
1486
- if (num_heads_k != num_heads && params.deterministic) {
1487
- // TODO: do we need to zero them out?
1488
- at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
1489
- at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
1490
- params.dk_semaphore = dk_semaphore.data_ptr<int>();
1491
- params.dv_semaphore = dv_semaphore.data_ptr<int>();
1492
- }
1493
-
1494
- #ifdef FLASHATTENTION_DISABLE_LOCAL
1495
- TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
1496
- #endif
1497
- #ifdef FLASHATTENTION_DISABLE_SOFTCAP
1498
- TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
1499
- #endif
1500
-
1501
- if (total_q > 0 && total_k > 0 && num_heads_k > 0) {
1502
- auto stream = at::cuda::getCurrentCUDAStream().stream();
1503
- run_mha_bwd(params, stream);
1504
- } else if (total_k > 0 && num_heads_k > 0) {
1505
- // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
1506
- dk.zero_();
1507
- dv.zero_();
1508
- softmax_d.zero_();
1509
- } else if (total_q > 0 && num_heads_k > 0) {
1510
- dq.zero_();
1511
- softmax_d.zero_();
1512
- }
1513
-
1514
- return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
1515
- }
1516
-
1517
- std::tuple<at::Tensor, at::Tensor>
1518
- mha_combine(at::Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size
1519
- at::Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads
1520
- std::optional<at::Tensor> out_, // batch_size x seqlen x num_heads x head_size
1521
- std::optional<at::ScalarType> out_dtype_
1522
- ) {
1523
-
1524
- auto dprops = at::cuda::getCurrentDeviceProperties();
1525
- bool is_sm8x = dprops->major >= 8;
1526
- TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer.");
1527
-
1528
- auto out_partial_type = out_partial.scalar_type();
1529
- TORCH_CHECK(out_partial_type == at::ScalarType::Float, "Attention combine function only support fp32 data type");
1530
- TORCH_CHECK(lse_partial.scalar_type() == at::ScalarType::Float, "Attention combine function only support fp32 data type");
1531
-
1532
- CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial);
1533
-
1534
- TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension");
1535
- TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension");
1536
-
1537
- const auto sizes = out_partial.sizes();
1538
-
1539
- const int num_splits = sizes[0];
1540
- const int batch_size = sizes[1];
1541
- const int seqlen = sizes[2];
1542
- const int num_heads = sizes[3];
1543
- const int head_size_og = sizes[4];
1544
- TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256");
1545
-
1546
- CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og);
1547
- CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads);
1548
-
1549
- int const alignment = 4;
1550
- at::Tensor out_partial_padded;
1551
- auto pad = [](at::Tensor x, int alignment) {
1552
- return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));
1553
- };
1554
- out_partial_padded = pad(out_partial, alignment);
1555
-
1556
- auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
1557
- const int head_size = round_multiple(head_size_og, alignment);
1558
-
1559
- auto opts = out_partial.options();
1560
- at::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type());
1561
- TORCH_CHECK(out_type == at::ScalarType::Float || out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Half, "Output type must be FP32, FP16 or BF16");
1562
- at::Tensor out;
1563
- if (out_.has_value()) {
1564
- out = out_.value();
1565
- TORCH_CHECK(out.scalar_type() == out_type);
1566
- CHECK_DEVICE(out);
1567
- TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
1568
- CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og);
1569
- if (head_size_og % alignment != 0) {
1570
- out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
1571
- }
1572
- } else {
1573
- out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
1574
- }
1575
-
1576
- // Otherwise the kernel will be launched from cuda:0 device
1577
- // Cast to char to avoid compiler warning about narrowing
1578
- at::cuda::CUDAGuard device_guard{(char)out_partial.get_device()};
1579
-
1580
- auto softmax_lse = torch::empty({batch_size, num_heads, seqlen}, opts.dtype(at::kFloat)).transpose(1, 2);
1581
-
1582
- Flash_fwd_params params {}; // Need to reset the params to set everything to zero
1583
- params.is_fp32 = out_type == at::ScalarType::Float;
1584
- params.is_bf16 = out_type == at::ScalarType::BFloat16;
1585
- params.oaccum_ptr = out_partial_padded.data_ptr();
1586
- params.softmax_lseaccum_ptr = lse_partial.data_ptr();
1587
- params.o_ptr = out.data_ptr();
1588
- params.softmax_lse_ptr = softmax_lse.data_ptr();
1589
- params.b = batch_size;
1590
- params.h = num_heads;
1591
- params.seqlen_q = seqlen;
1592
- params.dv = head_size;
1593
- params.num_splits = num_splits;
1594
- params.oaccum_split_stride = out_partial_padded.stride(0);
1595
- params.oaccum_row_stride = out_partial_padded.stride(2);
1596
- params.oaccum_head_stride = out_partial_padded.stride(3);
1597
- params.oaccum_batch_stride = out_partial_padded.stride(1);
1598
- params.lseaccum_split_stride = lse_partial.stride(0);
1599
- params.lseaccum_head_stride = lse_partial.stride(3);
1600
- params.lseaccum_batch_stride = lse_partial.stride(1);
1601
- params.o_row_stride = out.stride(1);
1602
- params.o_head_stride = out.stride(2);
1603
- params.o_batch_stride = out.stride(0);
1604
- params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
1605
-
1606
- if (seqlen > 0 && batch_size > 0) {
1607
- auto stream = at::cuda::getCurrentCUDAStream().stream();
1608
- run_mha_fwd_combine(params, stream, false /*enable_pdl*/);
1609
- }
1610
-
1611
- at::Tensor out_padded = out;
1612
- if (head_size_og % alignment != 0) {
1613
- out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
1614
- // if (out_.has_value()) { out_.value().copy_(out); }
1615
- }
1616
-
1617
- return {out, softmax_lse};
1618
- }
1619
-
1620
- #ifdef false
1621
-
1622
- TORCH_LIBRARY(flash_attn_3, m) {
1623
- m.def("fwd("
1624
- "Tensor q,"
1625
- "Tensor k,"
1626
- "Tensor v,"
1627
- "Tensor(k_new!)? k_new = None,"
1628
- "Tensor(v_new!)? v_new = None,"
1629
- "Tensor? q_v = None,"
1630
- "Tensor(out!)? out = None,"
1631
- "Tensor? cu_seqlens_q = None,"
1632
- "Tensor? cu_seqlens_k = None,"
1633
- "Tensor? cu_seqlens_k_new = None,"
1634
- "Tensor? seqused_q = None,"
1635
- "Tensor? seqused_k = None,"
1636
- "int? max_seqlen_q = None,"
1637
- "int? max_seqlen_k = None,"
1638
- "Tensor? page_table = None,"
1639
- "Tensor? kv_batch_idx = None,"
1640
- "Tensor? leftpad_k = None,"
1641
- "Tensor? rotary_cos = None,"
1642
- "Tensor? rotary_sin = None,"
1643
- "Tensor? seqlens_rotary = None,"
1644
- "Tensor? q_descale = None,"
1645
- "Tensor? k_descale = None,"
1646
- "Tensor? v_descale = None,"
1647
- "float? softmax_scale = None,"
1648
- "bool is_causal = False,"
1649
- "int window_size_left = -1,"
1650
- "int window_size_right = -1,"
1651
- "int attention_chunk = 0,"
1652
- "float softcap = 0.0,"
1653
- "bool is_rotary_interleaved = False,"
1654
- "Tensor? scheduler_metadata = None,"
1655
- "int num_splits = 0,"
1656
- "bool? pack_gqa = None,"
1657
- "int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)");
1658
- m.def("bwd("
1659
- "Tensor dout,"
1660
- "Tensor q,"
1661
- "Tensor k,"
1662
- "Tensor v,"
1663
- "Tensor out,"
1664
- "Tensor softmax_lse,"
1665
- "Tensor(dq!)? dq = None,"
1666
- "Tensor(dk!)? dk = None,"
1667
- "Tensor(dv!)? dv = None,"
1668
- "Tensor? cu_seqlens_q = None,"
1669
- "Tensor? cu_seqlens_k = None,"
1670
- "Tensor? seqused_q = None,"
1671
- "Tensor? seqused_k = None,"
1672
- "int? max_seqlen_q = None,"
1673
- "int? max_seqlen_k = None,"
1674
- "float? softmax_scale = None,"
1675
- "bool is_causal = False,"
1676
- "int window_size_left = -1,"
1677
- "int window_size_right = -1,"
1678
- "float softcap = 0.0,"
1679
- "bool deterministic = False,"
1680
- "int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)");
1681
- m.def("fwd_combine("
1682
- "Tensor out_partial,"
1683
- "Tensor lse_partial,"
1684
- "Tensor(out!)? out = None,"
1685
- "ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)");
1686
- m.def("get_scheduler_metadata("
1687
- "int batch_size,"
1688
- "int max_seqlen_q,"
1689
- "int max_seqlen_k,"
1690
- "int num_heads,"
1691
- "int num_heads_k,"
1692
- "int headdim,"
1693
- "int headdim_v,"
1694
- "ScalarType qkv_dtype,"
1695
- "Tensor seqused_k,"
1696
- "Tensor? cu_seqlens_q = None,"
1697
- "Tensor? cu_seqlens_k = None,"
1698
- "Tensor? cu_seqlens_k_new = None,"
1699
- "Tensor? seqused_q = None,"
1700
- "Tensor? leftpad_k = None,"
1701
- "int? page_size = None,"
1702
- "int max_seqlen_k_new = 0,"
1703
- "bool is_causal = False,"
1704
- "int window_size_left = -1,"
1705
- "int window_size_right = -1,"
1706
- "int attention_chunk = 0,"
1707
- "bool has_softcap = False,"
1708
- "int num_splits = 0,"
1709
- "bool? pack_gqa = None,"
1710
- "int sm_margin = 0) -> Tensor");
1711
- }
1712
-
1713
- TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) {
1714
- m.impl("fwd", &mha_fwd);
1715
- m.impl("bwd", &mha_bwd);
1716
- m.impl("fwd_combine", &mha_combine);
1717
- m.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata);
1718
- }
1719
-
1720
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_bwd_kernel_sm80.h DELETED
@@ -1,173 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include <cutlass/cutlass.h>
10
- #include <cutlass/array.h>
11
- #include <cutlass/numeric_types.h>
12
- #include <cutlass/kernel_hardware_info.h>
13
-
14
- #include "utils.h"
15
-
16
- namespace flash {
17
-
18
- using namespace cute;
19
-
20
- template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
21
- class FlashAttnBwdSm80 {
22
-
23
- public:
24
-
25
- // Type Aliases
26
- static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
27
- static constexpr bool Is_local = CollectiveMainloop_::Is_local;
28
- static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
29
- static constexpr bool Varlen = CollectiveMainloop_::Varlen;
30
-
31
- // Mainloop derived types
32
- using CollectiveMainloop = CollectiveMainloop_;
33
- using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
34
- using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
35
- using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
36
- using ArchTag = typename CollectiveMainloop::ArchTag;
37
- using MainloopArguments = typename CollectiveMainloop::Arguments;
38
- using MainloopParams = typename CollectiveMainloop::Params;
39
- static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
40
-
41
- // Epilogue derived types
42
- using CollectiveEpilogue = CollectiveEpilogue_;
43
- using EpilogueArguments = typename CollectiveEpilogue::Arguments;
44
- using EpilogueParams = typename CollectiveEpilogue::Params;
45
-
46
- static_assert(ArchTag::kMinComputeCapability >= 80);
47
-
48
- using TileScheduler = TileScheduler_;
49
- using TileSchedulerArguments = typename flash::TileSchedulerArguments;
50
- using TileSchedulerParams = typename TileScheduler::Params;
51
-
52
- static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMmaSdP{}));
53
- static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{}));
54
- static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
55
-
56
- // Kernel level shared memory storage
57
- struct SharedStorage {
58
- struct TensorStorage : cute::aligned_struct<128> {
59
- union {
60
- typename CollectiveMainloop::TensorStorage mainloop;
61
- typename CollectiveEpilogue::TensorStorage epilogue;
62
- };
63
- } tensors;
64
-
65
- alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
66
-
67
- };
68
-
69
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
70
-
71
- // Device side arguments
72
- struct Arguments {
73
- MainloopArguments mainloop{};
74
- EpilogueArguments epilogue{};
75
- cutlass::KernelHardwareInfo hw_info{};
76
- TileSchedulerArguments scheduler{};
77
- };
78
-
79
- // Kernel entry point API
80
- struct Params {
81
- MainloopParams mainloop{};
82
- EpilogueParams epilogue{};
83
- cutlass::KernelHardwareInfo hw_info{};
84
- TileSchedulerParams scheduler{};
85
- };
86
-
87
- //
88
- // Methods
89
- //
90
-
91
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
92
- static
93
- Params
94
- to_underlying_arguments(Arguments const& args) {
95
- CUTLASS_TRACE_HOST("to_underlying_arguments():");
96
-
97
- // Get SM count if needed, otherwise use user supplied SM count
98
- int sm_count = args.hw_info.sm_count;
99
- if (sm_count <= 0) {
100
- CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
101
- " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
102
- sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
103
- }
104
-
105
- CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
106
-
107
- cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
108
- return {
109
- CollectiveMainloop::to_underlying_arguments(args.mainloop),
110
- CollectiveEpilogue::to_underlying_arguments(args.epilogue),
111
- hw_info,
112
- TileScheduler::to_underlying_arguments(args.scheduler)
113
- };
114
- }
115
-
116
- // Computes the kernel launch grid shape based on runtime parameters
117
- static dim3
118
- get_grid_shape(Params const& params) {
119
- return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
120
- }
121
-
122
- static dim3
123
- get_block_shape() {
124
- return dim3(MaxThreadsPerBlock, 1, 1);
125
- }
126
-
127
- CUTLASS_DEVICE
128
- void
129
- operator()(Params const& params, char* smem_buf) {
130
-
131
- static constexpr int kBlockM = get<0>(TileShape_MNK{});
132
- static constexpr int kBlockN = get<1>(TileShape_MNK{});
133
-
134
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
135
-
136
- CollectiveMainloop mainloop;
137
- CollectiveEpilogue epilogue;
138
-
139
- TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
140
- // Initialize matmul objects.
141
- TiledMmadKV tiled_mma_dKV;
142
-
143
- scheduler.init_consumer();
144
-
145
- int warp_idx = cutlass::canonical_warp_idx_sync();
146
- CUTLASS_PRAGMA_NO_UNROLL
147
- for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
148
- work_tile_info.is_valid(params.scheduler);
149
- work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
150
-
151
- auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
152
- auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
153
- cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
154
-
155
- // dK and dV output accumulator.
156
- Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
157
- Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
158
- bool tile_valid = mainloop.mma(params.mainloop, tdKrdK, tdVrdV, threadIdx.x,
159
- block_coord, shared_storage);
160
- scheduler.prefetch_next_work(params.scheduler, work_tile_info);
161
- if (tile_valid) {
162
- epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
163
- threadIdx.x, block_coord);
164
- } else {
165
- epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);
166
- }
167
- }
168
-
169
- }
170
-
171
- };
172
-
173
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_bwd_kernel_sm90.h DELETED
@@ -1,282 +0,0 @@
1
-
2
- /******************************************************************************
3
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
4
- ******************************************************************************/
5
-
6
- #pragma once
7
-
8
- #include "cute/tensor.hpp"
9
-
10
- #include <cutlass/cutlass.h>
11
- #include <cutlass/arch/reg_reconfig.h>
12
- #include <cutlass/array.h>
13
- #include <cutlass/numeric_types.h>
14
- #include <cutlass/numeric_conversion.h>
15
- #include <cutlass/kernel_hardware_info.h>
16
- #include "cutlass/pipeline/pipeline.hpp"
17
-
18
- #include "utils.h"
19
-
20
- namespace flash {
21
-
22
- using namespace cute;
23
-
24
- template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
25
- class FlashAttnBwdSm90 {
26
-
27
- public:
28
-
29
- // Type Aliases
30
- static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
31
- static constexpr bool Is_local = CollectiveMainloop_::Is_local;
32
- static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
33
- static constexpr bool Varlen = CollectiveMainloop_::Varlen;
34
-
35
- // Mainloop derived types
36
- using CollectiveMainloop = CollectiveMainloop_;
37
- using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
38
- using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
39
- using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
40
- using ArchTag = typename CollectiveMainloop::ArchTag;
41
- using ClusterShape = typename CollectiveMainloop::ClusterShape;
42
- using MainloopArguments = typename CollectiveMainloop::Arguments;
43
- using MainloopParams = typename CollectiveMainloop::Params;
44
- static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
45
-
46
- // Epilogue derived types
47
- using CollectiveEpilogue = CollectiveEpilogue_;
48
- using EpilogueArguments = typename CollectiveEpilogue::Arguments;
49
- using EpilogueParams = typename CollectiveEpilogue::Params;
50
-
51
- static_assert(ArchTag::kMinComputeCapability >= 90);
52
-
53
- using TileScheduler = TileScheduler_;
54
- using TileSchedulerArguments = typename flash::TileSchedulerArguments;
55
- using TileSchedulerParams = typename TileScheduler::Params;
56
-
57
- static constexpr uint32_t NumLoadWarpGroups = 1;
58
- static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaSdP{})) / cutlass::NumThreadsPerWarpGroup;
59
- static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
60
- static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
61
- static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
62
-
63
- /// Register requirement for Load and Math WGs
64
- static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 2 ? 24 : 32;
65
- static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 240 : 160;
66
- // If you want to print from the producer warp, you'd need to increase the number of registers
67
- // Otherwise you'll get CUDA error.
68
- // static constexpr uint32_t LoadRegisterRequirement = 40;
69
- // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
70
-
71
- // Kernel level shared memory storage
72
- struct SharedStorage {
73
- struct TensorStorage : cute::aligned_struct<128> {
74
- union {
75
- typename CollectiveMainloop::TensorStorage mainloop;
76
- typename CollectiveEpilogue::TensorStorage epilogue;
77
- };
78
- } tensors;
79
-
80
- struct PipelineStorage : cute::aligned_struct<16> {
81
- alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_KV;
82
- alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_q;
83
- alignas(16) typename CollectiveMainloop::MainloopPipeline_dO::SharedStorage pipeline_do;
84
- alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
85
- } pipelines;
86
-
87
- };
88
-
89
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
90
-
91
- // Device side arguments
92
- struct Arguments {
93
- MainloopArguments mainloop{};
94
- EpilogueArguments epilogue{};
95
- cutlass::KernelHardwareInfo hw_info{};
96
- TileSchedulerArguments scheduler{};
97
- };
98
-
99
- // Kernel entry point API
100
- struct Params {
101
- MainloopParams mainloop{};
102
- EpilogueParams epilogue{};
103
- cutlass::KernelHardwareInfo hw_info{};
104
- TileSchedulerParams scheduler{};
105
- };
106
-
107
- //
108
- // Methods
109
- //
110
-
111
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
112
- static
113
- Params
114
- to_underlying_arguments(Arguments const& args) {
115
- CUTLASS_TRACE_HOST("to_underlying_arguments():");
116
-
117
- // Get SM count if needed, otherwise use user supplied SM count
118
- int sm_count = args.hw_info.sm_count;
119
- if (sm_count <= 0) {
120
- CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
121
- " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
122
- sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
123
- }
124
-
125
- CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
126
-
127
- cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
128
- return {
129
- CollectiveMainloop::to_underlying_arguments(args.mainloop),
130
- CollectiveEpilogue::to_underlying_arguments(args.epilogue),
131
- hw_info,
132
- TileScheduler::to_underlying_arguments(args.scheduler)
133
- };
134
- }
135
-
136
- // Computes the kernel launch grid shape based on runtime parameters
137
- static dim3
138
- get_grid_shape(Params const& params) {
139
- return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
140
- }
141
-
142
- static dim3
143
- get_block_shape() {
144
- return dim3(MaxThreadsPerBlock, 1, 1);
145
- }
146
-
147
- CUTLASS_DEVICE
148
- void
149
- operator()(Params const& params, char* smem_buf) {
150
-
151
- static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
152
- static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
153
- static constexpr int kBlockM = get<0>(TileShape_MNK{});
154
- static constexpr int kBlockN = get<1>(TileShape_MNK{});
155
-
156
- using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
157
- using PipelineParams = typename MainloopPipeline::Params;
158
- using PipelineState = typename MainloopPipeline::PipelineState;
159
- using MainloopPipeline_dO = typename CollectiveMainloop::MainloopPipeline_dO;
160
- using PipelineParams_dO = typename MainloopPipeline_dO::Params;
161
- using PipelineState_dO = typename MainloopPipeline_dO::PipelineState;
162
- static constexpr bool Q_dO_same_stages = std::is_same_v<MainloopPipeline, MainloopPipeline_dO>;
163
-
164
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
165
-
166
- int const lane_predicate = cute::elect_one_sync();
167
- int const warp_idx = cutlass::canonical_warp_idx_sync();
168
-
169
- // Issue Tma Descriptor Prefetch from a single thread
170
- if (warp_idx == 0 && lane_predicate) {
171
- CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
172
- CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
173
- }
174
-
175
- // Obtain warp index
176
- int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
177
-
178
- PipelineParams pipeline_params;
179
- pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesQ + CollectiveMainloop::TmaTransactionBytesLSE;
180
- int warp_group_idx = cutlass::canonical_warp_group_idx();
181
- pipeline_params.role = warp_group_idx == 0
182
- ? MainloopPipeline::ThreadCategory::Producer
183
- : MainloopPipeline::ThreadCategory::Consumer;
184
- pipeline_params.is_leader = warp_group_thread_idx == 0;
185
- pipeline_params.num_consumers = NumMmaThreads;
186
-
187
- if (warp_idx == 0 && lane_predicate) {
188
- shared_storage.pipelines.barrier_KV.init(1 /*numThreads*/);
189
- }
190
- // We're counting on pipeline_q to call cutlass::arch::fence_barrier_init();
191
- MainloopPipeline pipeline_q(shared_storage.pipelines.pipeline_q, pipeline_params, ClusterShape{});
192
- auto role_dO = warp_group_idx == 0
193
- ? MainloopPipeline_dO::ThreadCategory::Producer
194
- : MainloopPipeline_dO::ThreadCategory::Consumer;
195
- PipelineParams_dO pipeline_params_dO {pipeline_params.transaction_bytes, role_dO, pipeline_params.is_leader, pipeline_params.num_consumers};
196
- MainloopPipeline_dO pipeline_do(shared_storage.pipelines.pipeline_do, cute::conditional_return<Q_dO_same_stages>(pipeline_params, pipeline_params_dO), ClusterShape{});
197
-
198
- CollectiveMainloop mainloop;
199
- CollectiveEpilogue epilogue;
200
-
201
- // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
202
- if constexpr (size(ClusterShape{}) > 1) {
203
- cute::cluster_arrive_relaxed();
204
- cute::cluster_wait();
205
- } else {
206
- __syncthreads();
207
- }
208
-
209
- TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
210
-
211
- if (warp_group_idx == 0) { // Producer
212
- cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
213
-
214
- int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
215
- if (warp_idx_in_warpgroup == 0) { // Load K, V, and do TMA on Q and dO
216
- PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
217
- PipelineState_dO smem_pipe_write_do = cutlass::make_producer_start_state<MainloopPipeline_dO>();
218
- for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler);
219
- work_tile_info.is_valid(params.scheduler);
220
- work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info)) {
221
- auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
222
- auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
223
- cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
224
- auto scheduler_prefetch = [&scheduler, &params, &work_tile_info]() {
225
- scheduler.prefetch_next_work(params.scheduler, work_tile_info);
226
- };
227
- mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write,
228
- smem_pipe_write_do, shared_storage, scheduler_prefetch, block_coord);
229
- }
230
- mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do);
231
- } else if (warp_idx_in_warpgroup == 1) {
232
- for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
233
- work_tile_info.is_valid(params.scheduler);
234
- work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
235
- auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
236
- auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
237
- cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
238
- mainloop.store_dq(params.mainloop, shared_storage, block_coord);
239
- }
240
- }
241
- } else { // Consumer
242
- cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
243
- // Initialize matmul objects.
244
- TiledMmadKV tiled_mma_dKV;
245
-
246
- PipelineState smem_pipe_read;
247
- PipelineState_dO smem_pipe_read_do;
248
-
249
- mainloop.mma_init();
250
- scheduler.init_consumer();
251
-
252
- int work_idx = 0;
253
- CUTLASS_PRAGMA_NO_UNROLL
254
- for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
255
- work_tile_info.is_valid(params.scheduler);
256
- work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
257
- auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
258
- auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
259
- cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
260
-
261
- // dK and dV output accumulator.
262
- Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
263
- Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
264
- bool tile_valid = mainloop.mma(
265
- params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, smem_pipe_read_do,
266
- tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage);
267
- if (tile_valid) {
268
- epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
269
- threadIdx.x - NumCopyThreads, block_coord);
270
- } else {
271
- epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
272
- }
273
-
274
- }
275
- epilogue.store_tail();
276
- }
277
-
278
- }
279
-
280
- };
281
-
282
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_bwd_launch_template.h DELETED
@@ -1,390 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include "cutlass/device_kernel.h" // For device_kernel
10
- #include "cutlass/kernel_launch.h" // For kernel_launch
11
- #include "cutlass/cluster_launch.hpp" // For ClusterLauncher
12
-
13
- #include "static_switch.h"
14
- #include "flash.h"
15
- #include "flash_bwd_preprocess_kernel.h"
16
- #include "flash_bwd_postprocess_kernel.h"
17
- #include "tile_scheduler.hpp"
18
- #include "mainloop_bwd_sm90_tma_gmma_ws.hpp"
19
- #include "mainloop_bwd_sm80.hpp"
20
- #include "epilogue_bwd.hpp"
21
- #include "flash_bwd_kernel_sm90.h"
22
- #include "flash_bwd_kernel_sm80.h"
23
-
24
- using namespace cute;
25
-
26
- template <int Arch, int kHeadDim, int kBlockM, int kBlockN, typename Element,
27
- bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool Deterministic, bool GQA,
28
- int Stages_dO=2, int Stages_dS_or_QSm80=2,
29
- bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
30
- int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
31
- bool V_in_regs=false>
32
- void run_flash_bwd(Flash_bwd_params &params, cudaStream_t stream) {
33
- static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
34
- using ElementAccum = float;
35
- using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
36
-
37
- int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * kBlockM, kBlockM);
38
- int const total_k_padded_rounded = cute::round_up(params.total_k + params.b * kBlockN, kBlockN);
39
- bool const is_varlen_q = params.cu_seqlens_q;
40
- bool const is_varlen_k = params.cu_seqlens_k;
41
- int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
42
- int seqlen_k = !is_varlen_k ? params.seqlen_k : params.total_k;
43
- int seqlen_q_rounded = !is_varlen_q ? params.seqlen_q_rounded : total_q_padded_rounded;
44
- int seqlen_k_rounded = !is_varlen_k ? params.seqlen_k_rounded : total_k_padded_rounded;
45
- int batch_q = !is_varlen_q ? params.b : 1;
46
- int batch_k = !is_varlen_k ? params.b : 1;
47
-
48
- using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
49
- using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, ArchTag, /*Clear_dQaccum=*/true, Varlen>;
50
- typename PreprocessKernel::Arguments preprocess_args {
51
- static_cast<Element const*>(params.o_ptr),
52
- {seqlen_q, params.dv, params.h, batch_q}, // shape_O
53
- {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O
54
- static_cast<Element const*>(params.do_ptr),
55
- {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
56
- static_cast<float*>(params.dsoftmax_sum),
57
- {seqlen_q_rounded, params.h, batch_q}, // shape_dPsum
58
- {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
59
- static_cast<float*>(params.softmax_lse_ptr),
60
- {_1{}, seqlen_q, !is_varlen_q ? params.h * params.seqlen_q : 0}, // stride_LSE
61
- static_cast<float*>(params.softmax_lse_log2_ptr),
62
- {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
63
- static_cast<ElementAccum*>(params.dq_accum_ptr),
64
- {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
65
- {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0}, // stride_dQaccum
66
- params.b,
67
- params.dq_semaphore,
68
- params.cu_seqlens_q,
69
- params.seqused_q
70
- };
71
- typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args);
72
- int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM);
73
- dim3 grid_m(num_m_block, params.h, params.b);
74
- cutlass::kernel_launch<PreprocessKernel>(grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream, preprocess_params, false /*launch_with_pdl*/);
75
- CHECK_CUDA_KERNEL_LAUNCH();
76
-
77
- using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
78
- using ClusterShape = cute::Shape<_1, Int<1>, _1>; // Currently doesn't not support cluster
79
- // Stages_dS_or_QSm80 is Stages_dS if Sm90 and Stages if Sm80
80
- static constexpr int Stages = Arch >= 90 ? 2 : Stages_dS_or_QSm80;
81
- static constexpr int Stages_dS = Arch >= 90 ? Stages_dS_or_QSm80 : 1;
82
- using CollectiveMainloop = std::conditional_t<
83
- Arch >= 90,
84
- flash::CollectiveMainloopBwdSm90<Stages, Stages_dO, Stages_dS, ClusterShape, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm90,
85
- Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
86
- SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>,
87
- flash::CollectiveMainloopBwdSm80<Stages, Stages_dO, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm80,
88
- Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
89
- SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>
90
- >;
91
- using CollectiveEpilogue = std::conditional_t<
92
- !GQA,
93
- flash::CollectiveEpilogueBwd<TileShape_MNK, Element, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, dKV_swapAB, NumMmaWarpGroups * (Arch >= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>,
94
- flash::CollectiveEpilogueBwdGQA<TileShape_MNK, ElementAccum, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, Deterministic>
95
- >;
96
- using Scheduler = std::conditional_t<
97
- Is_causal && !Varlen,
98
- flash::SingleTileBwdLPTScheduler,
99
- flash::SingleTileScheduler<Varlen, false /*Split*/, false /*PackGQA*/, kBlockN>
100
- >;
101
- using AttnKernel = std::conditional_t<
102
- Arch >= 90,
103
- flash::enable_sm90_or_later<flash::FlashAttnBwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
104
- flash::enable_sm80_to_sm89<flash::FlashAttnBwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
105
- >;
106
-
107
- typename CollectiveMainloop::Arguments mainloop_args {
108
- static_cast<Element const*>(params.q_ptr),
109
- {seqlen_q, params.d, params.h, batch_q}, // shape_Q
110
- {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
111
- static_cast<Element const*>(params.k_ptr),
112
- {seqlen_k, params.d, params.h_k, batch_k}, // shape_K
113
- {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
114
- static_cast<Element const*>(params.v_ptr),
115
- {seqlen_k, params.dv, params.h_k, batch_k}, // shape_V
116
- {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V
117
- static_cast<Element const*>(params.do_ptr),
118
- {seqlen_q, params.dv, params.h, batch_q}, // shape_dO
119
- {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
120
- static_cast<ElementAccum*>(params.dq_accum_ptr),
121
- {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
122
- {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
123
- static_cast<float*>(params.softmax_lse_log2_ptr),
124
- {seqlen_q_rounded, params.h, batch_q}, // shape_LSE
125
- {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
126
- static_cast<float*>(params.dsoftmax_sum),
127
- {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
128
- params.scale_softmax,
129
- params.window_size_left, params.window_size_right, 0 /*attention_chunk*/,
130
- params.softcap,
131
- params.b,
132
- params.dq_semaphore,
133
- params.cu_seqlens_q, params.cu_seqlens_k,
134
- params.seqused_q, params.seqused_k
135
- };
136
- // The case work with GQA is ugly but idk how to fix it.
137
- typename CollectiveEpilogue::Arguments epilogue_args {
138
- static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dk_ptr : params.dk_accum_ptr),
139
- [&] {
140
- if constexpr (!GQA) {
141
- return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k}; // shape_dK
142
- } else {
143
- return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}; // shape_dKaccum
144
- }
145
- }(),
146
- [&] {
147
- if constexpr (!GQA) {
148
- return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0}; // stride_dK
149
- } else {
150
- return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dKaccum
151
- }
152
- }(),
153
- static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dv_ptr : params.dv_accum_ptr),
154
- [&] {
155
- if constexpr (!GQA) {
156
- return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.dv, params.h, batch_k}; // shape_dV
157
- } else {
158
- return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}; // shape_dVaccum
159
- }
160
- }(),
161
- [&] {
162
- if constexpr (!GQA) {
163
- return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV
164
- } else {
165
- return typename CollectiveEpilogue::StridedKV {_1{}, params.dv_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.dv_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum
166
- }
167
- }(),
168
- params.h,
169
- params.dk_semaphore,
170
- params.dv_semaphore,
171
- params.cu_seqlens_k,
172
- params.seqused_k,
173
- };
174
-
175
- int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{}));
176
- num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{}));
177
- typename flash::TileSchedulerArguments scheduler_args {
178
- num_blocks_n, params.h, params.b, 1 /*num_splits*/,
179
- params.h / params.h_k,
180
- params.seqlen_k,
181
- params.seqlen_q, params.d, params.dv, sizeof(Element),
182
- params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k
183
- };
184
-
185
- int device;
186
- cudaGetDevice(&device);
187
- typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
188
- mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
189
- });
190
-
191
- dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
192
- dim3 block_dims = AttnKernel::get_block_shape();
193
- int smem_size = AttnKernel::SharedStorageSize;
194
- // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
195
- // int smem_size_do = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_do));
196
- // int smem_size_ds = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_ds));
197
- // int smem_size_dqacc = [&] {
198
- // if constexpr (Arch >= 90) {
199
- // return sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dqacc));
200
- // } else {
201
- // return 0;
202
- // }
203
- // }();
204
- // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
205
- // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
206
- // int smem_size_lse = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_lse));
207
- // int smem_size_dpsum = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dpsum));
208
- // printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d, lse = %d, dpsum = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc, smem_size_lse, smem_size_dpsum);
209
- if constexpr (size(ClusterShape{}) > 1) {
210
- void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
211
- if (smem_size >= 48 * 1024) {
212
- CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
213
- }
214
- dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
215
- cutlass::ClusterLauncher::launch(
216
- grid_dims, cluster_dims, block_dims, smem_size, stream, kernel, kernel_params, false /*launch_with_pdl*/);
217
- } else {
218
- if (smem_size >= 48 * 1024) {
219
- CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<AttnKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
220
- }
221
- cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params, false /*launch_with_pdl*/);
222
- }
223
- CHECK_CUDA_KERNEL_LAUNCH();
224
-
225
- using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_MK, Element, ElementAccum, ArchTag,
226
- AttnKernel::CollectiveMainloop::NumMmaThreads,
227
- typename AttnKernel::CollectiveMainloop::TiledMmadQ,
228
- AttnKernel::CollectiveMainloop::dQ_swapAB
229
- >;
230
- typename PostprocessKernel::Arguments postprocess_args {
231
- static_cast<ElementAccum const*>(params.dq_accum_ptr),
232
- {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
233
- {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
234
- static_cast<Element*>(params.dq_ptr),
235
- {seqlen_q, params.d, params.h, batch_q}, // shape_dQ
236
- {params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ
237
- params.scale_softmax,
238
- params.cu_seqlens_q,
239
- params.seqused_q
240
- };
241
- typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args);
242
- int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{}));
243
- dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b);
244
- int smem_size_postprocess = PostprocessKernel::SharedStorageSize;
245
- if (smem_size_postprocess >= 48 * 1024) {
246
- CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));
247
- }
248
- cutlass::kernel_launch<PostprocessKernel>(grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_params, false /*launch_with_pdl*/);
249
- CHECK_CUDA_KERNEL_LAUNCH();
250
-
251
- if constexpr (GQA) {
252
- using TileShape_NK = cute::Shape<Int<kBlockN>, Int<kHeadDim>>;
253
- using PostprocessKerneldKV = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_NK, Element, ElementAccum, ArchTag,
254
- AttnKernel::CollectiveEpilogue::NumEpilogueThreads,
255
- typename AttnKernel::CollectiveMainloop::TiledMmadKV,
256
- AttnKernel::CollectiveMainloop::dKV_swapAB
257
- >;
258
- typename PostprocessKerneldKV::Arguments postprocess_dK_args {
259
- static_cast<ElementAccum const*>(params.dk_accum_ptr),
260
- {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum
261
- {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum
262
- static_cast<Element*>(params.dk_ptr),
263
- {seqlen_k, params.d, params.h_k, batch_k}, // shape_dK
264
- {params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride}, // stride_dK
265
- 1.f,
266
- params.cu_seqlens_k,
267
- params.seqused_k
268
- };
269
- typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args);
270
- typename PostprocessKerneldKV::Arguments postprocess_dV_args {
271
- static_cast<ElementAccum const*>(params.dv_accum_ptr),
272
- {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}, // shape_dVaccum
273
- {_1{}, seqlen_k_rounded * params.dv_rounded, !is_varlen_k ? params.dv_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum
274
- static_cast<Element*>(params.dv_ptr),
275
- {seqlen_k, params.dv, params.h_k, batch_k}, // shape_dV
276
- {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV
277
- 1.f,
278
- params.cu_seqlens_k,
279
- params.seqused_k
280
- };
281
- typename PostprocessKerneldKV::Params postprocess_dV_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dV_args);
282
- int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK{}));
283
- dim3 grid_n_postprocess(num_n_block_postprocess, params.h_k, params.b);
284
- int smem_size_postprocess = PostprocessKerneldKV::SharedStorageSize;
285
- if (smem_size_postprocess >= 48 * 1024) {
286
- CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKerneldKV>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));
287
- }
288
- cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dK_params, false /*launch_with_pdl*/);
289
- CHECK_CUDA_KERNEL_LAUNCH();
290
- cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dV_params, false /*launch_with_pdl*/);
291
- CHECK_CUDA_KERNEL_LAUNCH();
292
- }
293
-
294
- }
295
-
296
- template<int Arch, typename T, int kBlockM, int kBlockN, int kHeadDim, bool Is_causal, bool Is_local, bool Has_softcap,
297
- int Stages_dO=2, int Stages_dS_or_QSm80=2,
298
- bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
299
- int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
300
- bool V_in_regs=false>
301
- void run_mha_bwd_dispatch(Flash_bwd_params &params, cudaStream_t stream) {
302
- VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
303
- BOOL_SWITCH(params.h != params.h_k, GQA, [&] {
304
- // BOOL_SWITCH(params.deterministic, Deterministic, [&] {
305
- // run_flash_bwd<kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen, false, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>(params, stream);
306
- run_flash_bwd<Arch, kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen /*Varlen*/, false /*Deterministic*/, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>(params, stream);
307
- // });
308
- });
309
- });
310
- }
311
-
312
-
313
- template<int Arch, typename T, bool Has_softcap>
314
- void run_mha_bwd_hdim64(Flash_bwd_params &params, cudaStream_t stream) {
315
- CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
316
- if constexpr (Arch >= 90) {
317
- if constexpr (Is_causal && Has_softcap) {
318
- // register spill with 128 x 128
319
- run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 2, false>(params, stream);
320
- } else {
321
- // With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block.
322
- run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 2, false>(params, stream);
323
- }
324
- } else if constexpr (Arch == 86 || Arch == 89) {
325
- run_mha_bwd_dispatch<Arch, T, 64, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
326
- // run_mha_bwd_dispatch<Arch, T, 96, 96, 64, Is_causal, Is_local, Has_softcap, 1, 2, false, true, true, 2, 2, 4, 4, false>(params, stream);
327
- // run_mha_bwd_dispatch<Arch, T, 80, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 2, 4, 2, true>(params, stream);
328
- // run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 1, 8, 4, false>(params, stream);
329
- } else {
330
- run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 4, 4, 4, false>(params, stream);
331
- }
332
- });
333
- }
334
-
335
- template<int Arch, typename T, bool Has_softcap>
336
- void run_mha_bwd_hdim96(Flash_bwd_params &params, cudaStream_t stream) {
337
- CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
338
- if constexpr (Arch >= 90) {
339
- run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, true>(params, stream);
340
- } else if constexpr (Arch == 86 || Arch == 89) {
341
- run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
342
- } else {
343
- run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, false>(params, stream);
344
- }
345
- });
346
- }
347
-
348
- template<int Arch, typename T, bool Has_softcap>
349
- void run_mha_bwd_hdim128(Flash_bwd_params &params, cudaStream_t stream) {
350
- CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
351
- if constexpr (Arch >= 90) {
352
- if constexpr (Is_causal || Is_local || Has_softcap) {
353
- run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, false>(params, stream);
354
- } else {
355
- run_mha_bwd_dispatch<Arch, T, 80, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 1, false>(params, stream);
356
- }
357
- } else if constexpr (Arch == 86 || Arch == 89) {
358
- run_mha_bwd_dispatch<Arch, T, 64, 96, 128, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 2, 2, true>(params, stream);
359
- } else {
360
- run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 2, 2, false>(params, stream);
361
- }
362
- });
363
- }
364
-
365
- template<int Arch, typename T, bool Has_softcap>
366
- void run_mha_bwd_hdim192(Flash_bwd_params &params, cudaStream_t stream) {
367
- CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
368
- if constexpr (Arch >= 90) {
369
- run_mha_bwd_dispatch<Arch, T, 64, 96, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, true, false, 3, 1, 1, 1, false>(params, stream);
370
- } else if constexpr (Arch == 86 || Arch == 89) {
371
- run_mha_bwd_dispatch<Arch, T, 64, 64, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 2, true>(params, stream);
372
- } else {
373
- run_mha_bwd_dispatch<Arch, T, 64, 80, 192, Is_causal, Is_local, Has_softcap, 1, 2, false, true, false, 2, 4, 2, 2, false>(params, stream);
374
- }
375
- });
376
- }
377
-
378
- template<int Arch, typename T, bool Has_softcap>
379
- void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
380
- CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
381
- if constexpr (Arch >= 90) {
382
- run_mha_bwd_dispatch<Arch, T, 64, 80, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, true, true, 2, 1, 1, 1, false>(params, stream);
383
- } else if constexpr (Arch == 86 || Arch == 89) {
384
- run_mha_bwd_dispatch<Arch, T, 32, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 1, true>(params, stream);
385
- // run_mha_bwd_dispatch<Arch, T, 64, 32, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 1, 2, true>(params, stream);
386
- } else {
387
- run_mha_bwd_dispatch<Arch, T, 64, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 2, 2, false>(params, stream);
388
- }
389
- });
390
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_bwd_postprocess_kernel.h DELETED
@@ -1,256 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include <cutlass/cutlass.h>
10
- #include <cutlass/array.h>
11
- #include <cutlass/numeric_types.h>
12
- #include <cutlass/numeric_conversion.h>
13
- #include "cutlass/arch/barrier.h"
14
-
15
- #include "seqlen.h"
16
- #include "utils.h"
17
-
18
- namespace flash {
19
-
20
- using namespace cute;
21
-
22
- template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, int kNThreads, class TiledMma, bool dQ_swapAB>
23
- class FlashAttnBwdPostprocessConvertdQ {
24
-
25
- public:
26
-
27
- // Type Aliases
28
- using TileShape_MK = TileShape_MK_;
29
- using ArchTag = ArchTag_;
30
-
31
- static_assert(ArchTag::kMinComputeCapability >= 75);
32
- static constexpr bool IsSm90 = ArchTag::kMinComputeCapability >= 90;
33
-
34
- static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
35
- static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
36
-
37
- static constexpr int kBlockM = get<0>(TileShape_MK{});
38
- static constexpr int kHeadDim = get<1>(TileShape_MK{});
39
- static_assert(!IsSm90 || kNThreads % cutlass::NumThreadsPerWarpGroup == 0, "kNThreads must be a multiple of NumThreadsPerWarpGroup");
40
- static constexpr int NumdQWarpGgroups = kNThreads / cutlass::NumThreadsPerWarpGroup;
41
- using R2SLayoutAtomdQaccum = std::conditional_t<
42
- IsSm90,
43
- Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumdQWarpGgroups>>>,
44
- Layout<Shape<Int<kNThreads>>>
45
- >;
46
- using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdQaccum{},
47
- Layout<Shape<Int<IsSm90 ? 4 : 1>>>{})); // Val layout, 1 or 4 vals per read
48
- using G2SLayoutAtomdQaccum = Layout<Shape<Int<kNThreads>>>;
49
- // UniversalCopy instead of AutoVectorizingCopyWithAssumedAlignment as the latter generates cp.async instructions
50
- using G2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, ElementAccum>{}, G2SLayoutAtomdQaccum{},
51
- Layout<Shape<_4>>{})); // Val layout, 4 vals per read
52
- // We don't do bound checking for the gmem -> smem load so we just assert here.
53
- static_assert(IsSm90 || (kBlockM * kHeadDim) % (kNThreads * 4) == 0);
54
- static constexpr int SmemdQaccumSize = size(TileShape_MK{});
55
- using SmemLayoutdQaccumFlat = Layout<Shape<Int<SmemdQaccumSize>>>;
56
- using SmemLayoutdQaccum = std::conditional_t<
57
- IsSm90,
58
- Layout<Shape<Int<kBlockM * kHeadDim / NumdQWarpGgroups>, Int<NumdQWarpGgroups>>>,
59
- Layout<Shape<Int<kBlockM * kHeadDim>>>
60
- >;
61
-
62
- // We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,
63
- // then setting kBlockKSmem to 32 will cause "Static shape_div failure".
64
- // We want to treat it as 64 x 48, so kBlockKSmem should be 16.
65
- static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{});
66
- static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16);
67
- static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
68
- using SmemLayoutAtomdQ =
69
- decltype(composition(Swizzle<kSwizzle, 3, 3>{},
70
- Layout<Shape<Int<8>, Int<kBlockKSmem>>,
71
- Stride<Int<kBlockKSmem>, _1>>{}));
72
- using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{}));
73
- using SmemLayoutdQt =
74
- decltype(cute::composition(SmemLayoutdQ{},
75
- make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})),
76
- make_stride(Int<get<0>(TileShape_MK{})>{}, _1{}))));
77
-
78
- using SmemCopyAtomdQ = Copy_Atom<
79
- std::conditional_t<
80
- IsSm90,
81
- std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
82
- AutoVectorizingCopyWithAssumedAlignment<128>
83
- >,
84
- Element>;
85
-
86
- static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
87
- static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
88
- static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock));
89
- static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
90
- using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
91
- Stride<Int<kGmemThreadsPerRow>, _1>>;
92
- using GmemTiledCopy = decltype(
93
- make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
94
- GmemLayoutAtom{},
95
- Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
96
-
97
- struct SharedStorage : cute::aligned_struct<128> {
98
- cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>> smem_dqacc;
99
- cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
100
- alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum;
101
- };
102
-
103
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
104
-
105
- using ShapedQ = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
106
- using StridedQ = cute::Stride<int64_t, _1, int64_t, int64_t>;
107
- using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
108
- using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
109
-
110
- // Device side arguments
111
- struct Arguments {
112
- ElementAccum const* ptr_dQaccum;
113
- ShapedQaccum const shape_dQaccum;
114
- StridedQaccum const stride_dQaccum;
115
- Element* ptr_dQ;
116
- ShapedQ const shape_dQ;
117
- StridedQ const stride_dQ;
118
- float const softmax_scale;
119
- int const* cu_seqlens = nullptr;
120
- int const* seqused = nullptr;
121
- };
122
-
123
- // Kernel entry point API
124
- struct Params {
125
- ElementAccum const* ptr_dQaccum;
126
- ShapedQaccum const shape_dQaccum;
127
- StridedQaccum const stride_dQaccum;
128
- Element* ptr_dQ;
129
- ShapedQ const shape_dQ;
130
- StridedQ const stride_dQ;
131
- float const softmax_scale;
132
- int const* cu_seqlens = nullptr;
133
- int const* seqused = nullptr;
134
- };
135
-
136
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
137
- static
138
- Params
139
- to_underlying_arguments(Arguments const& args) {
140
- return {
141
- args.ptr_dQaccum,
142
- args.shape_dQaccum,
143
- args.stride_dQaccum,
144
- args.ptr_dQ,
145
- args.shape_dQ,
146
- args.stride_dQ,
147
- args.softmax_scale,
148
- args.cu_seqlens,
149
- args.seqused
150
- };
151
- }
152
-
153
- CUTLASS_DEVICE
154
- void
155
- operator()(Params const& params, char* smem_buf) {
156
-
157
- static constexpr int kBlockM = get<0>(TileShape_MK{});
158
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
159
-
160
- Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{});
161
- Tensor sdQaccum_flat = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumFlat{});
162
- Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{});
163
- Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{});
164
-
165
- int const thread_idx = threadIdx.x;
166
- int const m_block = blockIdx.x;
167
- int const bidh = blockIdx.y;
168
- int const bidb = blockIdx.z;
169
-
170
- flash::SeqlenInfo<true /*Varlen*/, kBlockM> seqlen_info(bidb, size<0>(params.shape_dQ), params.cu_seqlens, params.seqused);
171
- bool const is_varlen = params.cu_seqlens;
172
- if (is_varlen && m_block * kBlockM >= seqlen_info.seqlen) { return; }
173
-
174
- // Step 1: load dQaccum from gmem to smem
175
- Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum const*>(params.ptr_dQaccum)),
176
- params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
177
- Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block)); // (M * K)
178
- if constexpr (IsSm90) { // Use BulkCopy
179
- static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v<ElementAccum> / 8);
180
- auto bulk_copy = Copy_Traits<SM90_BULK_COPY_AUTO>{};
181
- // if (thread0()) { print(gdQaccum); printf("\n"); print(sdQaccum_flat); printf("\n"); }
182
- if (thread_idx == 0) {
183
- shared_storage.barrier_dQaccum.init(1 /*numThreads*/);
184
- shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum);
185
- copy(bulk_copy.with(*reinterpret_cast<uint64_t*>(&shared_storage.barrier_dQaccum)), gdQaccum, sdQaccum_flat);
186
- }
187
- __syncthreads();
188
- shared_storage.barrier_dQaccum.wait(0);
189
- } else {
190
- G2STiledCopydQaccum g2s_tiled_copy_dQaccum;
191
- auto g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);
192
- Tensor tdQgdQaccumg2s = g2s_thr_copy_dQaccum.partition_S(gdQaccum);
193
- Tensor tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum);
194
- cute::copy(g2s_tiled_copy_dQaccum, tdQgdQaccumg2s, tdQsdQaccumg2s);
195
- __syncthreads();
196
- }
197
-
198
- // __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); }
199
-
200
- // Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16
201
- R2STiledCopydQaccum s2r_tiled_copy_dQaccum;
202
- auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx);
203
- Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum);
204
- TiledMma tiled_mma_dQ;
205
- Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 1, !dQ_swapAB ? 1 : 0>(TileShape_MK{}));
206
- // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tiled_mma_dQ); printf("\n"); }
207
- // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tdQsdQaccum); }
208
- // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(taccdQrdQaccum); }
209
- CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum));
210
- Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum);
211
- cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum);
212
- #pragma unroll
213
- for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; }
214
- // Convert tdQrdQ from fp32 to fp16
215
- Tensor rdQ = make_tensor_like<Element>(taccdQrdQaccum);
216
- flash::convert_type_out(taccdQrdQaccum, rdQ);
217
-
218
- // Step 3: Copy dQ from register to smem
219
- auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ);
220
- auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx);
221
- Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
222
- // if (cute::thread0()) { print(smem_tiled_copy_dQ); }
223
- // if (cute::thread0()) { print(smem_thr_copy_dQ); }
224
- // if (cute::thread0()) { print(sdQ); }
225
- Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(cute::conditional_return<!dQ_swapAB>(sdQ, sdQt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
226
- cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
227
- __syncthreads();
228
-
229
- // Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem
230
- Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0);
231
- Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
232
- GmemTiledCopy gmem_tiled_copy_dQ;
233
- auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx);
234
- Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
235
- Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
236
-
237
- Tensor tdQrdQ = make_fragment_like(tdQsdQ);
238
- Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cute::make_identity_tensor(TileShape_MK{}));
239
- Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
240
- #pragma unroll
241
- for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); }
242
- // Need to check OOB when reading from smem if kBlockM isn't evenly tiled
243
- static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;
244
- flash::copy</*Is_even_MN=*/EvenM, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
245
- gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ, tdQcdQ, tdQpdQ, kBlockM);
246
-
247
- // Step 5: Copy dQ from register to gmem
248
- // Clear_OOB_K must be false since we don't want to write zeros to gmem
249
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
250
- gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, std::min(seqlen_info.seqlen - m_block * kBlockM, kBlockM)
251
- );
252
- }
253
-
254
- };
255
-
256
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_bwd_preprocess_kernel.h DELETED
@@ -1,252 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include <cutlass/cutlass.h>
10
- #include <cutlass/array.h>
11
- #include <cutlass/numeric_types.h>
12
- #include <cutlass/numeric_conversion.h>
13
-
14
- #include "seqlen.h"
15
- #include "utils.h"
16
-
17
- namespace flash {
18
-
19
- using namespace cute;
20
-
21
- template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, bool Clear_dQaccum, bool Varlen>
22
- class FlashAttnBwdPreprocess {
23
-
24
- public:
25
-
26
- // Type Aliases
27
- using TileShape_MK = TileShape_MK_;
28
- using ArchTag = ArchTag_;
29
-
30
- static_assert(std::is_same_v<Element, cutlass::half_t> && ArchTag::kMinComputeCapability >= 75 ||
31
- std::is_same_v<Element, cutlass::bfloat16_t> && ArchTag::kMinComputeCapability >= 80 ||
32
- std::is_same_v<Element, cutlass::float_e4m3_t> && ArchTag::kMinComputeCapability >= 89);
33
-
34
- static constexpr uint32_t MaxThreadsPerBlock = 256;
35
- static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
36
- static constexpr int SharedStorageSize = 0;
37
-
38
- static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
39
- static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
40
- static constexpr int kBlockM = get<0>(TileShape_MK{});
41
- static constexpr int kHeadDim = get<1>(TileShape_MK{});
42
- // We want kBlockKGmem to be a power of 2 so that when we do the summing,
43
- // it's just between threads in the same warp
44
- static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
45
- static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
46
- static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
47
- using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
48
- Stride<Int<kGmemThreadsPerRow>, _1>>;
49
- using GmemTiledCopy = decltype(
50
- make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
51
- GmemLayoutAtom{},
52
- Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
53
-
54
- static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
55
- static_assert((kBlockM * kHeadDim / kGmemElemsPerLoadAccum) % MaxThreadsPerBlock == 0, "MaxThreadsPerBlock must divide kBlockM * kHeadDim / kGmemElemsPerLoadAccum");
56
- using GmemLayoutAtomAccum = Layout<Shape<Int<MaxThreadsPerBlock>>>;
57
- using GmemTiledCopyAccum = decltype(
58
- make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
59
- GmemLayoutAtomAccum{},
60
- Layout<Shape<Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
61
-
62
- using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
63
- using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
64
- using ShapedPsum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q, head, batch)
65
- using StridedPsum = cute::Stride<_1, int64_t, int64_t>;
66
- using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
67
- using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
68
-
69
- // Device side arguments
70
- struct Arguments {
71
- Element const* ptr_O;
72
- ShapeO const shape_O;
73
- StrideO const stride_O;
74
- Element const* ptr_dO;
75
- StrideO const stride_dO;
76
- float* ptr_dPsum;
77
- ShapedPsum const shape_dPsum;
78
- StridedPsum const stride_dPsum;
79
- float const* ptr_LSE;
80
- StridedPsum const stride_LSE;
81
- float *ptr_LSE_log2;
82
- StridedPsum const stride_LSE_log2;
83
- ElementAccum* ptr_dQaccum;
84
- ShapedQaccum const shape_dQaccum;
85
- StridedQaccum const stride_dQaccum;
86
- int num_batch; // We need this to know the size of dq_semaphore in case of varlen
87
- int* dq_semaphore;
88
- int const* cu_seqlens = nullptr;
89
- int const* seqused = nullptr;
90
- };
91
-
92
- // Kernel entry point API
93
- struct Params {
94
- Element const* ptr_O;
95
- ShapeO const shape_O;
96
- StrideO const stride_O;
97
- Element const* ptr_dO;
98
- StrideO const stride_dO;
99
- float* ptr_dPsum;
100
- ShapedPsum const shape_dPsum;
101
- StridedPsum const stride_dPsum;
102
- float const* ptr_LSE;
103
- StridedPsum const stride_LSE;
104
- float* ptr_LSE_log2;
105
- StridedPsum const stride_LSE_log2;
106
- ElementAccum* ptr_dQaccum;
107
- ShapedQaccum const shape_dQaccum;
108
- StridedQaccum const stride_dQaccum;
109
- int num_batch;
110
- int* dq_semaphore;
111
- int const* cu_seqlens = nullptr;
112
- int const* seqused = nullptr;
113
- };
114
-
115
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
116
- static
117
- Params
118
- to_underlying_arguments(Arguments const& args) {
119
- return {
120
- args.ptr_O,
121
- args.shape_O,
122
- args.stride_O,
123
- args.ptr_dO,
124
- args.stride_dO,
125
- args.ptr_dPsum,
126
- args.shape_dPsum,
127
- args.stride_dPsum,
128
- args.ptr_LSE,
129
- args.stride_LSE,
130
- args.ptr_LSE_log2,
131
- args.stride_LSE_log2,
132
- args.ptr_dQaccum,
133
- args.shape_dQaccum,
134
- args.stride_dQaccum,
135
- args.num_batch,
136
- args.dq_semaphore,
137
- args.cu_seqlens,
138
- args.seqused
139
- };
140
- }
141
-
142
- CUTLASS_DEVICE
143
- void
144
- operator()(Params const& params, [[maybe_unused]] char* smem_buf) {
145
-
146
- static constexpr int kBlockM = get<0>(TileShape_MK{});
147
-
148
- int const thread_idx = threadIdx.x;
149
- int const m_block = blockIdx.x;
150
- int const bidh = blockIdx.y;
151
- int const bidb = blockIdx.z;
152
-
153
- flash::SeqlenInfo<Varlen, kBlockM> seqlen_info(bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused);
154
- bool const is_varlen = Varlen && params.cu_seqlens;
155
- int const seqlen_o = seqlen_info.seqlen;
156
- if (is_varlen && m_block * kBlockM >= seqlen_o) { return; }
157
-
158
- Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0);
159
- Tensor gO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
160
- Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_O, params.stride_dO)(_, _, bidh, !is_varlen ? bidb : 0);
161
- Tensor gdO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
162
-
163
- auto shape_LSE = select<0, 2, 3>(params.shape_O);
164
- Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !is_varlen ? bidb : 0);
165
- Tensor gLSE = local_tile(cute::domain_offset(make_coord(seqlen_info.offset), mLSE), Shape<Int<kBlockM>>{}, make_coord(m_block));
166
- static_assert(kBlockM <= MaxThreadsPerBlock);
167
- float lse = thread_idx < seqlen_o - m_block * kBlockM && thread_idx < kBlockM ? gLSE(thread_idx) : INFINITY;
168
-
169
- GmemTiledCopy gmem_tiled_copy_O;
170
- auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
171
-
172
- Tensor tOgO = gmem_thr_copy_O.partition_S(gO);
173
- Tensor tOgdO = gmem_thr_copy_O.partition_S(gdO);
174
- // Construct identity layout for gO
175
- Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
176
- // Repeat the partitioning with identity layouts
177
- Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
178
- Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
179
- #pragma unroll
180
- for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
181
-
182
- // (8, kBlockM / 32, kHeadDim / 64) or (8, kBlockM / 16, kHeadDim / 128)
183
- Tensor tOrO = make_fragment_like(tOgO);
184
- Tensor tOrdO = make_fragment_like(tOgdO);
185
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(
186
- gmem_tiled_copy_O, tOgO, tOrO, tOcO, tOpO, seqlen_o - m_block * kBlockM
187
- );
188
- flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(
189
- gmem_tiled_copy_O, tOgdO, tOrdO, tOcO, tOpO, seqlen_o - m_block * kBlockM
190
- );
191
- // if (threadIdx.x == 222) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_o = %d, m_block = %d, seqlen_o - m_block * kBlockM = %d, tOgO addr = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_o, m_block, seqlen_o - m_block * kBlockM, &tOgO(0));}
192
-
193
- // Reshape from e.g. (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, (8, kHeadDim / 64))
194
- Layout l = make_layout(get<1>(tOrO.layout()), make_layout(get<0>(tOrO.layout()), get<2>(tOrO.layout())));
195
- Tensor tOrO_l = make_tensor(tOrO.data(), l);
196
- Tensor o_fp32 = make_tensor_like<float>(tOrO_l);
197
- flash::convert_type_out(tOrO_l, o_fp32);
198
- Tensor tOrdO_l = make_tensor(tOrdO.data(), l);
199
- Tensor do_fp32 = make_tensor_like<float>(tOrdO_l);
200
- flash::convert_type_out(tOrdO_l, do_fp32);
201
- // Sum across the last dimension
202
- Tensor dP_sum = make_tensor<float>(make_shape(size<0>(o_fp32)));
203
- #pragma unroll
204
- for (int mi = 0; mi < size<0>(o_fp32); ++mi) {
205
- float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
206
- #pragma unroll
207
- for (int ni = 1; ni < size<1>(o_fp32); ni++) {
208
- dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
209
- }
210
- flash::SumOp<float> sum_op;
211
- dP_sum(mi) = flash::Allreduce<kGmemThreadsPerRow>::run(dP_sum_cur, sum_op);
212
- }
213
-
214
- Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(_, bidh, !is_varlen ? bidb : 0);
215
- Tensor gdPsum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mdPsum), Shape<Int<kBlockM>>{}, make_coord(m_block));
216
- if (get<1>(tOcO(_0{}, _0{}, _0{})) == 0) {
217
- #pragma unroll
218
- for (int mi = 0; mi < size(dP_sum); ++mi) {
219
- int const row = get<0>(tOcO(_0{}, mi, _0{}));
220
- gdPsum(row) = row < seqlen_o - m_block * kBlockM ? dP_sum(mi) : 0;
221
- }
222
- }
223
-
224
- int const seqlen_rounded = cute::round_up(seqlen_o, kBlockM);
225
- Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(_, bidh, !is_varlen ? bidb : 0);
226
- Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mLSElog2), Shape<Int<kBlockM>>{}, make_coord(m_block));
227
- if (thread_idx < seqlen_rounded - m_block * kBlockM && thread_idx < kBlockM) {
228
- gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E);
229
- }
230
-
231
- if constexpr (Clear_dQaccum) {
232
- Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
233
- Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block));
234
- GmemTiledCopyAccum gmem_tiled_copy_dQaccum;
235
- auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx);
236
- Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
237
- Tensor zero = make_fragment_like(tdQgdQaccum);
238
- clear(zero);
239
- cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, zero, tdQgdQaccum);
240
- }
241
-
242
- if (params.dq_semaphore != nullptr && thread_idx == 0) {
243
- int const num_batch = params.num_batch;
244
- int const num_head = get<2>(params.shape_O);
245
- params.dq_semaphore[bidh + bidb * num_head + m_block * num_head * num_batch] = 0;
246
- }
247
-
248
- }
249
-
250
- };
251
-
252
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_fwd_combine.cu DELETED
@@ -1,13 +0,0 @@
1
- // Copyright (c) 2024, Tri Dao.
2
- // Splitting the different head dimensions to different files to speed up compilation.
3
-
4
- #include "flash_fwd_combine_launch_template.h"
5
-
6
- template void run_mha_fwd_combine_<float, float, 64>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
7
- template void run_mha_fwd_combine_<float, float, 128>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
8
-
9
- template void run_mha_fwd_combine_<cutlass::half_t, float, 64>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
10
- template void run_mha_fwd_combine_<cutlass::half_t, float, 128>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
11
-
12
- template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
13
- template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_fwd_combine_kernel.h DELETED
@@ -1,482 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include <cutlass/cutlass.h>
10
- #include <cutlass/arch/memory.h>
11
- #include <cutlass/array.h>
12
- #include <cutlass/numeric_types.h>
13
- #include <cutlass/numeric_conversion.h>
14
-
15
- #include "cutlass/arch/grid_dependency_control.h"
16
-
17
- #include "seqlen.h"
18
- #include "utils.h"
19
-
20
- namespace flash {
21
-
22
- using namespace cute;
23
-
24
- template <class TileShape_MK_, int kLogMaxSplits_, int kNThreads, int AlignmentLSE_,
25
- bool Is_even_K, bool Varlen, class Element, class ElementPartial, class ArchTag_>
26
- class FlashAttnFwdCombine {
27
-
28
- public:
29
-
30
- // Type Aliases
31
- using TileShape_MK = TileShape_MK_;
32
- using ArchTag = ArchTag_;
33
- static constexpr int kMaxSplits = 1 << kLogMaxSplits_;
34
- static constexpr int AlignmentLSE = std::min(AlignmentLSE_, int(128 / 8 / sizeof(float)));
35
- static_assert(AlignmentLSE >= 1);
36
- static constexpr int kStages = 4;
37
-
38
- static_assert(ArchTag::kMinComputeCapability >= 75);
39
- static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80;
40
-
41
- static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
42
- static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
43
-
44
- static constexpr int kBlockM = get<0>(TileShape_MK{});
45
- static constexpr int kBlockK = get<1>(TileShape_MK{});
46
-
47
- static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial);
48
- static_assert(kBlockK % kGmemElemsPerLoad == 0, "kBlockK must be a multiple of kGmemElemsPerLoad");
49
- static constexpr int kBlockKGmem = kBlockK % 128 == 0 ? 128 : (kBlockK % 64 == 0 ? 64 : 32);
50
- static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
51
- static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
52
- using GmemCopyAtom = std::conditional_t<
53
- Has_cp_async,
54
- cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<uint128_t>, ElementPartial>,
55
- cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>
56
- >;
57
- using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
58
- Stride<Int<kGmemThreadsPerRow>, _1>>;
59
- static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0);
60
- using GmemTiledCopyAccum = decltype(
61
- make_tiled_copy(GmemCopyAtom{},
62
- GmemLayoutAtom{},
63
- Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 4 vals per load
64
- using GmemTiledCopy = decltype(
65
- make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
66
- GmemLayoutAtom{},
67
- Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 4 vals per load
68
-
69
- using AlignmentTypeLSE = cute::uint_byte_t<static_cast<int>(sizeof(float)) * AlignmentLSE>;
70
- static constexpr int kGmemElemsPerLoadLSE = sizeof(AlignmentTypeLSE) / sizeof(float);
71
- static_assert(kBlockM % kGmemElemsPerLoadLSE == 0, "kBlockM must be a multiple of kGmemElemsPerLoadLSE");
72
- static_assert(kBlockM % 8 == 0, "kBlockM must be a multiple of 8");
73
- static constexpr int kBlockMSmem = kBlockM % 128 == 0 ? 128 : (kBlockM % 64 == 0 ? 64 : (kBlockM % 32 == 0 ? 32 : (kBlockM % 16 == 0 ? 16 : 8)));
74
- static constexpr int kGmemThreadsPerRowLSE = kBlockMSmem / kGmemElemsPerLoadLSE;
75
- static_assert(MaxThreadsPerBlock % kGmemThreadsPerRowLSE == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRowLSE");
76
- using GmemLayoutAtomLSE = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRowLSE>, Int<kGmemThreadsPerRowLSE>>,
77
- Stride<Int<kGmemThreadsPerRowLSE>, _1>>;
78
- static_assert(kMaxSplits % CUTE_STATIC_V(shape<0>(GmemLayoutAtomLSE{})) == 0);
79
- using GmemCopyAtomLSE = std::conditional_t<
80
- Has_cp_async,
81
- cute::Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeLSE>, float>,
82
- cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<AlignmentLSE * sizeof(float) * 8>, float>
83
- >;
84
- using GmemTiledCopyLSE = decltype(
85
- make_tiled_copy(GmemCopyAtomLSE{},
86
- GmemLayoutAtomLSE{},
87
- Layout<Shape<_1, Int<kGmemElemsPerLoadLSE>>>{})); // Val layout, 4 vals per load
88
-
89
- // Otherwise we get IMA when some threads access sLSE, as we're not doing any masking
90
- static_assert((kBlockM * kMaxSplits * AlignmentLSE) % kNThreads == 0, "kNThreads must divide kBlockM * kMaxSplits * AlignmentLSE");
91
- // This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts
92
- using SmemLSESwizzle = std::conditional_t<
93
- kBlockMSmem == 8,
94
- Swizzle<5, 0, 5>,
95
- std::conditional_t<kBlockMSmem == 16, Swizzle<4, 0, 4>, Swizzle<3, 2, 3>>
96
- >;
97
- using SmemLayoutAtomLSE =
98
- decltype(composition(SmemLSESwizzle{},
99
- Layout<Shape<Int<8>, Int<kBlockMSmem>>,
100
- Stride<Int<kBlockMSmem>, _1>>{}));
101
- using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape<Int<kMaxSplits>, Int<kBlockM>>{}));
102
-
103
- using SmemLayoutO = Layout<Shape<Int<kBlockM>, Int<kBlockK>, Int<kStages>>,
104
- Stride<Int<kBlockK>, _1, Int<kBlockM * kBlockK>>>;
105
-
106
- // We want each column (kMaxSplits) to be processed by threads in the same warp.
107
- // To reduce the number of shuffles, we want as few threads on the same column as possible.
108
- // E.g., if kBlockM is divisible by 64, and there are 256 threads, we want 4 threads (0, 1, 2, 4) per column
109
- // have have 64 such quads.
110
- static_assert(MaxThreadsPerBlock % kBlockMSmem == 0, "MaxThreadsPerBlock must be a multiple of kBlockMSmem");
111
- static constexpr int kSmemThreadsPerColLSEt = MaxThreadsPerBlock / kBlockMSmem;
112
- static_assert(cutlass::NumThreadsPerWarp % kSmemThreadsPerColLSEt == 0, "kSmemThreadsPerColLSEt must divide NumThreadsPerWarp");
113
- using S2RLayoutAtomLSE = Layout<Shape<Int<kSmemThreadsPerColLSEt>, Int<MaxThreadsPerBlock / kSmemThreadsPerColLSEt>>>;
114
- using S2RTiledCopyLSE = decltype(make_tiled_copy(cute::Copy_Atom<cute::DefaultCopy, float>{}, S2RLayoutAtomLSE{}, Layout<_1>{}));
115
-
116
- using ShapeOPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, num_splits, head, batch)
117
- using StrideOPartial = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
118
- using ShapeLSEPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, num_splits, head, batch)
119
- using StrideLSEPartial = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen, num_splits, head, batch)
120
- using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
121
- using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
122
- using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
123
- using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
124
-
125
- struct SharedStorage : cute::aligned_struct<128> {
126
- cute::array_aligned<float, cute::cosize_v<SmemLayoutLSE>> smem_lse_partial;
127
- cute::array_aligned<int, kBlockM> smem_max_valid_split;
128
- cute::array_aligned<ElementPartial, cute::cosize_v<SmemLayoutO>> smem_o_partial;
129
- };
130
-
131
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
132
-
133
- // Device side arguments
134
- struct Arguments {
135
- ElementPartial const* const ptr_O_partial;
136
- ShapeOPartial const shape_O_partial;
137
- StrideOPartial const stride_O_partial;
138
- float const* const ptr_LSE_partial;
139
- ShapeLSEPartial const shape_LSE_partial;
140
- StrideLSEPartial const stride_LSE_partial;
141
- Element* const ptr_O;
142
- StrideO const stride_O;
143
- float* const ptr_LSE;
144
- StrideLSE const stride_LSE;
145
- int const* const cu_seqlens = nullptr;
146
- int const* const seqused = nullptr;
147
- int const* const num_splits_dynamic_ptr = nullptr;
148
- int* const semaphore_to_reset = nullptr;
149
- };
150
-
151
- // Kernel entry point API
152
- struct Params {
153
- ElementPartial const* const ptr_O_partial;
154
- ShapeOPartial const shape_O_partial;
155
- StrideOPartial const stride_O_partial;
156
- float const* const ptr_LSE_partial;
157
- ShapeLSEPartial const shape_LSE_partial;
158
- StrideLSEPartial const stride_LSE_partial;
159
- Element* const ptr_O;
160
- StrideO const stride_O;
161
- float* const ptr_LSE;
162
- StrideLSE const stride_LSE;
163
- cutlass::FastDivmod seqlen_divmod, head_divmod;
164
- int const* const cu_seqlens = nullptr;
165
- int const* const seqused = nullptr;
166
- int const* const num_splits_dynamic_ptr = nullptr;
167
- int* const semaphore_to_reset = nullptr;
168
- };
169
-
170
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
171
- static
172
- Params
173
- to_underlying_arguments(Arguments const& args) {
174
- assert(get<1>(args.shape_LSE_partial) <= kMaxSplits);
175
- return {
176
- args.ptr_O_partial,
177
- args.shape_O_partial,
178
- args.stride_O_partial,
179
- args.ptr_LSE_partial,
180
- args.shape_LSE_partial,
181
- args.stride_LSE_partial,
182
- args.ptr_O,
183
- args.stride_O,
184
- args.ptr_LSE,
185
- args.stride_LSE,
186
- cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)),
187
- args.cu_seqlens,
188
- args.seqused,
189
- args.num_splits_dynamic_ptr,
190
- args.semaphore_to_reset
191
- };
192
- }
193
-
194
- CUTLASS_DEVICE
195
- void
196
- operator()(Params const& params, char* smem_buf) {
197
-
198
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
199
- Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{});
200
- Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape<Int<kBlockM>>{});
201
- Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{});
202
-
203
- int const thread_idx = threadIdx.x;
204
- int const m_block = blockIdx.x;
205
- int const k_block = blockIdx.y;
206
- int const batch = blockIdx.z;
207
- int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial);
208
-
209
- if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) {
210
- cutlass::arch::wait_on_dependent_grids();
211
- *params.semaphore_to_reset = 0;
212
- }
213
- if (num_splits <= 1) { return; }
214
- flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused};
215
- int const offset = seqlen_info.offset;
216
- int const seqlen = seqlen_info.seqlen;
217
- int max_idx = seqlen * get<2>(params.shape_LSE_partial);
218
- if constexpr (Varlen) {
219
- if (m_block * kBlockM >= max_idx) { return; }
220
- }
221
-
222
- cutlass::FastDivmod seqlen_divmod_dynamic(seqlen);
223
-
224
- // Step 1: load LSE_partial from gmem -> smem
225
- Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)),
226
- select<1, 0, 2, 3>(params.shape_LSE_partial),
227
- select<1, 0, 2, 3>(params.stride_LSE_partial))(_, _, _, !Varlen ? batch : 0); // (num_splits, seqlen, head)
228
- Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int<kGmemElemsPerLoadLSE>>{});
229
- GmemTiledCopyLSE gmem_tiled_copy_LSE;
230
- auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx);
231
- Tensor tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE);
232
-
233
- // Construct identity layout for sLSE
234
- Tensor cLSE = make_identity_tensor(make_shape(size<0>(sLSE), size<1>(sLSE))); // (NUM_SPLITS, BLK_M) -> (num_splits, blk_m)
235
- // Repeat the partitioning with identity layouts
236
- Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE);
237
-
238
- cutlass::arch::wait_on_dependent_grids();
239
-
240
- #pragma unroll
241
- for (int m = 0; m < size<2>(tLSEcLSE); ++m) {
242
- int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m)));
243
- int idx = m_block * kBlockM + mi;
244
- if (idx < max_idx) {
245
- int m_idx, bidh;
246
- if constexpr (!Varlen) {
247
- bidh = params.seqlen_divmod.divmod(m_idx, idx);
248
- } else {
249
- bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);
250
- }
251
- Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh);
252
- #pragma unroll
253
- for (int s = 0; s < size<1>(tLSEcLSE); ++s) {
254
- int si = get<0>(tLSEcLSE(_0{}, s, _0{}));
255
- // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast<float *>(&(tLSEsLSE(_0{}, s, m))), reinterpret_cast<int>(&(tLSEsLSE(_0{}, s, m))) / 4 % 32);}
256
- if (si < num_splits) {
257
- cute::copy(gmem_tiled_copy_LSE, mLSEpartial_cur_copy(_, si), tLSEsLSE(_, s, m));
258
- } else {
259
- cute::fill(tLSEsLSE(_, s, m), -INFINITY);
260
- }
261
- }
262
- } else {
263
- // We don't need to zero out the rest of the LSEs, as we will not write the output to gmem
264
- // cute::fill(tLSEsLSE(_, _, m), -INFINITY);
265
- }
266
- }
267
- if constexpr (Has_cp_async) { cute::cp_async_fence(); }
268
-
269
- // Step 2: Load O_partial from gmem -> smem for split = 0, 1, ..., kStages - 2.
270
- // We want these async loads to be in flight as we compute the LSE.
271
- GmemTiledCopyAccum gmem_tiled_copy_O_partial;
272
- auto gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_thread_slice(thread_idx);
273
- // Construct identity layout for gO
274
- Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
275
- // Repeat the partitioning with identity layouts
276
- Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO);
277
- Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)),
278
- params.shape_O_partial, params.stride_O_partial)(_, _, _, _, !Varlen ? batch : 0); // (seqlen, d, num_splits, head)
279
-
280
- // Precompute these values to avoid recomputing them in the loop
281
- Tensor tOmidx = make_tensor<int>(make_shape(size<1>(tOcO)));
282
- Tensor tObidh = make_tensor<int>(make_shape(size<1>(tOcO)));
283
- Tensor tOrOptr = make_tensor<ElementPartial const*>(make_shape(size<1>(tOcO)));
284
- #pragma unroll
285
- for (int m = 0; m < size<1>(tOcO); ++m) {
286
- int mi = get<0>(tOcO(_0{}, m, _0{}));
287
- int idx = m_block * kBlockM + mi;
288
- if constexpr (!Varlen) {
289
- tObidh(m) = params.seqlen_divmod.divmod(tOmidx(m), idx);
290
- } else {
291
- tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx);
292
- }
293
- tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m));
294
- if (idx >= max_idx) {
295
- tObidh[m] = -1;
296
- }
297
- }
298
-
299
- Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
300
- if constexpr (!(Is_even_K)) {
301
- #pragma unroll
302
- for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial) - k_block * kBlockK; }
303
- }
304
-
305
- Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO);
306
-
307
- auto load_O_partial = [&] (int split, int stage) {
308
- Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage);
309
- #pragma unroll
310
- for (int m = 0; m < size<1>(tOcO); ++m) {
311
- if (tObidh(m) >= 0) {
312
- Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}).layout());
313
- Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape<Int<kGmemElemsPerLoad>>{});
314
- #pragma unroll
315
- for (int k = 0; k < size<2>(tOcO); ++k) {
316
- int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;
317
- if (Is_even_K || tOpO(k)) {
318
- cute::copy(gmem_tiled_copy_O_partial, mOpartial_cur_copy(_, k_idx, split), tOsOpartial_cur(_, m, k));
319
- }
320
- }
321
- }
322
- }
323
- };
324
-
325
- for (int s = 0; s < kStages - 1; ++s) {
326
- if (s < num_splits) { load_O_partial(s, s); }
327
- if constexpr (Has_cp_async) { cute::cp_async_fence(); }
328
- }
329
-
330
- // Step 3: load and transpose LSE_partial from smem -> rmem
331
- if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }
332
- __syncthreads();
333
-
334
- S2RTiledCopyLSE s2r_tiled_copy_LSE;
335
- auto s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_thread_slice(thread_idx);
336
- Tensor ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE);
337
- Tensor ts2rrLSE = make_fragment_like(ts2rsLSE);
338
- cute::copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE);
339
-
340
- // Step 4: compute the final LSE along the split dimension
341
- Tensor lse_sum = make_tensor<float>(make_shape(size<2>(ts2rrLSE)));
342
- Tensor ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE);
343
- // We compute the max valid split for each row to short-circuit the computation later
344
- Tensor max_valid_split = make_tensor<int>(make_shape(size<2>(ts2rrLSE)));
345
- static_assert(CUTE_STATIC_V(size<0>(ts2rrLSE)) == 1);
346
- #pragma unroll
347
- for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
348
- float lse_max = ts2rrLSE(_0{}, _0{}, m);
349
- #pragma unroll
350
- for (int s = 1; s < size<1>(ts2rrLSE); ++s) { lse_max = max(lse_max, ts2rrLSE(_0{}, s, m)); }
351
- MaxOp<float> max_op;
352
- lse_max = Allreduce<kSmemThreadsPerColLSEt>::run(lse_max, max_op);
353
- int max_valid_idx = -1;
354
- #pragma unroll
355
- for (int s = 0; s < size<1>(ts2rrLSE); ++s) {
356
- if (ts2rrLSE(_0{}, s, m) != -INFINITY) { max_valid_idx = get<0>(ts2rcLSE(_0{}, s, _0{})); }
357
- }
358
- MaxOp<int> max_int_op;
359
- max_valid_split[m] = Allreduce<kSmemThreadsPerColLSEt>::run(max_valid_idx, max_int_op);
360
- float lse_max_cur = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
361
- float lse_sum_cur = 0.f;
362
- #pragma unroll
363
- for (int s = 0; s < size<1>(ts2rrLSE); ++s) {
364
- float scale = expf(ts2rrLSE(_0{}, s, m) - lse_max_cur);
365
- lse_sum_cur += scale;
366
- // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast<float *>(&(ts2rsLSE(_0{}, s, m))), reinterpret_cast<int>(&(ts2rsLSE(_0{}, s, m))) / 4 % 32);}
367
- // ts2rsLSE(_0{}, m, s) = scale;
368
- ts2rrLSE(_0{}, s, m) = scale;
369
- }
370
- SumOp<float> sum_op;
371
- lse_sum_cur = Allreduce<kSmemThreadsPerColLSEt>::run(lse_sum_cur, sum_op);
372
- lse_sum(m) = logf(lse_sum_cur) + lse_max;
373
- float inv_sum = (lse_sum_cur == 0.f || lse_sum_cur != lse_sum_cur) ? 0.f : 1.f / lse_sum_cur;
374
- #pragma unroll
375
- for (int s = 0; s < size<1>(ts2rrLSE); ++s) { ts2rrLSE(_0{}, s, m) *= inv_sum; }
376
- }
377
- // Store the scales exp(lse - lse_logsum) back to smem
378
- cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE);
379
-
380
- // Store max_valid_split to smem
381
- #pragma unroll
382
- for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
383
- if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to smem
384
- int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));
385
- if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; }
386
- }
387
- }
388
-
389
- // Step 5: store final LSE back to gmem
390
- if (k_block == 0) {
391
- auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial);
392
- Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE)(_, _, !Varlen ? batch : 0);
393
- #pragma unroll
394
- for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
395
- if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem
396
- int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));
397
- int idx = m_block * kBlockM + mi;
398
- if (idx < max_idx) {
399
- int m_idx, bidh;
400
- if constexpr (!Varlen) {
401
- bidh = params.seqlen_divmod.divmod(m_idx, idx);
402
- } else {
403
- bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);
404
- }
405
- // printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m));
406
- mLSE(m_idx, bidh) = lse_sum(m);
407
- }
408
- }
409
- }
410
- }
411
-
412
- // Step 6: read O_partial from gmem -> smem -> rmem and accumulate the final O
413
- __syncthreads();
414
- int thr_max_valid_split = sMaxValidSplit[get<0>(tOcO(_0{}, _0{}, _0{}))];
415
- #pragma unroll
416
- for (int m = 1; m < size<1>(tOcO); ++m) { thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[get<0>(tOcO(_0{}, m, _0{}))]); }
417
- Layout tOrOpartial_layout = gmem_thr_copy_O_partial.partition_S(make_tensor<ElementPartial>(TileShape_MK{})).layout();
418
- Tensor tOrOpartial = make_fragment_like<ElementPartial>(tOrOpartial_layout);
419
- Tensor tOrO = make_fragment_like<float>(tOrOpartial);
420
- clear(tOrO);
421
- int stage_load = kStages - 1, stage_compute = 0;
422
- #pragma unroll 4 // Already tuned for speed
423
- for (int s = 0; s <= thr_max_valid_split; ++s) {
424
- Tensor scale = make_tensor<float>(make_shape(size<1>(tOrOpartial)));
425
- #pragma unroll
426
- for (int m = 0; m < size<1>(tOrOpartial); ++m) { scale(m) = sLSE(s, get<0>(tOcO(_0{}, m, _0{}))); }
427
-
428
- if (s + kStages - 1 <= thr_max_valid_split) { load_O_partial(s + kStages - 1, stage_load); }
429
- if constexpr (Has_cp_async) { cute::cp_async_fence(); }
430
- stage_load = stage_load < kStages - 1 ? stage_load + 1 : 0;
431
- if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }
432
- // We don't need __syncthreads() because each thread is just reading its own data from smem
433
- cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>{},
434
- tOsOpartial(_, _, _, stage_compute), tOrOpartial);
435
- stage_compute = stage_compute < kStages - 1 ? stage_compute + 1 : 0;
436
-
437
- #pragma unroll
438
- for (int m = 0; m < size<1>(tOrOpartial); ++m) {
439
- if (tObidh(m) >= 0 && scale(m) > 0.f) {
440
- #pragma unroll
441
- for (int k = 0; k < size<2>(tOrOpartial); ++k) {
442
- if (Is_even_K || tOpO(k)) {
443
- Tensor rOpartial = make_tensor_like<float>(tOrOpartial(_, m, k));
444
- flash::convert_type_out(tOrOpartial(_, m, k), rOpartial);
445
- #pragma unroll
446
- for (int i = 0; i < size<0>(tOrOpartial); ++i) {
447
- tOrO(i, m, k) += scale(m) * rOpartial[i];
448
- }
449
- }
450
- }
451
- }
452
- }
453
- }
454
-
455
- // Step 7: Write the final O to gmem
456
- Tensor rO = make_tensor_like<Element>(tOrO);
457
- flash::convert_type_out(tOrO, rO);
458
- auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial));
459
- Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)),
460
- shape_O, params.stride_O)(_, _, _, !Varlen ? batch : 0);
461
- Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int<kGmemElemsPerLoad>>{});
462
- GmemTiledCopy gmem_tiled_copy_O;
463
- auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
464
-
465
- #pragma unroll
466
- for (int m = 0; m < size<1>(tOcO); ++m) {
467
- if (tObidh(m) >= 0) {
468
- #pragma unroll
469
- for (int k = 0; k < size<2>(tOcO); ++k) {
470
- int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;
471
- if (Is_even_K || tOpO(k)) {
472
- cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m)));
473
- }
474
- }
475
- }
476
- }
477
-
478
- }
479
-
480
- };
481
-
482
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_fwd_combine_launch_template.h DELETED
@@ -1,80 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include "cutlass/cutlass.h"
10
- #include "cutlass/arch/arch.h" // For cutlass::arch::Sm80
11
- #include "cutlass/device_kernel.h" // For device_kernel
12
- #include "cutlass/kernel_launch.h" // For kernel_launch
13
-
14
- #include "static_switch.h"
15
- #include "flash.h"
16
- #include "flash_fwd_combine_kernel.h"
17
-
18
- using namespace cute;
19
-
20
- template <int Arch, int kBlockM, int kBlockK, int kLogMaxSplits, bool IsEvenK, bool Varlen, typename Element, typename ElementPartial>
21
- void run_flash_fwd_combine(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl) {
22
- using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
23
- using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kBlockK>>;
24
- using CombineKernel = flash::FlashAttnFwdCombine<TileShape_MK, kLogMaxSplits, 256 /*kNThreads*/, 1 /*AlignmentLSE*/,
25
- IsEvenK, Varlen, Element, ElementPartial, ArchTag>;
26
-
27
- typename CombineKernel::Arguments args {
28
- static_cast<ElementPartial const*>(params.oaccum_ptr),
29
- {!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial
30
- {params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial
31
- static_cast<float*>(params.softmax_lseaccum_ptr),
32
- {!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_LSE_partial
33
- {_1{}, params.lseaccum_split_stride, params.lseaccum_head_stride, !Varlen ? params.lseaccum_batch_stride : 0}, // stride_LSE_partial
34
- static_cast<Element*>(params.o_ptr),
35
- {params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O
36
- static_cast<float*>(params.softmax_lse_ptr),
37
- {_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE
38
- params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore
39
- };
40
-
41
- typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args);
42
- int num_blocks_k = cute::ceil_div(params.dv, kBlockK);
43
- int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h, kBlockM);
44
- dim3 grid_m(num_blocks_m, num_blocks_k, params.b);
45
- auto kernel = cutlass::device_kernel<CombineKernel>;
46
- int smem_size = CombineKernel::SharedStorageSize;
47
- if (smem_size >= 48 * 1024) {
48
- CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
49
- }
50
- // kernel<<<grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream>>>(kernel_params);
51
- cutlass::kernel_launch<CombineKernel>(grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream, kernel_params, Arch >= 90 && enable_pdl /*launch_with_pdl*/);
52
- CHECK_CUDA_KERNEL_LAUNCH();
53
- }
54
-
55
- template<typename T, typename Tpartial, int kBlockK>
56
- void run_mha_fwd_combine_(Flash_fwd_params &params, cudaStream_t stream, bool enable_pdl) {
57
- // We want kBlockM to be as small as possible to maximize parallelism.
58
- // E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
59
- static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32");
60
- static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32);
61
- ARCH_SWITCH(params.arch, Arch, [&] {
62
- BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] {
63
- if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32.
64
- if (params.num_splits <= 16) {
65
- run_flash_fwd_combine<Arch, kBlockM, kBlockK, 4, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
66
- return;
67
- }
68
- }
69
- if (params.num_splits <= 32) {
70
- run_flash_fwd_combine<Arch, kBlockM, kBlockK, 5, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
71
- } else if (params.num_splits <= 64) {
72
- run_flash_fwd_combine<Arch, kBlockM, kBlockK, 6, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
73
- } else if (params.num_splits <= 128) {
74
- run_flash_fwd_combine<Arch, kBlockM, kBlockK, 7, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
75
- } else {
76
- run_flash_fwd_combine<Arch, kBlockM, kBlockK, 8, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
77
- }
78
- });
79
- });
80
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_fwd_kernel_sm80.h DELETED
@@ -1,215 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include <cutlass/cutlass.h>
10
- #include <cutlass/array.h>
11
- #include <cutlass/numeric_types.h>
12
- #include <cutlass/kernel_hardware_info.h>
13
-
14
- #include "seqlen.h"
15
- #include "utils.h"
16
- #include "softmax.h"
17
-
18
- namespace flash {
19
-
20
- using namespace cute;
21
-
22
- template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
23
- class FlashAttnFwdSm80 {
24
-
25
- public:
26
-
27
- // Type Aliases
28
- using CollectiveMainloop = CollectiveMainloop_;
29
- using CollectiveEpilogue = CollectiveEpilogue_;
30
- static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
31
- static constexpr bool Is_local = CollectiveMainloop::Is_local;
32
- static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
33
- static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
34
- static constexpr bool Varlen = CollectiveMainloop::Varlen;
35
- static constexpr bool PagedKV = CollectiveMainloop::PagedKV;
36
- static constexpr bool Split = CollectiveMainloop::Split;
37
- static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
38
- static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
39
- static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
40
- static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
41
- static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
42
- using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
43
-
44
- // Mainloop derived types
45
- using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
46
- using TiledMma = typename CollectiveMainloop::TiledMma;
47
- using ArchTag = typename CollectiveMainloop::ArchTag;
48
- using MainloopArguments = typename CollectiveMainloop::Arguments;
49
- using MainloopParams = typename CollectiveMainloop::Params;
50
-
51
- // Epilogue derived types
52
- using EpilogueArguments = typename CollectiveEpilogue::Arguments;
53
- using EpilogueParams = typename CollectiveEpilogue::Params;
54
-
55
- static_assert(ArchTag::kMinComputeCapability >= 80);
56
-
57
- using TileScheduler = TileScheduler_;
58
- using TileSchedulerArguments = typename flash::TileSchedulerArguments;
59
- using TileSchedulerParams = typename TileScheduler::Params;
60
-
61
- static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMma{}));
62
- static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{}));
63
- static constexpr uint32_t MinBlocksPerMultiprocessor = NumThreads == 128 ? 2 : 1;
64
-
65
- // Kernel level shared memory storage
66
- // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v + smem_k and not smem_q
67
- // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v) + sizeof(smem_k).
68
- static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage))
69
- - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)))
70
- - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)));
71
- static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
72
- struct SharedStorage {
73
- struct TensorStorage : cute::aligned_struct<128> {
74
- union {
75
- struct {
76
- cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
77
- typename CollectiveMainloop::TensorStorage mainloop;
78
- };
79
- // We want smem_o to line up with the start of smem_v
80
- typename CollectiveEpilogue::TensorStorage epilogue;
81
- };
82
- } tensors;
83
-
84
- alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
85
-
86
- };
87
-
88
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
89
-
90
- // Device side arguments
91
- struct Arguments {
92
- MainloopArguments mainloop{};
93
- EpilogueArguments epilogue{};
94
- cutlass::KernelHardwareInfo hw_info{};
95
- TileSchedulerArguments scheduler{};
96
- };
97
-
98
- // Kernel entry point API
99
- struct Params {
100
- MainloopParams mainloop{};
101
- EpilogueParams epilogue{};
102
- cutlass::KernelHardwareInfo hw_info{};
103
- TileSchedulerParams scheduler{};
104
- };
105
-
106
- //
107
- // Methods
108
- //
109
-
110
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
111
- static
112
- Params
113
- to_underlying_arguments(Arguments const& args) {
114
- CUTLASS_TRACE_HOST("to_underlying_arguments():");
115
-
116
- // Get SM count if needed, otherwise use user supplied SM count
117
- int sm_count = args.hw_info.sm_count;
118
- if (sm_count <= 0) {
119
- CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
120
- " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
121
- sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
122
- }
123
-
124
- CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
125
-
126
- cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
127
- return {
128
- CollectiveMainloop::to_underlying_arguments(args.mainloop),
129
- CollectiveEpilogue::to_underlying_arguments(args.epilogue),
130
- hw_info,
131
- TileScheduler::to_underlying_arguments(args.scheduler)
132
- };
133
- }
134
-
135
- // Computes the kernel launch grid shape based on runtime parameters
136
- static dim3
137
- get_grid_shape(Params const& params) {
138
- return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count * MinBlocksPerMultiprocessor);
139
- }
140
-
141
- static dim3
142
- get_block_shape() {
143
- return dim3(MaxThreadsPerBlock, 1, 1);
144
- }
145
-
146
- CUTLASS_DEVICE
147
- void
148
- operator()(Params const& params, char* smem_buf) {
149
-
150
- static constexpr int kBlockM = get<0>(TileShape_MNK{});
151
-
152
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
153
-
154
- CollectiveMainloop mainloop;
155
- CollectiveEpilogue epilogue;
156
-
157
- TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
158
- // Initialize matmul objects.
159
- TiledMma tiled_mma;
160
-
161
- scheduler.init_consumer();
162
-
163
- int warp_idx = cutlass::canonical_warp_idx_sync();
164
- CUTLASS_PRAGMA_NO_UNROLL
165
- for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
166
- work_tile_info.is_valid(params.scheduler);
167
- work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
168
- // Attention output (GEMM-II) accumulator.
169
- Tensor tOrO = partition_fragment_C(tiled_mma, select<0, 2>(TileShape_MNK{}));
170
- float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
171
- // If there's tanh softcap, the scaling will be done before tanh.
172
- auto block_coord = work_tile_info.get_block_coord(params.scheduler);
173
- int const bidb = get<2>(block_coord);
174
- if constexpr (Is_FP8 && !Has_softcap) {
175
- int const bidh = get<1>(block_coord);
176
- int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
177
- float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
178
- float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
179
- softmax_scale_log2 *= q_descale * k_descale;
180
- }
181
- flash::Softmax<2 * (2 * kBlockM / NumThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
182
-
183
- SeqlenInfo_t seqlen_info{
184
- bidb,
185
- get<0>(params.mainloop.shape_Q),
186
- !PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
187
- get<0>(params.mainloop.shape_K_new),
188
- params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
189
- params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
190
- params.mainloop.seqlens_rotary
191
- };
192
- if constexpr (AppendKV) {
193
- bool tile_new_valid = mainloop.store_kv_new(
194
- params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord);
195
- if (tile_new_valid) { __syncthreads(); }
196
- }
197
- bool tile_valid = mainloop.mma(
198
- params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord,
199
- shared_storage);
200
- scheduler.prefetch_next_work(params.scheduler, work_tile_info);
201
- if (tile_valid) {
202
- // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
203
- epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma,
204
- threadIdx.x, block_coord);
205
- } else {
206
- // Write 0 to gO and -inf to gLSE.
207
- epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);
208
- }
209
- }
210
-
211
- }
212
-
213
- };
214
-
215
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_fwd_kernel_sm90.h DELETED
@@ -1,458 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include <cutlass/cutlass.h>
10
- #include <cutlass/arch/reg_reconfig.h>
11
- #include <cutlass/array.h>
12
- #include <cutlass/numeric_types.h>
13
- #include <cutlass/numeric_conversion.h>
14
- #include <cutlass/kernel_hardware_info.h>
15
- #include "cutlass/pipeline/pipeline.hpp"
16
-
17
- #include "cutlass/arch/grid_dependency_control.h"
18
-
19
- #include "seqlen.h"
20
- #include "utils.h"
21
- #include "softmax.h"
22
-
23
- namespace flash {
24
-
25
- using namespace cute;
26
-
27
- template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
28
- class FlashAttnFwdSm90 {
29
-
30
- public:
31
-
32
- // Type Aliases
33
- using CollectiveMainloop = CollectiveMainloop_;
34
- using CollectiveEpilogue = CollectiveEpilogue_;
35
- static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
36
- static constexpr bool Is_local = CollectiveMainloop::Is_local;
37
- static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
38
- static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
39
- static constexpr bool Varlen = CollectiveMainloop::Varlen;
40
- static constexpr bool Split = CollectiveMainloop::Split;
41
- static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
42
- static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
43
- static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
44
- static constexpr bool HasQv = CollectiveMainloop::HasQv;
45
- static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q;
46
- static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV;
47
- static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O;
48
- static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
49
- static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
50
- static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim;
51
- static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV;
52
- static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV);
53
- using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
54
-
55
- // Mainloop derived types
56
- using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV;
57
- using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV;
58
- using ArchTag = typename CollectiveMainloop::ArchTag;
59
- using ClusterShape = typename CollectiveMainloop::ClusterShape;
60
- using MainloopArguments = typename CollectiveMainloop::Arguments;
61
- using MainloopParams = typename CollectiveMainloop::Params;
62
- using BarrierQ = std::conditional_t<Use_TMA_Q, cutlass::arch::ClusterTransactionBarrier, cutlass::arch::ClusterBarrier>;
63
-
64
- // Epilogue derived types
65
- using EpilogueArguments = typename CollectiveEpilogue::Arguments;
66
- using EpilogueParams = typename CollectiveEpilogue::Params;
67
-
68
- static_assert(ArchTag::kMinComputeCapability >= 90);
69
-
70
- using TileScheduler = TileScheduler_;
71
- using TileSchedulerArguments = typename flash::TileSchedulerArguments;
72
- using TileSchedulerParams = typename TileScheduler::Params;
73
-
74
- static constexpr uint32_t NumLoadWarpGroups = 1;
75
- static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup;
76
- static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
77
- static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
78
- static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
79
-
80
- /// Register requirement for Load and Math WGs
81
- // If we use cp.async to load K and V, we need more registers for the producer WG.
82
- static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32);
83
- static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160);
84
- // If you want to print from the producer warp, you'd need to increase the number of registers
85
- // Otherwise you'll get CUDA error.
86
- // static constexpr uint32_t LoadRegisterRequirement = 40;
87
- // static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
88
-
89
- // Kernel level shared memory storage
90
- // We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v
91
- // and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v).
92
- static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)));
93
- static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
94
- struct SharedStorage {
95
- struct TensorStorage : cute::aligned_struct<128, _1> {
96
- union {
97
- struct {
98
- cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
99
- typename CollectiveMainloop::TensorStorage mainloop;
100
- };
101
- // We want smem_o to line up with the start of smem_v
102
- typename CollectiveEpilogue::TensorStorage epilogue;
103
- };
104
- } tensors;
105
- struct PipelineStorage : cute::aligned_struct<16, _1> {
106
- alignas(16) BarrierQ barrier_Q;
107
- alignas(16) BarrierQ barrier_Qv;
108
- alignas(16) cutlass::arch::ClusterBarrier barrier_O;
109
- alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k;
110
- alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v;
111
- alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt;
112
- alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_k_new;
113
- alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_v_new;
114
- alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
115
- } pipelines;
116
-
117
- };
118
-
119
- static constexpr int SharedStorageSize = sizeof(SharedStorage);
120
-
121
- // Device side arguments
122
- struct Arguments {
123
- MainloopArguments mainloop{};
124
- EpilogueArguments epilogue{};
125
- cutlass::KernelHardwareInfo hw_info{};
126
- TileSchedulerArguments scheduler{};
127
- };
128
-
129
- // Kernel entry point API
130
- struct Params {
131
- MainloopParams mainloop{};
132
- EpilogueParams epilogue{};
133
- cutlass::KernelHardwareInfo hw_info{};
134
- TileSchedulerParams scheduler{};
135
- };
136
-
137
- //
138
- // Methods
139
- //
140
-
141
- // Convert to underlying arguments. In this case, a simple copy for the aliased type.
142
- static
143
- Params
144
- to_underlying_arguments(Arguments const& args) {
145
- CUTLASS_TRACE_HOST("to_underlying_arguments():");
146
-
147
- // Get SM count if needed, otherwise use user supplied SM count
148
- int sm_count = args.hw_info.sm_count;
149
- if (sm_count <= 0) {
150
- CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
151
- " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
152
- sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
153
- }
154
-
155
- CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
156
-
157
- cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
158
- return {
159
- CollectiveMainloop::to_underlying_arguments(args.mainloop),
160
- CollectiveEpilogue::to_underlying_arguments(args.epilogue),
161
- hw_info,
162
- TileScheduler::to_underlying_arguments(args.scheduler)
163
- };
164
- }
165
-
166
- // Computes the kernel launch grid shape based on runtime parameters
167
- static dim3
168
- get_grid_shape(Params const& params) {
169
- return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
170
- }
171
-
172
- static dim3
173
- get_block_shape() {
174
- return dim3(MaxThreadsPerBlock, 1, 1);
175
- }
176
-
177
- CUTLASS_DEVICE
178
- void
179
- operator()(Params const& params, char* smem_buf) {
180
-
181
- static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
182
- static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
183
- static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
184
-
185
- using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK;
186
- using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV;
187
- using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt;
188
- using MainloopPipelineKVNew = typename CollectiveMainloop::MainloopPipelineKVNew;
189
- using PipelineState = typename CollectiveMainloop::PipelineState;
190
- using PipelineParamsK = typename MainloopPipelineK::Params;
191
- using PipelineParamsV = typename MainloopPipelineV::Params;
192
- using PipelineParamsVt = typename MainloopPipelineVt::Params;
193
- using PipelineParamsKVNew = typename MainloopPipelineKVNew::Params;
194
-
195
- SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
196
-
197
- int const lane_predicate = cute::elect_one_sync();
198
- int const warp_idx = cutlass::canonical_warp_idx_sync();
199
-
200
- // Issue Tma Descriptor Prefetch from a single thread
201
- if (warp_idx == 0 && lane_predicate) {
202
- CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
203
- CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
204
- }
205
-
206
- // Obtain warp index
207
- int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
208
- int warp_group_idx = cutlass::canonical_warp_group_idx();
209
-
210
- if (warp_idx == 0 && lane_predicate) {
211
- shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
212
- if constexpr (HasQv) {
213
- shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
214
- }
215
- shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/);
216
- }
217
-
218
- // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
219
- PipelineParamsK pipeline_params_k;
220
- pipeline_params_k.role = warp_group_idx == 0
221
- ? MainloopPipelineK::ThreadCategory::Producer
222
- : MainloopPipelineK::ThreadCategory::Consumer;
223
- if constexpr (Use_TMA_KV) {
224
- pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
225
- pipeline_params_k.is_leader = warp_group_thread_idx == 0;
226
- pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
227
- } else {
228
- pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
229
- pipeline_params_k.producer_arv_count = NumProducerThreads;
230
- }
231
-
232
- static_assert(is_same_v<PipelineParamsK, PipelineParamsVt>);
233
- PipelineParamsVt pipeline_params_vt = pipeline_params_k;
234
- if constexpr (Use_TMA_KV && !SameHeadDim) {
235
- pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
236
- if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; }
237
- } else {
238
- if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; }
239
- }
240
-
241
- MainloopPipelineK pipeline_k = [&] {
242
- if constexpr (Use_TMA_KV) {
243
- return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{});
244
- } else {
245
- return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k);
246
- }
247
- }();
248
- // MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{});
249
- MainloopPipelineV pipeline_v = [&] {
250
- if constexpr (!Transpose_V) {
251
- static_assert(is_same_v<PipelineParamsK, PipelineParamsV>);
252
- if constexpr (Use_TMA_KV) {
253
- return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{});
254
- } else {
255
- return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt);
256
- }
257
- } else {
258
- PipelineParamsV pipeline_params_v;
259
- pipeline_params_v.role = warp_group_idx == 0
260
- ? MainloopPipelineV::ThreadCategory::Producer
261
- : MainloopPipelineV::ThreadCategory::Consumer;
262
- pipeline_params_v.producer_arv_count = NumProducerThreads;
263
- pipeline_params_v.consumer_arv_count = NumMmaThreads;
264
- return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v);
265
- }
266
- }();
267
- // If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then
268
- // the producer WG will read from pipeline_vt and write to pipeline_v.
269
- // If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used.
270
- // Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers.
271
- // However, the thread role isn't used in the pipeline implementation.
272
- MainloopPipelineVt pipeline_vt = [&] {
273
- if constexpr (Use_TMA_KV) {
274
- pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG
275
- return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{});
276
- } else {
277
- pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG
278
- return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt);
279
- }
280
- }();
281
-
282
- PipelineParamsKVNew pipeline_params_kv_new;
283
- pipeline_params_kv_new.role = warp_group_idx == 0
284
- ? MainloopPipelineKVNew::ThreadCategory::Producer
285
- : MainloopPipelineKVNew::ThreadCategory::Consumer;
286
- pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
287
- pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0;
288
- pipeline_params_kv_new.num_consumers = NumMmaThreads;
289
- auto pipeline_k_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
290
- if constexpr (!SameHeadDim) {
291
- pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
292
- }
293
- auto pipeline_v_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
294
-
295
- CollectiveMainloop mainloop;
296
- CollectiveEpilogue epilogue;
297
-
298
- // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
299
- if constexpr (size(ClusterShape{}) > 1) {
300
- cute::cluster_arrive_relaxed();
301
- cute::cluster_wait();
302
- } else {
303
- __syncthreads();
304
- }
305
-
306
- TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
307
-
308
- if (warp_group_idx == 0) { // Producer
309
- cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
310
-
311
- // The pipelines for AppendKV and main attention are different, since e.g. main attention
312
- // might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load
313
- // KV_new. Since the pipeline states are different, we have to manually sync to make
314
- // sure the two pipelines don't race when accessing smem_k and smem_v.
315
- PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipelineK>();
316
- PipelineState smem_pipe_write_new = cutlass::make_producer_start_state<MainloopPipelineKVNew>();
317
- int work_idx = 0;
318
- int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
319
- static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;
320
- if constexpr (SingleProducerWarp) {
321
- if (warp_idx_in_warpgroup != 0) { return; }
322
- }
323
- if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); }
324
-
325
- cutlass::arch::wait_on_dependent_grids();
326
-
327
- // Load Q, K, V
328
- for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
329
- work_tile_info.is_valid(params.scheduler);
330
- work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
331
-
332
- auto block_coord = work_tile_info.get_block_coord(params.scheduler);
333
- SeqlenInfo_t seqlen_info{
334
- get<2>(block_coord) /*bidb*/,
335
- get<0>(params.mainloop.shape_Q),
336
- !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
337
- get<0>(params.mainloop.shape_K_new),
338
- params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
339
- params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
340
- params.mainloop.seqlens_rotary
341
- };
342
- if constexpr (AppendKV) {
343
- bool tile_new_valid = mainloop.load_kv_new(
344
- params.mainloop, pipeline_k_new, pipeline_v_new,
345
- smem_pipe_write_new, shared_storage, seqlen_info, block_coord, work_idx);
346
- if (tile_new_valid) {
347
- // if (threadIdx.x == 0) { printf("Producer: Before sync\n"); }
348
- cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
349
- // if (threadIdx.x == 0) { printf("Producer: After sync\n"); }
350
- }
351
- }
352
- auto scheduler_prefetch = [&scheduler, &params, &work_tile_info]() {
353
- scheduler.prefetch_next_work(params.scheduler, work_tile_info);
354
- };
355
- // pipeline_vt won't be used if we don't need to transpose V.
356
- mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write,
357
- shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx);
358
- }
359
- mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx);
360
- } else { // Consumer
361
- cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
362
-
363
- // Initialize matmul objects.
364
- TiledMmaPV tiled_mma_pv;
365
-
366
- PipelineState smem_pipe_read;
367
- PipelineState smem_pipe_read_new;
368
- // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
369
- // (like in Cutlass's gemm) because the read and release pipeline states are always the same.
370
-
371
- scheduler.init_consumer();
372
- mainloop.mma_init();
373
-
374
- int work_idx = 0;
375
- CUTLASS_PRAGMA_NO_UNROLL
376
- for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
377
- work_tile_info.is_valid(params.scheduler);
378
- // get_next_work will be called before the epilogue
379
- ) {
380
- auto block_coord = work_tile_info.get_block_coord(params.scheduler);
381
- int const bidb = get<2>(block_coord);
382
- SeqlenInfo_t seqlen_info{
383
- bidb,
384
- get<0>(params.mainloop.shape_Q),
385
- !params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
386
- get<0>(params.mainloop.shape_K_new),
387
- params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
388
- params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
389
- params.mainloop.seqlens_rotary
390
- };
391
- if constexpr (AppendKV) {
392
- bool tile_new_valid = mainloop.store_kv_new(
393
- params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read_new,
394
- threadIdx.x - MmaThreadOffset, shared_storage, seqlen_info, block_coord);
395
- if (tile_new_valid) {
396
- // if (threadIdx.x == 128) { printf("Consumer: Before sync\n"); }
397
- // We need this sync so that the gmem write from the consumers is visible to the producer
398
- // that might do TMA read after that.
399
- asm volatile ("fence.proxy.async.global;");
400
- cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
401
- // arrive is enough, we don't need sync. The producer will sync, which means
402
- // after that sync we're guaranteed that the AppendKV pipeline have finished
403
- // loading and consumer smem_k and smem_v.
404
- // if (threadIdx.x == 128) { printf("Consumer: After sync\n"); }
405
- }
406
- }
407
- // If there's tanh softcap, the scaling will be done before tanh.
408
- float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
409
- if constexpr (Is_FP8 && !Has_softcap) {
410
- int const bidh = get<1>(block_coord);
411
- int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
412
- float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
413
- float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
414
- softmax_scale_log2 *= q_descale * k_descale;
415
- }
416
- flash::Softmax<!LargeHeadDimV ? 2 * (2 * kBlockM / NumMmaThreads) : 2, /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
417
- // Attention output (GEMM-II) accumulator.
418
- Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{}));
419
- bool tile_valid;
420
- if constexpr (!LargeHeadDimV) {
421
- tile_valid = mainloop.mma(
422
- params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
423
- tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
424
- } else { // mma_pv might not compile if !LargeHeadDimV
425
- if (warp_group_idx == 1) {
426
- tile_valid = mainloop.mma(
427
- params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
428
- tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
429
- } else {
430
- tile_valid = mainloop.mma_pv(
431
- params.mainloop, pipeline_v, smem_pipe_read,
432
- tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage);
433
- }
434
- }
435
- // Do this here before the epilogue so that the next tile is ready to go.
436
- work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info);
437
- if constexpr (Split && Varlen) {
438
- if (!work_tile_info.is_valid(params.scheduler)) { // Last tile
439
- cutlass::arch::launch_dependent_grids();
440
- }
441
- }
442
- if (tile_valid) {
443
- // if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
444
- epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv,
445
- threadIdx.x - MmaThreadOffset, block_coord);
446
- } else {
447
- // Write 0 to gO and -inf to gLSE.
448
- epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord);
449
- }
450
- }
451
- epilogue.store_tail();
452
- }
453
-
454
- }
455
-
456
- };
457
-
458
- } // namespace flash
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_fwd_launch_template.h DELETED
@@ -1,223 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include "cute/tensor.hpp"
8
-
9
- #include "cutlass/cutlass.h"
10
- #include "cutlass/device_kernel.h" // For device_kernel
11
- #include <cutlass/kernel_hardware_info.h>
12
- #include "cutlass/cluster_launch.hpp"
13
- #include "cutlass/kernel_launch.h"
14
-
15
- #include "static_switch.h"
16
- #include "flash.h"
17
- #include "tile_size.h"
18
- #include "tile_scheduler.hpp"
19
- #include "flash_fwd_kernel_sm90.h"
20
- #include "flash_fwd_kernel_sm80.h"
21
- #include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
22
- #include "mainloop_fwd_sm80.hpp"
23
- #include "epilogue_fwd.hpp"
24
-
25
- using namespace cute;
26
-
27
- template <int Arch, int kHeadDim, int kHeadDimV, int ClusterM, typename Element, typename ElementOut,
28
- bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool PagedKVNonTMA, bool AppendKV, bool HasQv,
29
- bool PackGQA, bool Split, bool V_colmajor>
30
- void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
31
- static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time");
32
- static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time");
33
- static_assert(!(AppendKV && !Varlen), "AppendKV requires Varlen");
34
- static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;
35
- static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor;
36
- using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
37
-
38
- // Can't use structured binding since it's not compatible with constexpr
39
- static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap);
40
- static constexpr std::tuple<int, int, int, int, bool> kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV);
41
- static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS);
42
- static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS);
43
- static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap);
44
- static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap);
45
- static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS);
46
- static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS);
47
- static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS);
48
-
49
- using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
50
- using TileShape_MNK_PV = cute::Shape<Int<kBlockM>, Int<kHeadDimV>, Int<kBlockN>>;
51
- using ClusterShape = cute::Shape<Int<ClusterM>, _1, _1>;
52
- using CollectiveMainloop = std::conditional_t<
53
- Arch >= 90,
54
- flash::CollectiveMainloopFwdSm90<kStages, ClusterShape, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm90, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, HasQv, MmaPV_is_RS, IntraWGOverlap, PackGQA, Split, V_colmajor>,
55
- flash::CollectiveMainloopFwdSm80<kNWarps, kStages, Q_in_regs, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm80, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, PackGQA, Split>
56
- >;
57
- using CollectiveEpilogue = flash::CollectiveEpilogueFwd<TileShape_MNK_PV, ClusterShape, ElementOut, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, PackGQA, Split, FP8_TransposeV>;
58
-
59
- static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads;
60
- using SchedulerPersistent = std::conditional_t<Varlen,
61
- flash::VarlenDynamicPersistentTileScheduler<kBlockM, CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>,
62
- std::conditional_t<!Is_causal && !Is_local,
63
- flash::StaticPersistentTileScheduler<Split>,
64
- flash::DynamicPersistentTileScheduler<CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>
65
- >
66
- >;
67
- using SchedulerSingleTile = flash::SingleTileScheduler<Varlen, Split, PackGQA, kBlockM>;
68
- // If Split then we probably don't have enough work for PersistentScheduler to be useful.
69
- // However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better
70
- // since we'll avoid launching a bunch of thread blocks that immediately exit.
71
- // On Sm80, noncausal persistent seems a bit slower.
72
- static constexpr bool UsePersistentScheduler = Arch >= 90 ? !(Split && !Varlen) : ((Is_causal && !Varlen) || (Varlen && Split));
73
- using Scheduler = std::conditional_t<!UsePersistentScheduler, SchedulerSingleTile, SchedulerPersistent>;
74
- using AttnKernel = std::conditional_t<
75
- Arch >= 90,
76
- flash::enable_sm90_or_later<flash::FlashAttnFwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
77
- flash::enable_sm80_to_sm89<flash::FlashAttnFwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
78
- >;
79
-
80
- bool const is_varlen_q = params.cu_seqlens_q;
81
- bool const is_varlen_k = params.cu_seqlens_k;
82
- bool const is_varlen_k_new = params.cu_seqlens_knew;
83
- int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
84
- int batch_q = !is_varlen_q ? params.b : 1;
85
- int batch_k = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1;
86
- typename CollectiveMainloop::StrideV v_strides =
87
- cute::conditional_return<!V_colmajor>(
88
- make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0),
89
- make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0));
90
- typename CollectiveMainloop::Arguments mainloop_args {
91
- static_cast<Element const*>(params.q_ptr),
92
- {seqlen_q, params.d, params.h, batch_q}, // shape_Q
93
- {params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
94
- static_cast<Element*>(params.k_ptr),
95
- {!params.page_table ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size,
96
- params.d, params.h_k, !params.page_table ? batch_k : params.num_pages}, // shape_K
97
- {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
98
- static_cast<Element*>(params.v_ptr),
99
- params.dv, // headdim_v
100
- v_strides, // stride_V
101
- static_cast<Element const*>(params.knew_ptr),
102
- {!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new
103
- {params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new
104
- static_cast<Element const*>(params.vnew_ptr),
105
- {params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new
106
- static_cast<Element const*>(params.qv_ptr),
107
- {params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0}, // stride_Qv
108
- static_cast<Element const*>(params.rotary_cos_ptr),
109
- {params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter
110
- {params.rotary_dim / 2, _1{}}, // stride_rotary_cos
111
- static_cast<Element const*>(params.rotary_sin_ptr),
112
- {params.rotary_dim / 2, _1{}}, // stride_rotary_sin
113
- params.is_rotary_interleaved,
114
- params.page_table,
115
- // if page_size is not set, avoid dividing by zero
116
- {params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table
117
- {params.page_table_batch_stride, _1{}}, // stride_page_table
118
- params.scale_softmax,
119
- params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr,
120
- {params.q_descale_batch_stride, params.q_descale_head_stride},
121
- {params.k_descale_batch_stride, params.k_descale_head_stride},
122
- {params.v_descale_batch_stride, params.v_descale_head_stride},
123
- params.window_size_left, params.window_size_right, params.attention_chunk,
124
- params.softcap,
125
- params.num_splits,
126
- params.kv_batch_idx,
127
- params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
128
- params.seqused_q, params.seqused_k,
129
- params.leftpad_k, params.seqlens_rotary
130
- };
131
- typename CollectiveEpilogue::Arguments epilogue_args {
132
- static_cast<ElementOut*>(params.o_ptr),
133
- {seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O
134
- {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0, 0}, // stride_O
135
- static_cast<float*>(params.oaccum_ptr),
136
- {params.oaccum_row_stride, _1{}, params.oaccum_head_stride, !is_varlen_q ? params.oaccum_batch_stride : 0, params.oaccum_split_stride}, // stride_O_partial
137
- static_cast<float*>(params.softmax_lse_ptr),
138
- {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, 0}, // stride_LSE
139
- static_cast<float*>(params.softmax_lseaccum_ptr),
140
- {_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, params.h * seqlen_q * batch_q}, // stride_LSE_partial
141
- params.h_k,
142
- params.cu_seqlens_q, params.seqused_q
143
- };
144
-
145
- int qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k);
146
- int num_blocks_m = cutlass::ceil_div(params.seqlen_q * qhead_per_khead, get<0>(TileShape_MNK{}));
147
- num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{}));
148
- typename flash::TileSchedulerArguments scheduler_args {
149
- num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits,
150
- params.h / params.h_k,
151
- params.seqlen_q,
152
- params.seqlen_k, params.d, params.dv, sizeof(Element),
153
- params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q,
154
- // params.num_m_blocks_ptr,
155
- params.num_splits_dynamic_ptr,
156
- };
157
-
158
- if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) {
159
- prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/);
160
- CHECK_CUDA_KERNEL_LAUNCH();
161
- }
162
-
163
- int device;
164
- CHECK_CUDA(cudaGetDevice(&device));
165
- typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
166
- mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
167
- });
168
-
169
- dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
170
- dim3 block_dims = AttnKernel::get_block_shape();
171
- int smem_size = AttnKernel::SharedStorageSize;
172
- // int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
173
- // int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
174
- // int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
175
- // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
176
- // Get the ptr to kernel function.
177
- if constexpr (size(ClusterShape{}) > 1) {
178
- void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
179
- if (smem_size >= 48 * 1024) {
180
- CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
181
- }
182
- dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
183
- cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
184
- cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params);
185
- } else {
186
- auto kernel = cutlass::device_kernel<AttnKernel>;
187
- if (smem_size >= 48 * 1024) {
188
- CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
189
- }
190
- // kernel<<<grid_dims, block_dims, smem_size, stream>>>(kernel_params);
191
- cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params,
192
- Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/);
193
- }
194
- CHECK_CUDA_KERNEL_LAUNCH();
195
- }
196
-
197
- template<int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>
198
- void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream) {
199
- static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported");
200
- static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
201
- using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;
202
- CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
203
- VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] {
204
- static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1;
205
- VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] {
206
- // Only needed here to decide if we should use cluster
207
- static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128;
208
-
209
- static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen;
210
- BOOL_SWITCH(params.qv_ptr, HasQV_, [&] {
211
- static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256;
212
- APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] {
213
- // Only use Cluster if number of tiles along seqlen_q is even and not varlen
214
- CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] {
215
- static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1;
216
- run_flash_fwd<Arch, kHeadDim, kHeadDimV, ClusterM, T, T_out, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV && Varlen, HasQv, PackGQA, Split, V_colmajor>(params, stream);
217
- });
218
- });
219
- });
220
- });
221
- });
222
- });
223
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/flash_prepare_scheduler.cu DELETED
@@ -1,124 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #include "cutlass/fast_math.h"
6
- #include "cutlass/barrier.h"
7
- #include "cutlass/arch/barrier.h"
8
-
9
- #include "cutlass/arch/grid_dependency_control.h"
10
-
11
- #include "flash.h"
12
-
13
- namespace flash {
14
-
15
- __global__ void prepare_varlen_num_blocks_kernel(
16
- int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static,
17
- int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new,
18
- int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr,
19
- int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static,
20
- cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod,
21
- int* const tile_count_semaphore,
22
- // int* const num_m_blocks_ptr,
23
- int* const num_splits_dynamic_ptr,
24
- bool enable_pdl) {
25
-
26
- static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1;
27
- static constexpr int kSmemSize = 1;
28
- // Assume that there's only one block in the grid
29
- __shared__ int total_blocks_smem[kSmemSize];
30
-
31
- // There's only 1 block in the grid, so might as well start launching the main attn kernel
32
- if (enable_pdl) { cutlass::arch::launch_dependent_grids(); }
33
-
34
- if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; }
35
- __syncthreads();
36
-
37
- if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = 0; }
38
-
39
- int lane = threadIdx.x % cutlass::NumThreadsPerWarp;
40
-
41
- auto get_num_m_blocks = [&](int bidb_start) {
42
- int batch_idx = lane + bidb_start;
43
- int seqlen;
44
- if (seqused_q) {
45
- seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0;
46
- } else if (cu_seqlens_q) {
47
- int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_q[batch_idx] : 0;
48
- int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
49
- seqlen = next_cu_seqlen - cur_cu_seqlen;
50
- } else {
51
- seqlen = seqlen_q_static;
52
- }
53
- seqlen *= qhead_per_khead;
54
- return batch_idx < num_batch && lane < kNumBatchPerWarp
55
- ? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0;
56
- };
57
-
58
- auto get_num_n_blocks = [&](int bidb_start) {
59
- int batch_idx = lane + bidb_start;
60
- int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0;
61
- int seqlen;
62
- if (seqused_k) {
63
- seqlen = batch_idx < num_batch ? seqused_k[batch_idx] : 0;
64
- } else if (cu_seqlens_k) {
65
- int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_k[batch_idx] : 0;
66
- int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
67
- seqlen = next_cu_seqlen - cur_cu_seqlen;
68
- } else {
69
- seqlen = seqlen_k_static;
70
- }
71
- int seqlen_new;
72
- if (cu_seqlens_k_new) {
73
- int cur_cu_seqlen_new = batch_idx <= num_batch ? cu_seqlens_k_new[batch_idx] : 0;
74
- int next_cu_seqlen_new = __shfl_down_sync(0xffffffff, cur_cu_seqlen_new, 1);
75
- seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new;
76
- } else {
77
- seqlen_new = seqlen_k_new_static;
78
- }
79
- // if (threadIdx.x == 0) { printf("seqlen = %d, seqlen_new = %d, leftpad_k = %d\n", seqlen, seqlen_new, leftpad_k); }
80
- seqlen = seqlen - leftpad_k + seqlen_new;
81
- return batch_idx < num_batch && lane < kNumBatchPerWarp
82
- ? blockn_divmod.div(seqlen + blockn_divmod.divisor - 1) : 0;
83
- };
84
-
85
- int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp;
86
- int bidb_start = kNumBatchPerWarp * warp_idx;
87
- int num_m_blocks = get_num_m_blocks(bidb_start);
88
- int num_n_blocks = get_num_n_blocks(bidb_start);
89
-
90
- int total_blocks = num_m_blocks * num_n_blocks;
91
- // Warp sum
92
- #pragma unroll
93
- for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) {
94
- total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i);
95
- }
96
- if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); }
97
- __syncthreads();
98
- total_blocks = total_blocks_smem[0];
99
- // 10% margin
100
- int blocks_per_sm = static_cast<int>(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm)));
101
- // blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM
102
- int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1);
103
- if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) {
104
- num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic;
105
- // printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic);
106
- }
107
- }
108
-
109
- } // flash
110
-
111
- void prepare_varlen_num_blocks(Flash_fwd_params &params, cudaStream_t stream, bool packgqa,
112
- int blockM, int blockN, bool enable_pdl) {
113
- // Only support batch <= 992 (32 warps, each with 31 batches)
114
- int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k);
115
- flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>(
116
- params.seqlen_q, params.seqlen_k, params.seqlen_knew,
117
- params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
118
- params.seqused_q, params.seqused_k, params.leftpad_k,
119
- params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits,
120
- cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN),
121
- params.tile_count_semaphore,
122
- // params.num_m_blocks_ptr,
123
- params.num_splits_dynamic_ptr, enable_pdl);
124
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/heuristics.h DELETED
@@ -1,59 +0,0 @@
1
- /******************************************************************************
2
- * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
3
- ******************************************************************************/
4
-
5
- #pragma once
6
-
7
- #include <vector>
8
-
9
- inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) {
10
- // If varlen, we don't actually know seqlen_q but only max_seqlen_q.
11
- if (varlen_q) return true;
12
- // Heuristic: PackGQA is a bit slower but can help if seqlen_q is small or not near a multiple of kBlockM
13
- auto round_up = [](int a, int b) { return (a + b - 1) / b * b; };
14
- float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM));
15
- float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM));
16
- return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency;
17
- };
18
-
19
- // Find the number of splits that maximizes the occupancy. For example, if we have
20
- // batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
21
- // better than having 3 splits (efficiency = 0.67). However, we also don't want too many
22
- // splits as that would incur more HBM reads/writes.
23
- // So we find the best efficiency, then find the smallest number of splits that gets 85%
24
- // of the best efficiency.
25
- inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) {
26
- // If we have enough to almost fill the SMs, then just use 1 split
27
- // However, in the case of super long seqlen where each head of KV doesn't even fit into
28
- // L2 (we assume that L2 size is 50MB), we want to split.
29
- if (total_mblocks >= 0.8f * num_SMs) {
30
- int const size_l2 = 50 * 1024 * 1024;
31
- // Only split if there are enough queries to go over the KV at least twice
32
- // Don't split if causal
33
- if (size_one_kv_head > size_l2 && num_m_blocks >= num_SMs * 2 && !is_causal_or_local) {
34
- return std::min((size_one_kv_head + size_l2 - 1) / size_l2, max_splits);
35
- } else {
36
- return 1;
37
- }
38
- }
39
- // If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512.
40
- if (num_n_blocks <= 4) { return 1; }
41
- max_splits = std::min({max_splits, num_SMs, num_n_blocks});
42
- float max_efficiency = 0.f;
43
- std::vector<float> efficiency;
44
- efficiency.reserve(max_splits);
45
- for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
46
- float n_waves = float(total_mblocks * num_splits) / num_SMs;
47
- float eff = n_waves / ceil(n_waves);
48
- // printf("num_splits = %d, eff = %f\n", num_splits, eff);
49
- if (eff > max_efficiency) { max_efficiency = eff; }
50
- efficiency.push_back(eff);
51
- }
52
- for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
53
- if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
54
- // printf("num_splits chosen = %d\n", num_splits);
55
- return num_splits;
56
- }
57
- }
58
- return 1;
59
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM128
9
- template<>
10
- void run_mha_bwd_<80, cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim128<80, cutlass::bfloat16_t, false>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim128<86, cutlass::bfloat16_t, false>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM128
8
- template<>
9
- void run_mha_bwd_<90, cutlass::bfloat16_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim128<90, cutlass::bfloat16_t, false>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM128
9
- template<>
10
- void run_mha_bwd_<80, cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim128<80, cutlass::bfloat16_t, true>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim128<86, cutlass::bfloat16_t, true>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM128
8
- template<>
9
- void run_mha_bwd_<90, cutlass::bfloat16_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim128<90, cutlass::bfloat16_t, true>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu DELETED
@@ -1,6 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_hdim128_bf16_sm90.cu"
6
- #include "flash_bwd_hdim128_bf16_softcap_sm90.cu"
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM128
9
- template<>
10
- void run_mha_bwd_<80, cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim128<80, cutlass::half_t, false>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim128<86, cutlass::half_t, false>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM128
8
- template<>
9
- void run_mha_bwd_<90, cutlass::half_t, 128, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim128<90, cutlass::half_t, false>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM128
9
- template<>
10
- void run_mha_bwd_<80, cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim128<80, cutlass::half_t, true>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim128<86, cutlass::half_t, true>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM128
8
- template<>
9
- void run_mha_bwd_<90, cutlass::half_t, 128, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim128<90, cutlass::half_t, true>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu DELETED
@@ -1,6 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_hdim128_fp16_sm90.cu"
6
- #include "flash_bwd_hdim128_fp16_softcap_sm90.cu"
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM192
9
- template<>
10
- void run_mha_bwd_<80, cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim192<80, cutlass::bfloat16_t, false>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim192<86, cutlass::bfloat16_t, false>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM192
8
- template<>
9
- void run_mha_bwd_<90, cutlass::bfloat16_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim192<90, cutlass::bfloat16_t, false>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM192
9
- template<>
10
- void run_mha_bwd_<80, cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim192<80, cutlass::bfloat16_t, true>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim192<86, cutlass::bfloat16_t, true>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM192
8
- template<>
9
- void run_mha_bwd_<90, cutlass::bfloat16_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim192<90, cutlass::bfloat16_t, true>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu DELETED
@@ -1,6 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_hdim192_bf16_sm90.cu"
6
- #include "flash_bwd_hdim192_bf16_softcap_sm90.cu"
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM192
9
- template<>
10
- void run_mha_bwd_<80, cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim192<80, cutlass::half_t, false>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim192<86, cutlass::half_t, false>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM192
8
- template<>
9
- void run_mha_bwd_<90, cutlass::half_t, 192, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim192<90, cutlass::half_t, false>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM192
9
- template<>
10
- void run_mha_bwd_<80, cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim192<80, cutlass::half_t, true>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim192<86, cutlass::half_t, true>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM192
8
- template<>
9
- void run_mha_bwd_<90, cutlass::half_t, 192, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim192<90, cutlass::half_t, true>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu DELETED
@@ -1,6 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_hdim192_fp16_sm90.cu"
6
- #include "flash_bwd_hdim192_fp16_softcap_sm90.cu"
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM256
9
- template<>
10
- void run_mha_bwd_<80, cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim256<80, cutlass::bfloat16_t, false>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim256<86, cutlass::bfloat16_t, false>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM256
8
- template<>
9
- void run_mha_bwd_<90, cutlass::bfloat16_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim256<90, cutlass::bfloat16_t, false>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM256
9
- template<>
10
- void run_mha_bwd_<80, cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim256<80, cutlass::bfloat16_t, true>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim256<86, cutlass::bfloat16_t, true>(params, stream);
16
- }
17
- #endif
18
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu DELETED
@@ -1,12 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_HDIM256
8
- template<>
9
- void run_mha_bwd_<90, cutlass::bfloat16_t, 256, true>(Flash_bwd_params &params, cudaStream_t stream) {
10
- run_mha_bwd_hdim256<90, cutlass::bfloat16_t, true>(params, stream);
11
- }
12
- #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu DELETED
@@ -1,6 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_hdim256_bf16_sm90.cu"
6
- #include "flash_bwd_hdim256_bf16_softcap_sm90.cu"
 
 
 
 
 
 
 
flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu DELETED
@@ -1,18 +0,0 @@
1
- // Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
2
- // Splitting the different template instantiations to different files to speed up compilation.
3
- // This file is auto-generated. See "generate_kernels.py"
4
-
5
- #include "flash_bwd_launch_template.h"
6
-
7
- #ifndef FLASHATTENTION_DISABLE_SM8x
8
- #ifndef FLASHATTENTION_DISABLE_HDIM256
9
- template<>
10
- void run_mha_bwd_<80, cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
11
- run_mha_bwd_hdim256<80, cutlass::half_t, false>(params, stream);
12
- }
13
- template<>
14
- void run_mha_bwd_<86, cutlass::half_t, 256, false>(Flash_bwd_params &params, cudaStream_t stream) {
15
- run_mha_bwd_hdim256<86, cutlass::half_t, false>(params, stream);
16
- }
17
- #endif
18
- #endif