Remove source
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +2 -0
- build.toml +0 -589
- flake.lock +0 -168
- flake.nix +0 -35
- flash-attn/block.h +0 -139
- flash-attn/copy_sm90_bulk_reduce.hpp +0 -49
- flash-attn/cuda_check.h +0 -19
- flash-attn/epilogue_bwd.hpp +0 -533
- flash-attn/epilogue_fwd.hpp +0 -484
- flash-attn/flash.h +0 -218
- flash-attn/flash_api.cpp +0 -1720
- flash-attn/flash_bwd_kernel_sm80.h +0 -173
- flash-attn/flash_bwd_kernel_sm90.h +0 -282
- flash-attn/flash_bwd_launch_template.h +0 -390
- flash-attn/flash_bwd_postprocess_kernel.h +0 -256
- flash-attn/flash_bwd_preprocess_kernel.h +0 -252
- flash-attn/flash_fwd_combine.cu +0 -13
- flash-attn/flash_fwd_combine_kernel.h +0 -482
- flash-attn/flash_fwd_combine_launch_template.h +0 -80
- flash-attn/flash_fwd_kernel_sm80.h +0 -215
- flash-attn/flash_fwd_kernel_sm90.h +0 -458
- flash-attn/flash_fwd_launch_template.h +0 -223
- flash-attn/flash_prepare_scheduler.cu +0 -124
- flash-attn/heuristics.h +0 -59
- flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu +0 -6
- flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu +0 -6
- flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu +0 -6
- flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu +0 -6
- flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu +0 -18
- flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu +0 -12
- flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu +0 -6
- flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu +0 -18
README.md
CHANGED
|
@@ -11,3 +11,5 @@ attention mechanism, designed to work with large models and long sequences.
|
|
| 11 |
This is a Hugging Face compliant kernel build of Flash Attention.
|
| 12 |
|
| 13 |
Original code here [https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention).
|
|
|
|
|
|
|
|
|
| 11 |
This is a Hugging Face compliant kernel build of Flash Attention.
|
| 12 |
|
| 13 |
Original code here [https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention).
|
| 14 |
+
|
| 15 |
+
Kernel source: https://github.com/huggingface/kernels-community/tree/main/flash-attn3
|
build.toml
DELETED
|
@@ -1,589 +0,0 @@
|
|
| 1 |
-
[general]
|
| 2 |
-
name = "flash_attn3"
|
| 3 |
-
universal = false
|
| 4 |
-
cuda-minver = "12.4"
|
| 5 |
-
cuda-maxver = "12.4"
|
| 6 |
-
|
| 7 |
-
[torch]
|
| 8 |
-
src = [
|
| 9 |
-
"torch-ext/pytorch_shim.h",
|
| 10 |
-
"torch-ext/torch_binding.cpp",
|
| 11 |
-
"torch-ext/torch_binding.h",
|
| 12 |
-
]
|
| 13 |
-
|
| 14 |
-
[kernel.flash_attn]
|
| 15 |
-
backend = "cuda"
|
| 16 |
-
cuda-capabilities = ["8.0", "9.0a"]
|
| 17 |
-
cuda-flags = [
|
| 18 |
-
"-O3",
|
| 19 |
-
"-std=c++17",
|
| 20 |
-
"--ftemplate-backtrace-limit=0", # To debug template code
|
| 21 |
-
"--use_fast_math",
|
| 22 |
-
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
| 23 |
-
"-DCUTLASS_ENABLE_GDC_FOR_SM90",
|
| 24 |
-
"--expt-relaxed-constexpr",
|
| 25 |
-
"--expt-extended-lambda",
|
| 26 |
-
"--use_fast_math",
|
| 27 |
-
"-DNDEBUG",
|
| 28 |
-
]
|
| 29 |
-
|
| 30 |
-
src = [
|
| 31 |
-
"flash-attn/cuda_check.h",
|
| 32 |
-
"flash-attn/flash_api.cpp",
|
| 33 |
-
"flash-attn/flash_fwd_combine.cu",
|
| 34 |
-
"flash-attn/flash_fwd_combine_kernel.h",
|
| 35 |
-
"flash-attn/flash_fwd_combine_launch_template.h",
|
| 36 |
-
"flash-attn/flash.h",
|
| 37 |
-
"flash-attn/flash_prepare_scheduler.cu",
|
| 38 |
-
"flash-attn/heuristics.h",
|
| 39 |
-
"flash-attn/seqlen.h",
|
| 40 |
-
"flash-attn/static_switch.h",
|
| 41 |
-
"flash-attn/tile_size.h",
|
| 42 |
-
"flash-attn/utils.h",
|
| 43 |
-
]
|
| 44 |
-
depends = ["torch", "cutlass_3_9"]
|
| 45 |
-
|
| 46 |
-
[kernel.flash_attn_sm80]
|
| 47 |
-
backend = "cuda"
|
| 48 |
-
cuda-capabilities = ["8.0", "9.0a"]
|
| 49 |
-
cuda-flags = [
|
| 50 |
-
"-O3",
|
| 51 |
-
"-std=c++17",
|
| 52 |
-
"--ftemplate-backtrace-limit=0", # To debug template code
|
| 53 |
-
"--use_fast_math",
|
| 54 |
-
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
| 55 |
-
"-DCUTLASS_ENABLE_GDC_FOR_SM90",
|
| 56 |
-
"--expt-relaxed-constexpr",
|
| 57 |
-
"--expt-extended-lambda",
|
| 58 |
-
"--use_fast_math",
|
| 59 |
-
"-DNDEBUG",
|
| 60 |
-
]
|
| 61 |
-
src = [
|
| 62 |
-
"flash-attn/block.h",
|
| 63 |
-
"flash-attn/copy_sm90_bulk_reduce.hpp",
|
| 64 |
-
"flash-attn/epilogue_bwd.hpp",
|
| 65 |
-
"flash-attn/epilogue_fwd.hpp",
|
| 66 |
-
"flash-attn/flash.h",
|
| 67 |
-
"flash-attn/flash_bwd_kernel_sm80.h",
|
| 68 |
-
"flash-attn/flash_bwd_kernel_sm90.h",
|
| 69 |
-
"flash-attn/flash_bwd_launch_template.h",
|
| 70 |
-
"flash-attn/flash_bwd_postprocess_kernel.h",
|
| 71 |
-
"flash-attn/flash_bwd_preprocess_kernel.h",
|
| 72 |
-
"flash-attn/flash_fwd_launch_template.h",
|
| 73 |
-
"flash-attn/flash_fwd_kernel_sm80.h",
|
| 74 |
-
"flash-attn/flash_fwd_kernel_sm90.h",
|
| 75 |
-
"flash-attn/heuristics.h",
|
| 76 |
-
"flash-attn/mainloop_bwd_sm80.hpp",
|
| 77 |
-
"flash-attn/mainloop_fwd_sm80.hpp",
|
| 78 |
-
"flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
|
| 79 |
-
"flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
|
| 80 |
-
"flash-attn/mask.h",
|
| 81 |
-
"flash-attn/named_barrier.hpp",
|
| 82 |
-
"flash-attn/pack_gqa.h",
|
| 83 |
-
"flash-attn/paged_kv.h",
|
| 84 |
-
"flash-attn/rotary.h",
|
| 85 |
-
"flash-attn/sm90_pipeline_no_cluster.hpp",
|
| 86 |
-
"flash-attn/softmax.h",
|
| 87 |
-
"flash-attn/tile_size.h",
|
| 88 |
-
"flash-attn/tile_scheduler.hpp",
|
| 89 |
-
|
| 90 |
-
"flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu",
|
| 91 |
-
"flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu",
|
| 92 |
-
"flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu",
|
| 93 |
-
"flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu",
|
| 94 |
-
"flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu",
|
| 95 |
-
"flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu",
|
| 96 |
-
"flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu",
|
| 97 |
-
"flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu",
|
| 98 |
-
"flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu",
|
| 99 |
-
"flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu",
|
| 100 |
-
"flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu",
|
| 101 |
-
"flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm80.cu",
|
| 102 |
-
"flash-attn/instantiations/flash_bwd_hdim64_bf16_sm80.cu",
|
| 103 |
-
"flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm80.cu",
|
| 104 |
-
"flash-attn/instantiations/flash_bwd_hdim64_fp16_sm80.cu",
|
| 105 |
-
"flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm80.cu",
|
| 106 |
-
"flash-attn/instantiations/flash_bwd_hdim96_bf16_sm80.cu",
|
| 107 |
-
"flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm80.cu",
|
| 108 |
-
"flash-attn/instantiations/flash_bwd_hdim96_fp16_sm80.cu",
|
| 109 |
-
"flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm80.cu",
|
| 110 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm80.cu",
|
| 111 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm80.cu",
|
| 112 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm80.cu",
|
| 113 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm80.cu",
|
| 114 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_sm80.cu",
|
| 115 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm80.cu",
|
| 116 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm80.cu",
|
| 117 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm80.cu",
|
| 118 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm80.cu",
|
| 119 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm80.cu",
|
| 120 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm80.cu",
|
| 121 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm80.cu",
|
| 122 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_sm80.cu",
|
| 123 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm80.cu",
|
| 124 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm80.cu",
|
| 125 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm80.cu",
|
| 126 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm80.cu",
|
| 127 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm80.cu",
|
| 128 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm80.cu",
|
| 129 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm80.cu",
|
| 130 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_sm80.cu",
|
| 131 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm80.cu",
|
| 132 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm80.cu",
|
| 133 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm80.cu",
|
| 134 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm80.cu",
|
| 135 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm80.cu",
|
| 136 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm80.cu",
|
| 137 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm80.cu",
|
| 138 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_sm80.cu",
|
| 139 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm80.cu",
|
| 140 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm80.cu",
|
| 141 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm80.cu",
|
| 142 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm80.cu",
|
| 143 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm80.cu",
|
| 144 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm80.cu",
|
| 145 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm80.cu",
|
| 146 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_sm80.cu",
|
| 147 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm80.cu",
|
| 148 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm80.cu",
|
| 149 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm80.cu",
|
| 150 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm80.cu",
|
| 151 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm80.cu",
|
| 152 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm80.cu",
|
| 153 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm80.cu",
|
| 154 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_sm80.cu",
|
| 155 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm80.cu",
|
| 156 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm80.cu",
|
| 157 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm80.cu",
|
| 158 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm80.cu",
|
| 159 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm80.cu",
|
| 160 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm80.cu",
|
| 161 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm80.cu",
|
| 162 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_sm80.cu",
|
| 163 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm80.cu",
|
| 164 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm80.cu",
|
| 165 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm80.cu",
|
| 166 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm80.cu",
|
| 167 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm80.cu",
|
| 168 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm80.cu",
|
| 169 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm80.cu",
|
| 170 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_sm80.cu",
|
| 171 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm80.cu",
|
| 172 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm80.cu",
|
| 173 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm80.cu",
|
| 174 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm80.cu",
|
| 175 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm80.cu",
|
| 176 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm80.cu",
|
| 177 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm80.cu",
|
| 178 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_sm80.cu",
|
| 179 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm80.cu",
|
| 180 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm80.cu",
|
| 181 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm80.cu",
|
| 182 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm80.cu",
|
| 183 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm80.cu",
|
| 184 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm80.cu",
|
| 185 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm80.cu",
|
| 186 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_sm80.cu",
|
| 187 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm80.cu",
|
| 188 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm80.cu",
|
| 189 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm80.cu"
|
| 190 |
-
]
|
| 191 |
-
include = ["flash-attn"]
|
| 192 |
-
depends = ["torch", "cutlass_3_9"]
|
| 193 |
-
|
| 194 |
-
[kernel.flash_attn_sm90]
|
| 195 |
-
backend = "cuda"
|
| 196 |
-
cuda-capabilities = ["8.0", "9.0a"]
|
| 197 |
-
cuda-flags = [
|
| 198 |
-
"-O3",
|
| 199 |
-
"-std=c++17",
|
| 200 |
-
"--ftemplate-backtrace-limit=0", # To debug template code
|
| 201 |
-
"--use_fast_math",
|
| 202 |
-
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
| 203 |
-
"-DCUTLASS_ENABLE_GDC_FOR_SM90",
|
| 204 |
-
"--expt-relaxed-constexpr",
|
| 205 |
-
"--expt-extended-lambda",
|
| 206 |
-
"--use_fast_math",
|
| 207 |
-
"-DNDEBUG",
|
| 208 |
-
]
|
| 209 |
-
src = [
|
| 210 |
-
"flash-attn/block.h",
|
| 211 |
-
"flash-attn/copy_sm90_bulk_reduce.hpp",
|
| 212 |
-
"flash-attn/epilogue_bwd.hpp",
|
| 213 |
-
"flash-attn/epilogue_fwd.hpp",
|
| 214 |
-
"flash-attn/flash.h",
|
| 215 |
-
"flash-attn/flash_bwd_kernel_sm80.h",
|
| 216 |
-
"flash-attn/flash_bwd_kernel_sm90.h",
|
| 217 |
-
"flash-attn/flash_bwd_launch_template.h",
|
| 218 |
-
"flash-attn/flash_bwd_postprocess_kernel.h",
|
| 219 |
-
"flash-attn/flash_bwd_preprocess_kernel.h",
|
| 220 |
-
"flash-attn/flash_fwd_launch_template.h",
|
| 221 |
-
"flash-attn/flash_fwd_kernel_sm80.h",
|
| 222 |
-
"flash-attn/flash_fwd_kernel_sm90.h",
|
| 223 |
-
"flash-attn/heuristics.h",
|
| 224 |
-
"flash-attn/mainloop_bwd_sm80.hpp",
|
| 225 |
-
"flash-attn/mainloop_fwd_sm80.hpp",
|
| 226 |
-
"flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
|
| 227 |
-
"flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
|
| 228 |
-
"flash-attn/mask.h",
|
| 229 |
-
"flash-attn/named_barrier.hpp",
|
| 230 |
-
"flash-attn/pack_gqa.h",
|
| 231 |
-
"flash-attn/paged_kv.h",
|
| 232 |
-
"flash-attn/rotary.h",
|
| 233 |
-
"flash-attn/sm90_pipeline_no_cluster.hpp",
|
| 234 |
-
"flash-attn/softmax.h",
|
| 235 |
-
"flash-attn/tile_size.h",
|
| 236 |
-
"flash-attn/tile_scheduler.hpp",
|
| 237 |
-
|
| 238 |
-
"flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu",
|
| 239 |
-
"flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu",
|
| 240 |
-
"flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu",
|
| 241 |
-
"flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu",
|
| 242 |
-
"flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu",
|
| 243 |
-
"flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu",
|
| 244 |
-
"flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu",
|
| 245 |
-
"flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu",
|
| 246 |
-
"flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu",
|
| 247 |
-
"flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu",
|
| 248 |
-
"flash-attn/instantiations/flash_bwd_hdim256_fp16_sm90.cu",
|
| 249 |
-
"flash-attn/instantiations/flash_bwd_hdim256_fp16_softcap_sm90.cu",
|
| 250 |
-
"flash-attn/instantiations/flash_bwd_hdim64_bf16_sm90.cu",
|
| 251 |
-
"flash-attn/instantiations/flash_bwd_hdim64_bf16_softcap_sm90.cu",
|
| 252 |
-
"flash-attn/instantiations/flash_bwd_hdim64_fp16_sm90.cu",
|
| 253 |
-
"flash-attn/instantiations/flash_bwd_hdim64_fp16_softcap_sm90.cu",
|
| 254 |
-
"flash-attn/instantiations/flash_bwd_hdim96_bf16_sm90.cu",
|
| 255 |
-
"flash-attn/instantiations/flash_bwd_hdim96_bf16_softcap_sm90.cu",
|
| 256 |
-
"flash-attn/instantiations/flash_bwd_hdim96_fp16_sm90.cu",
|
| 257 |
-
"flash-attn/instantiations/flash_bwd_hdim96_fp16_softcap_sm90.cu",
|
| 258 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_packgqa_sm90.cu",
|
| 259 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_sm90.cu",
|
| 260 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_softcap_sm90.cu",
|
| 261 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_sm90.cu",
|
| 262 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_paged_split_softcap_sm90.cu",
|
| 263 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_sm90.cu",
|
| 264 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_packgqa_sm90.cu",
|
| 265 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_softcap_sm90.cu",
|
| 266 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_split_sm90.cu",
|
| 267 |
-
"flash-attn/instantiations/flash_fwd_hdim128_bf16_split_softcap_sm90.cu",
|
| 268 |
-
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_packgqa_sm90.cu",
|
| 269 |
-
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_sm90.cu",
|
| 270 |
-
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_softcap_sm90.cu",
|
| 271 |
-
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_sm90.cu",
|
| 272 |
-
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_paged_split_softcap_sm90.cu",
|
| 273 |
-
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_sm90.cu",
|
| 274 |
-
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_packgqa_sm90.cu",
|
| 275 |
-
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_softcap_sm90.cu",
|
| 276 |
-
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_sm90.cu",
|
| 277 |
-
"flash-attn/instantiations/flash_fwd_hdim128_e4m3_split_softcap_sm90.cu",
|
| 278 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_packgqa_sm90.cu",
|
| 279 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_sm90.cu",
|
| 280 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_softcap_sm90.cu",
|
| 281 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_sm90.cu",
|
| 282 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_paged_split_softcap_sm90.cu",
|
| 283 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_sm90.cu",
|
| 284 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_packgqa_sm90.cu",
|
| 285 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_softcap_sm90.cu",
|
| 286 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_split_sm90.cu",
|
| 287 |
-
"flash-attn/instantiations/flash_fwd_hdim128_fp16_split_softcap_sm90.cu",
|
| 288 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_packgqa_sm90.cu",
|
| 289 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_sm90.cu",
|
| 290 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_softcap_sm90.cu",
|
| 291 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_sm90.cu",
|
| 292 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_paged_split_softcap_sm90.cu",
|
| 293 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_sm90.cu",
|
| 294 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_packgqa_sm90.cu",
|
| 295 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_softcap_sm90.cu",
|
| 296 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_sm90.cu",
|
| 297 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_bf16_split_softcap_sm90.cu",
|
| 298 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_packgqa_sm90.cu",
|
| 299 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_sm90.cu",
|
| 300 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_softcap_sm90.cu",
|
| 301 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_sm90.cu",
|
| 302 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_paged_split_softcap_sm90.cu",
|
| 303 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_sm90.cu",
|
| 304 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_packgqa_sm90.cu",
|
| 305 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_softcap_sm90.cu",
|
| 306 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_sm90.cu",
|
| 307 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_e4m3_split_softcap_sm90.cu",
|
| 308 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_packgqa_sm90.cu",
|
| 309 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_sm90.cu",
|
| 310 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_softcap_sm90.cu",
|
| 311 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_sm90.cu",
|
| 312 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_paged_split_softcap_sm90.cu",
|
| 313 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_sm90.cu",
|
| 314 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_packgqa_sm90.cu",
|
| 315 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_softcap_sm90.cu",
|
| 316 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_sm90.cu",
|
| 317 |
-
"flash-attn/instantiations/flash_fwd_hdim192_128_fp16_split_softcap_sm90.cu",
|
| 318 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_packgqa_sm90.cu",
|
| 319 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_sm90.cu",
|
| 320 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_softcap_sm90.cu",
|
| 321 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_sm90.cu",
|
| 322 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_paged_split_softcap_sm90.cu",
|
| 323 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_sm90.cu",
|
| 324 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_packgqa_sm90.cu",
|
| 325 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_softcap_sm90.cu",
|
| 326 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_split_sm90.cu",
|
| 327 |
-
"flash-attn/instantiations/flash_fwd_hdim192_bf16_split_softcap_sm90.cu",
|
| 328 |
-
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_packgqa_sm90.cu",
|
| 329 |
-
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_sm90.cu",
|
| 330 |
-
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_softcap_sm90.cu",
|
| 331 |
-
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_sm90.cu",
|
| 332 |
-
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_paged_split_softcap_sm90.cu",
|
| 333 |
-
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_sm90.cu",
|
| 334 |
-
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_packgqa_sm90.cu",
|
| 335 |
-
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_softcap_sm90.cu",
|
| 336 |
-
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_sm90.cu",
|
| 337 |
-
"flash-attn/instantiations/flash_fwd_hdim192_e4m3_split_softcap_sm90.cu",
|
| 338 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_packgqa_sm90.cu",
|
| 339 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_sm90.cu",
|
| 340 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_softcap_sm90.cu",
|
| 341 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_sm90.cu",
|
| 342 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_paged_split_softcap_sm90.cu",
|
| 343 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_sm90.cu",
|
| 344 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_packgqa_sm90.cu",
|
| 345 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_softcap_sm90.cu",
|
| 346 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_split_sm90.cu",
|
| 347 |
-
"flash-attn/instantiations/flash_fwd_hdim192_fp16_split_softcap_sm90.cu",
|
| 348 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_packgqa_sm90.cu",
|
| 349 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_sm90.cu",
|
| 350 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_softcap_sm90.cu",
|
| 351 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_sm90.cu",
|
| 352 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_paged_split_softcap_sm90.cu",
|
| 353 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_sm90.cu",
|
| 354 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_packgqa_sm90.cu",
|
| 355 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_softcap_sm90.cu",
|
| 356 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_split_sm90.cu",
|
| 357 |
-
"flash-attn/instantiations/flash_fwd_hdim256_bf16_split_softcap_sm90.cu",
|
| 358 |
-
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_packgqa_sm90.cu",
|
| 359 |
-
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_sm90.cu",
|
| 360 |
-
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_softcap_sm90.cu",
|
| 361 |
-
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_sm90.cu",
|
| 362 |
-
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_paged_split_softcap_sm90.cu",
|
| 363 |
-
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_sm90.cu",
|
| 364 |
-
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_packgqa_sm90.cu",
|
| 365 |
-
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_softcap_sm90.cu",
|
| 366 |
-
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_sm90.cu",
|
| 367 |
-
"flash-attn/instantiations/flash_fwd_hdim256_e4m3_split_softcap_sm90.cu",
|
| 368 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_packgqa_sm90.cu",
|
| 369 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_sm90.cu",
|
| 370 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_softcap_sm90.cu",
|
| 371 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_sm90.cu",
|
| 372 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_paged_split_softcap_sm90.cu",
|
| 373 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_sm90.cu",
|
| 374 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_packgqa_sm90.cu",
|
| 375 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_softcap_sm90.cu",
|
| 376 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_split_sm90.cu",
|
| 377 |
-
"flash-attn/instantiations/flash_fwd_hdim256_fp16_split_softcap_sm90.cu",
|
| 378 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_packgqa_sm90.cu",
|
| 379 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_sm90.cu",
|
| 380 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_softcap_sm90.cu",
|
| 381 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_sm90.cu",
|
| 382 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_paged_split_softcap_sm90.cu",
|
| 383 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_sm90.cu",
|
| 384 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_packgqa_sm90.cu",
|
| 385 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_softcap_sm90.cu",
|
| 386 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_sm90.cu",
|
| 387 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_bf16_split_softcap_sm90.cu",
|
| 388 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_packgqa_sm90.cu",
|
| 389 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_sm90.cu",
|
| 390 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_softcap_sm90.cu",
|
| 391 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_sm90.cu",
|
| 392 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_paged_split_softcap_sm90.cu",
|
| 393 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_sm90.cu",
|
| 394 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_packgqa_sm90.cu",
|
| 395 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_softcap_sm90.cu",
|
| 396 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_sm90.cu",
|
| 397 |
-
"flash-attn/instantiations/flash_fwd_hdim64_256_fp16_split_softcap_sm90.cu",
|
| 398 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_packgqa_sm90.cu",
|
| 399 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_sm90.cu",
|
| 400 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_softcap_sm90.cu",
|
| 401 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_sm90.cu",
|
| 402 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_paged_split_softcap_sm90.cu",
|
| 403 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_sm90.cu",
|
| 404 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_packgqa_sm90.cu",
|
| 405 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_softcap_sm90.cu",
|
| 406 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_sm90.cu",
|
| 407 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_bf16_split_softcap_sm90.cu",
|
| 408 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_packgqa_sm90.cu",
|
| 409 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_sm90.cu",
|
| 410 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_softcap_sm90.cu",
|
| 411 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_sm90.cu",
|
| 412 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_paged_split_softcap_sm90.cu",
|
| 413 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_sm90.cu",
|
| 414 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_packgqa_sm90.cu",
|
| 415 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_softcap_sm90.cu",
|
| 416 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_sm90.cu",
|
| 417 |
-
"flash-attn/instantiations/flash_fwd_hdim64_512_fp16_split_softcap_sm90.cu",
|
| 418 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_packgqa_sm90.cu",
|
| 419 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_sm90.cu",
|
| 420 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_softcap_sm90.cu",
|
| 421 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_sm90.cu",
|
| 422 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_paged_split_softcap_sm90.cu",
|
| 423 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_sm90.cu",
|
| 424 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_packgqa_sm90.cu",
|
| 425 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_softcap_sm90.cu",
|
| 426 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_split_sm90.cu",
|
| 427 |
-
"flash-attn/instantiations/flash_fwd_hdim64_bf16_split_softcap_sm90.cu",
|
| 428 |
-
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_packgqa_sm90.cu",
|
| 429 |
-
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_sm90.cu",
|
| 430 |
-
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_softcap_sm90.cu",
|
| 431 |
-
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_sm90.cu",
|
| 432 |
-
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_paged_split_softcap_sm90.cu",
|
| 433 |
-
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_sm90.cu",
|
| 434 |
-
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_packgqa_sm90.cu",
|
| 435 |
-
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_softcap_sm90.cu",
|
| 436 |
-
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_sm90.cu",
|
| 437 |
-
"flash-attn/instantiations/flash_fwd_hdim64_e4m3_split_softcap_sm90.cu",
|
| 438 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_packgqa_sm90.cu",
|
| 439 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_sm90.cu",
|
| 440 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_softcap_sm90.cu",
|
| 441 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_sm90.cu",
|
| 442 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_paged_split_softcap_sm90.cu",
|
| 443 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_sm90.cu",
|
| 444 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_packgqa_sm90.cu",
|
| 445 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_softcap_sm90.cu",
|
| 446 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_split_sm90.cu",
|
| 447 |
-
"flash-attn/instantiations/flash_fwd_hdim64_fp16_split_softcap_sm90.cu",
|
| 448 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_packgqa_sm90.cu",
|
| 449 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_sm90.cu",
|
| 450 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_softcap_sm90.cu",
|
| 451 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_sm90.cu",
|
| 452 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_paged_split_softcap_sm90.cu",
|
| 453 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_sm90.cu",
|
| 454 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_packgqa_sm90.cu",
|
| 455 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_softcap_sm90.cu",
|
| 456 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_split_sm90.cu",
|
| 457 |
-
"flash-attn/instantiations/flash_fwd_hdim96_bf16_split_softcap_sm90.cu",
|
| 458 |
-
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_packgqa_sm90.cu",
|
| 459 |
-
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_sm90.cu",
|
| 460 |
-
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_softcap_sm90.cu",
|
| 461 |
-
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_sm90.cu",
|
| 462 |
-
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_paged_split_softcap_sm90.cu",
|
| 463 |
-
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_sm90.cu",
|
| 464 |
-
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_packgqa_sm90.cu",
|
| 465 |
-
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_softcap_sm90.cu",
|
| 466 |
-
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_sm90.cu",
|
| 467 |
-
"flash-attn/instantiations/flash_fwd_hdim96_e4m3_split_softcap_sm90.cu",
|
| 468 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_packgqa_sm90.cu",
|
| 469 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_sm90.cu",
|
| 470 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_softcap_sm90.cu",
|
| 471 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_sm90.cu",
|
| 472 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_paged_split_softcap_sm90.cu",
|
| 473 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_sm90.cu",
|
| 474 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_packgqa_sm90.cu",
|
| 475 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_softcap_sm90.cu",
|
| 476 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_split_sm90.cu",
|
| 477 |
-
"flash-attn/instantiations/flash_fwd_hdim96_fp16_split_softcap_sm90.cu",
|
| 478 |
-
"flash-attn/instantiations/flash_fwd_hdimall_bf16_packgqa_sm90.cu",
|
| 479 |
-
"flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_sm90.cu",
|
| 480 |
-
"flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_softcap_sm90.cu",
|
| 481 |
-
"flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_sm90.cu",
|
| 482 |
-
"flash-attn/instantiations/flash_fwd_hdimall_bf16_paged_split_softcap_sm90.cu",
|
| 483 |
-
"flash-attn/instantiations/flash_fwd_hdimall_bf16_sm90.cu",
|
| 484 |
-
"flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_packgqa_sm90.cu",
|
| 485 |
-
"flash-attn/instantiations/flash_fwd_hdimall_bf16_softcap_sm90.cu",
|
| 486 |
-
"flash-attn/instantiations/flash_fwd_hdimall_bf16_split_sm90.cu",
|
| 487 |
-
"flash-attn/instantiations/flash_fwd_hdimall_bf16_split_softcap_sm90.cu",
|
| 488 |
-
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_packgqa_sm90.cu",
|
| 489 |
-
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_sm90.cu",
|
| 490 |
-
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_softcap_sm90.cu",
|
| 491 |
-
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_sm90.cu",
|
| 492 |
-
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_paged_split_softcap_sm90.cu",
|
| 493 |
-
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_sm90.cu",
|
| 494 |
-
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_packgqa_sm90.cu",
|
| 495 |
-
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_softcap_sm90.cu",
|
| 496 |
-
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_sm90.cu",
|
| 497 |
-
"flash-attn/instantiations/flash_fwd_hdimall_e4m3_split_softcap_sm90.cu",
|
| 498 |
-
"flash-attn/instantiations/flash_fwd_hdimall_fp16_packgqa_sm90.cu",
|
| 499 |
-
"flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_sm90.cu",
|
| 500 |
-
"flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_softcap_sm90.cu",
|
| 501 |
-
"flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_sm90.cu",
|
| 502 |
-
"flash-attn/instantiations/flash_fwd_hdimall_fp16_paged_split_softcap_sm90.cu",
|
| 503 |
-
"flash-attn/instantiations/flash_fwd_hdimall_fp16_sm90.cu",
|
| 504 |
-
"flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_packgqa_sm90.cu",
|
| 505 |
-
"flash-attn/instantiations/flash_fwd_hdimall_fp16_softcap_sm90.cu",
|
| 506 |
-
"flash-attn/instantiations/flash_fwd_hdimall_fp16_split_sm90.cu",
|
| 507 |
-
"flash-attn/instantiations/flash_fwd_hdimall_fp16_split_softcap_sm90.cu",
|
| 508 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_packgqa_sm90.cu",
|
| 509 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_sm90.cu",
|
| 510 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_softcap_sm90.cu",
|
| 511 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_sm90.cu",
|
| 512 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_paged_split_softcap_sm90.cu",
|
| 513 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_sm90.cu",
|
| 514 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_packgqa_sm90.cu",
|
| 515 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_softcap_sm90.cu",
|
| 516 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_sm90.cu",
|
| 517 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_bf16_split_softcap_sm90.cu",
|
| 518 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_packgqa_sm90.cu",
|
| 519 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_sm90.cu",
|
| 520 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_softcap_sm90.cu",
|
| 521 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_sm90.cu",
|
| 522 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_paged_split_softcap_sm90.cu",
|
| 523 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_sm90.cu",
|
| 524 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_packgqa_sm90.cu",
|
| 525 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_softcap_sm90.cu",
|
| 526 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_sm90.cu",
|
| 527 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_e4m3_split_softcap_sm90.cu",
|
| 528 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_packgqa_sm90.cu",
|
| 529 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_sm90.cu",
|
| 530 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_softcap_sm90.cu",
|
| 531 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_sm90.cu",
|
| 532 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_paged_split_softcap_sm90.cu",
|
| 533 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_sm90.cu",
|
| 534 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_packgqa_sm90.cu",
|
| 535 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_softcap_sm90.cu",
|
| 536 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_sm90.cu",
|
| 537 |
-
"flash-attn/instantiations/flash_fwd_hdimdiff_fp16_split_softcap_sm90.cu",
|
| 538 |
-
]
|
| 539 |
-
include = ["flash-attn"]
|
| 540 |
-
depends = ["torch", "cutlass_3_9"]
|
| 541 |
-
|
| 542 |
-
# [kernel.flash_attn_sm100]
|
| 543 |
-
# backend = "cuda"
|
| 544 |
-
# cuda-capabilities = ["8.0", "9.0a", "10.0"]
|
| 545 |
-
# cuda-flags = [
|
| 546 |
-
# "-O3",
|
| 547 |
-
# "-std=c++17",
|
| 548 |
-
# "--ftemplate-backtrace-limit=0", # To debug template code
|
| 549 |
-
# "--use_fast_math",
|
| 550 |
-
# "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
| 551 |
-
# "-DCUTLASS_ENABLE_GDC_FOR_SM90",
|
| 552 |
-
# "--expt-relaxed-constexpr",
|
| 553 |
-
# "--expt-extended-lambda",
|
| 554 |
-
# "--use_fast_math",
|
| 555 |
-
# "-DNDEBUG",
|
| 556 |
-
# ]
|
| 557 |
-
# src = [
|
| 558 |
-
# "flash-attn/block.h",
|
| 559 |
-
# "flash-attn/copy_sm90_bulk_reduce.hpp",
|
| 560 |
-
# "flash-attn/epilogue_bwd.hpp",
|
| 561 |
-
# "flash-attn/epilogue_fwd.hpp",
|
| 562 |
-
# "flash-attn/flash.h",
|
| 563 |
-
# "flash-attn/flash_bwd_kernel_sm80.h",
|
| 564 |
-
# "flash-attn/flash_bwd_kernel_sm90.h",
|
| 565 |
-
# "flash-attn/flash_bwd_launch_template.h",
|
| 566 |
-
# "flash-attn/flash_bwd_postprocess_kernel.h",
|
| 567 |
-
# "flash-attn/flash_bwd_preprocess_kernel.h",
|
| 568 |
-
# "flash-attn/flash_fwd_launch_template.h",
|
| 569 |
-
# "flash-attn/flash_fwd_kernel_sm80.h",
|
| 570 |
-
# "flash-attn/flash_fwd_kernel_sm90.h",
|
| 571 |
-
# "flash-attn/heuristics.h",
|
| 572 |
-
# "flash-attn/mainloop_bwd_sm80.hpp",
|
| 573 |
-
# "flash-attn/mainloop_fwd_sm80.hpp",
|
| 574 |
-
# "flash-attn/mainloop_bwd_sm90_tma_gmma_ws.hpp",
|
| 575 |
-
# "flash-attn/mainloop_fwd_sm90_tma_gmma_ws.hpp",
|
| 576 |
-
# "flash-attn/mask.h",
|
| 577 |
-
# "flash-attn/named_barrier.hpp",
|
| 578 |
-
# "flash-attn/pack_gqa.h",
|
| 579 |
-
# "flash-attn/paged_kv.h",
|
| 580 |
-
# "flash-attn/rotary.h",
|
| 581 |
-
# "flash-attn/sm90_pipeline_no_cluster.hpp",
|
| 582 |
-
# "flash-attn/softmax.h",
|
| 583 |
-
# "flash-attn/tile_size.h",
|
| 584 |
-
# "flash-attn/tile_scheduler.hpp",
|
| 585 |
-
#
|
| 586 |
-
# "flash-attn/instantiations/flash_fwd_hdim128_bf16_sm100.cu",
|
| 587 |
-
# ]
|
| 588 |
-
# include = ["flash-attn"]
|
| 589 |
-
# depends = ["torch", "cutlass_3_9"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flake.lock
DELETED
|
@@ -1,168 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"nodes": {
|
| 3 |
-
"flake-compat": {
|
| 4 |
-
"locked": {
|
| 5 |
-
"lastModified": 1747046372,
|
| 6 |
-
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
| 7 |
-
"owner": "edolstra",
|
| 8 |
-
"repo": "flake-compat",
|
| 9 |
-
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
| 10 |
-
"type": "github"
|
| 11 |
-
},
|
| 12 |
-
"original": {
|
| 13 |
-
"owner": "edolstra",
|
| 14 |
-
"repo": "flake-compat",
|
| 15 |
-
"type": "github"
|
| 16 |
-
}
|
| 17 |
-
},
|
| 18 |
-
"flake-compat_2": {
|
| 19 |
-
"locked": {
|
| 20 |
-
"lastModified": 1747046372,
|
| 21 |
-
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
| 22 |
-
"owner": "edolstra",
|
| 23 |
-
"repo": "flake-compat",
|
| 24 |
-
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
| 25 |
-
"type": "github"
|
| 26 |
-
},
|
| 27 |
-
"original": {
|
| 28 |
-
"owner": "edolstra",
|
| 29 |
-
"repo": "flake-compat",
|
| 30 |
-
"type": "github"
|
| 31 |
-
}
|
| 32 |
-
},
|
| 33 |
-
"flake-utils": {
|
| 34 |
-
"inputs": {
|
| 35 |
-
"systems": "systems"
|
| 36 |
-
},
|
| 37 |
-
"locked": {
|
| 38 |
-
"lastModified": 1731533236,
|
| 39 |
-
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 40 |
-
"owner": "numtide",
|
| 41 |
-
"repo": "flake-utils",
|
| 42 |
-
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 43 |
-
"type": "github"
|
| 44 |
-
},
|
| 45 |
-
"original": {
|
| 46 |
-
"owner": "numtide",
|
| 47 |
-
"repo": "flake-utils",
|
| 48 |
-
"type": "github"
|
| 49 |
-
}
|
| 50 |
-
},
|
| 51 |
-
"flake-utils_2": {
|
| 52 |
-
"inputs": {
|
| 53 |
-
"systems": "systems_2"
|
| 54 |
-
},
|
| 55 |
-
"locked": {
|
| 56 |
-
"lastModified": 1731533236,
|
| 57 |
-
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 58 |
-
"owner": "numtide",
|
| 59 |
-
"repo": "flake-utils",
|
| 60 |
-
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 61 |
-
"type": "github"
|
| 62 |
-
},
|
| 63 |
-
"original": {
|
| 64 |
-
"owner": "numtide",
|
| 65 |
-
"repo": "flake-utils",
|
| 66 |
-
"type": "github"
|
| 67 |
-
}
|
| 68 |
-
},
|
| 69 |
-
"hf-nix": {
|
| 70 |
-
"inputs": {
|
| 71 |
-
"flake-compat": "flake-compat_2",
|
| 72 |
-
"flake-utils": "flake-utils_2",
|
| 73 |
-
"nixpkgs": "nixpkgs"
|
| 74 |
-
},
|
| 75 |
-
"locked": {
|
| 76 |
-
"lastModified": 1759493343,
|
| 77 |
-
"narHash": "sha256-8fhl0gwMAnOkQbogPIVq+Fha+Yeq52FaRXfwF+F9Q+k=",
|
| 78 |
-
"owner": "huggingface",
|
| 79 |
-
"repo": "hf-nix",
|
| 80 |
-
"rev": "b1fc3a18b52447a0f24bc6884418edc5e66082b9",
|
| 81 |
-
"type": "github"
|
| 82 |
-
},
|
| 83 |
-
"original": {
|
| 84 |
-
"owner": "huggingface",
|
| 85 |
-
"repo": "hf-nix",
|
| 86 |
-
"type": "github"
|
| 87 |
-
}
|
| 88 |
-
},
|
| 89 |
-
"kernel-builder": {
|
| 90 |
-
"inputs": {
|
| 91 |
-
"flake-compat": "flake-compat",
|
| 92 |
-
"flake-utils": "flake-utils",
|
| 93 |
-
"hf-nix": "hf-nix",
|
| 94 |
-
"nixpkgs": [
|
| 95 |
-
"kernel-builder",
|
| 96 |
-
"hf-nix",
|
| 97 |
-
"nixpkgs"
|
| 98 |
-
]
|
| 99 |
-
},
|
| 100 |
-
"locked": {
|
| 101 |
-
"lastModified": 1759516823,
|
| 102 |
-
"narHash": "sha256-UJVvZHtS9c64Dm4iZRaOKWB+VHI7jzcazGH57KXWeg8=",
|
| 103 |
-
"owner": "huggingface",
|
| 104 |
-
"repo": "kernel-builder",
|
| 105 |
-
"rev": "e13610a05f67b7296be9ead89ad172a0a088a1c3",
|
| 106 |
-
"type": "github"
|
| 107 |
-
},
|
| 108 |
-
"original": {
|
| 109 |
-
"owner": "huggingface",
|
| 110 |
-
"repo": "kernel-builder",
|
| 111 |
-
"type": "github"
|
| 112 |
-
}
|
| 113 |
-
},
|
| 114 |
-
"nixpkgs": {
|
| 115 |
-
"locked": {
|
| 116 |
-
"lastModified": 1755963616,
|
| 117 |
-
"narHash": "sha256-6yD0ww/S8n+U2uPYcJZ3DRURP8Kx036GRpR2uPNZroE=",
|
| 118 |
-
"owner": "nixos",
|
| 119 |
-
"repo": "nixpkgs",
|
| 120 |
-
"rev": "73e96df7cff5783f45e21342a75a1540c4eddce4",
|
| 121 |
-
"type": "github"
|
| 122 |
-
},
|
| 123 |
-
"original": {
|
| 124 |
-
"owner": "nixos",
|
| 125 |
-
"ref": "nixos-unstable-small",
|
| 126 |
-
"repo": "nixpkgs",
|
| 127 |
-
"type": "github"
|
| 128 |
-
}
|
| 129 |
-
},
|
| 130 |
-
"root": {
|
| 131 |
-
"inputs": {
|
| 132 |
-
"kernel-builder": "kernel-builder"
|
| 133 |
-
}
|
| 134 |
-
},
|
| 135 |
-
"systems": {
|
| 136 |
-
"locked": {
|
| 137 |
-
"lastModified": 1681028828,
|
| 138 |
-
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 139 |
-
"owner": "nix-systems",
|
| 140 |
-
"repo": "default",
|
| 141 |
-
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 142 |
-
"type": "github"
|
| 143 |
-
},
|
| 144 |
-
"original": {
|
| 145 |
-
"owner": "nix-systems",
|
| 146 |
-
"repo": "default",
|
| 147 |
-
"type": "github"
|
| 148 |
-
}
|
| 149 |
-
},
|
| 150 |
-
"systems_2": {
|
| 151 |
-
"locked": {
|
| 152 |
-
"lastModified": 1681028828,
|
| 153 |
-
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 154 |
-
"owner": "nix-systems",
|
| 155 |
-
"repo": "default",
|
| 156 |
-
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 157 |
-
"type": "github"
|
| 158 |
-
},
|
| 159 |
-
"original": {
|
| 160 |
-
"owner": "nix-systems",
|
| 161 |
-
"repo": "default",
|
| 162 |
-
"type": "github"
|
| 163 |
-
}
|
| 164 |
-
}
|
| 165 |
-
},
|
| 166 |
-
"root": "root",
|
| 167 |
-
"version": 7
|
| 168 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flake.nix
DELETED
|
@@ -1,35 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
description = "Flake for Hopper Flash Attention kernel";
|
| 3 |
-
|
| 4 |
-
inputs = {
|
| 5 |
-
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
-
};
|
| 7 |
-
|
| 8 |
-
outputs =
|
| 9 |
-
{
|
| 10 |
-
self,
|
| 11 |
-
kernel-builder,
|
| 12 |
-
}:
|
| 13 |
-
kernel-builder.lib.genFlakeOutputs {
|
| 14 |
-
inherit self;
|
| 15 |
-
path = ./.;
|
| 16 |
-
# Building with CUDA later than 12.4 fails with:
|
| 17 |
-
#
|
| 18 |
-
# error: 'ptxas' died due to signal 11 (Invalid memory reference)
|
| 19 |
-
#
|
| 20 |
-
# So, build for 12.4 only and copy to all the other build variants
|
| 21 |
-
# by hand (which works fine thanks to backward compat).
|
| 22 |
-
torchVersions = _: [
|
| 23 |
-
{
|
| 24 |
-
torchVersion = "2.9";
|
| 25 |
-
cudaVersion = "12.4";
|
| 26 |
-
cxx11Abi = true;
|
| 27 |
-
systems = [
|
| 28 |
-
"x86_64-linux"
|
| 29 |
-
"aarch64-linux"
|
| 30 |
-
];
|
| 31 |
-
bundleBuild = true;
|
| 32 |
-
}
|
| 33 |
-
];
|
| 34 |
-
};
|
| 35 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/block.h
DELETED
|
@@ -1,139 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
namespace flash {
|
| 8 |
-
|
| 9 |
-
template <class SeqlenInfo_t, int kBlockM, int kBlockN, bool Is_causal, bool Is_local, bool PackGQA=false, bool Split=false>
|
| 10 |
-
struct BlockMN {
|
| 11 |
-
|
| 12 |
-
static
|
| 13 |
-
CUTLASS_DEVICE
|
| 14 |
-
cute::tuple<int, int> get_n_block_min_max(
|
| 15 |
-
SeqlenInfo_t const& seqlen_info,
|
| 16 |
-
int const m_block, int const bidb, int const split_idx, int const num_splits,
|
| 17 |
-
int const window_size_left, int const window_size_right,
|
| 18 |
-
cutlass::FastDivmod const& attention_chunk_divmod,
|
| 19 |
-
cutlass::FastDivmod const& qhead_per_khead_divmod) {
|
| 20 |
-
|
| 21 |
-
int const seqlen_k = seqlen_info.seqlen_k;
|
| 22 |
-
int const seqlen_q = seqlen_info.seqlen_q;
|
| 23 |
-
int n_block_max = cute::ceil_div(seqlen_k, kBlockN);
|
| 24 |
-
if constexpr (Is_causal || Is_local) {
|
| 25 |
-
int m_idx_max = (m_block + 1) * kBlockM;
|
| 26 |
-
// TODO: check off-by-1 error
|
| 27 |
-
if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
|
| 28 |
-
int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
|
| 29 |
-
int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right;
|
| 30 |
-
if (Is_local && attention_chunk_divmod.divisor > 0) {
|
| 31 |
-
n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx));
|
| 32 |
-
}
|
| 33 |
-
n_block_max = std::min(n_block_max, cute::ceil_div(n_idx_right, kBlockN));
|
| 34 |
-
}
|
| 35 |
-
int n_block_min = 0;
|
| 36 |
-
if constexpr (Is_local) {
|
| 37 |
-
int m_idx_min = m_block * kBlockM;
|
| 38 |
-
if (PackGQA) { m_idx_min = qhead_per_khead_divmod.divide(m_idx_min); }
|
| 39 |
-
int const n_idx = m_idx_min + seqlen_k - seqlen_q;
|
| 40 |
-
int n_idx_left = n_idx - window_size_left;
|
| 41 |
-
if (attention_chunk_divmod.divisor > 0) {
|
| 42 |
-
n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx));
|
| 43 |
-
}
|
| 44 |
-
n_block_min = std::max(int(0), n_idx_left / kBlockN);
|
| 45 |
-
}
|
| 46 |
-
// if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
|
| 47 |
-
if constexpr (Split) {
|
| 48 |
-
uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
|
| 49 |
-
int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
|
| 50 |
-
int split_idx_actual = split_idx & 0x0000FFFF;
|
| 51 |
-
int num_splits_actual = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
|
| 52 |
-
int num_n_blocks_per_split = n_block_max <= n_block_min ? 0 : cute::ceil_div(n_block_max - n_block_min, num_splits_actual);
|
| 53 |
-
n_block_min = n_block_min + split_idx_actual * num_n_blocks_per_split;
|
| 54 |
-
n_block_max = std::min(n_block_min + num_n_blocks_per_split, n_block_max);
|
| 55 |
-
// if (threadIdx.x == 128) { printf("Inside, bid.x = %d, bid.y = %d, bid.z = %d, split_idx = %d, num_splits_dynamic = %d, num_splits_actual = %d, num_n_blocks_per_split = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.x, blockIdx.y, blockIdx.z, split_idx, num_splits_dynamic, num_splits_actual, num_n_blocks_per_split, n_block_min, n_block_max); }
|
| 56 |
-
}
|
| 57 |
-
// if (threadIdx.x == 128) { printf("After split, inside, bid.y = %d, bid.z = %d, split_idx = %d, n_block_min: %d, n_block_max: %d\n", blockIdx.y, blockIdx.z, split_idx, n_block_min, n_block_max); }
|
| 58 |
-
return {n_block_min, n_block_max};
|
| 59 |
-
}
|
| 60 |
-
|
| 61 |
-
static
|
| 62 |
-
CUTLASS_DEVICE
|
| 63 |
-
cute::tuple<int, int> get_n_block_k_new_min_max(
|
| 64 |
-
SeqlenInfo_t const& seqlen_info,
|
| 65 |
-
int const m_block, int const bidb, int const split_idx, int const num_splits,
|
| 66 |
-
int const window_size_left, int const window_size_right,
|
| 67 |
-
cutlass::FastDivmod const& attention_chunk_divmod,
|
| 68 |
-
cutlass::FastDivmod const& qhead_per_khead_divmod) {
|
| 69 |
-
|
| 70 |
-
auto [n_block_min, n_block_max] = get_n_block_min_max(
|
| 71 |
-
seqlen_info, m_block, bidb, split_idx, num_splits,
|
| 72 |
-
window_size_left, window_size_right, attention_chunk_divmod, qhead_per_khead_divmod);
|
| 73 |
-
int const idx_k_new_min = std::max(n_block_min * kBlockN - seqlen_info.seqlen_k_og, 0);
|
| 74 |
-
int const idx_k_new_max = std::min(n_block_max * kBlockN - seqlen_info.seqlen_k_og, seqlen_info.seqlen_k_new);
|
| 75 |
-
int const n_block_new_min = idx_k_new_min / kBlockN;
|
| 76 |
-
int const n_block_new_max = idx_k_new_max > idx_k_new_min ? cute::ceil_div(idx_k_new_max, kBlockN) : n_block_new_min;
|
| 77 |
-
// if (threadIdx.x == 128 && m_block == 0) { printf("bidb = %d, seqlen_k_new = %d, seqlen_k_og = %d, n_block_min = %d, n_block_max = %d, idx_k_new_min = %d, idx_k_new_max = %d, n_block_new_min = %d, n_block_new_max = %d\n", bidb, seqlen_k_new, seqlen_k_og, n_block_min, n_block_max, idx_k_new_min, idx_k_new_max, n_block_new_min, n_block_new_max);}
|
| 78 |
-
return {n_block_new_min, n_block_new_max};
|
| 79 |
-
}
|
| 80 |
-
|
| 81 |
-
static
|
| 82 |
-
CUTLASS_DEVICE
|
| 83 |
-
cute::tuple<int, int> get_m_block_min_max(
|
| 84 |
-
SeqlenInfo_t const& seqlen_info,
|
| 85 |
-
int const n_block, int const bidb,
|
| 86 |
-
int const window_size_left, int const window_size_right, int const sink_token_length) {
|
| 87 |
-
// TODO: support attention_chunk
|
| 88 |
-
int const seqlen_q = seqlen_info.seqlen_q;
|
| 89 |
-
int const seqlen_k = seqlen_info.seqlen_k;
|
| 90 |
-
int m_block_max = cute::ceil_div(seqlen_q, kBlockM);
|
| 91 |
-
if constexpr (Is_local) {
|
| 92 |
-
if (n_block >= cute::ceil_div(sink_token_length, kBlockN)) {
|
| 93 |
-
m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + seqlen_q - seqlen_k + window_size_left, kBlockM));
|
| 94 |
-
}
|
| 95 |
-
}
|
| 96 |
-
int m_block_min = 0;
|
| 97 |
-
if constexpr (Is_causal || Is_local) {
|
| 98 |
-
m_block_min = std::max(m_block_min, (n_block * kBlockN + seqlen_q - seqlen_k - window_size_right) / kBlockM);
|
| 99 |
-
}
|
| 100 |
-
return {m_block_min, m_block_max};
|
| 101 |
-
}
|
| 102 |
-
|
| 103 |
-
// If we have separate iterations with causal or local masking at the start, where do we stop
|
| 104 |
-
static
|
| 105 |
-
CUTLASS_DEVICE
|
| 106 |
-
int get_n_block_min_causal_local_mask(
|
| 107 |
-
SeqlenInfo_t const& seqlen_info,
|
| 108 |
-
int const m_block, int const n_block_min, int const window_size_right,
|
| 109 |
-
cutlass::FastDivmod const& attention_chunk_divmod,
|
| 110 |
-
cutlass::FastDivmod const& qhead_per_khead_divmod) {
|
| 111 |
-
int const m_idx_min = !PackGQA ? m_block * kBlockM : qhead_per_khead_divmod.divide(m_block * kBlockM);
|
| 112 |
-
int const n_idx = m_idx_min + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
|
| 113 |
-
int n_idx_right = !Is_local ? n_idx : n_idx + window_size_right;
|
| 114 |
-
if (Is_local && attention_chunk_divmod.divisor > 0) {
|
| 115 |
-
n_idx_right = std::min(n_idx_right, flash::round_up(attention_chunk_divmod, n_idx));
|
| 116 |
-
}
|
| 117 |
-
return std::max(n_block_min, n_idx_right / kBlockN);
|
| 118 |
-
}
|
| 119 |
-
|
| 120 |
-
// If we have separate iterations with local masking at the end, where do we stop the non-masked iterations
|
| 121 |
-
static
|
| 122 |
-
CUTLASS_DEVICE
|
| 123 |
-
int get_n_block_min_before_local_mask(
|
| 124 |
-
SeqlenInfo_t const& seqlen_info,
|
| 125 |
-
int const m_block, int const n_block_min, int const window_size_left,
|
| 126 |
-
cutlass::FastDivmod const& attention_chunk_divmod,
|
| 127 |
-
cutlass::FastDivmod const& qhead_per_khead_divmod) {
|
| 128 |
-
int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1;
|
| 129 |
-
int const n_idx = m_idx_max + seqlen_info.seqlen_k - seqlen_info.seqlen_q;
|
| 130 |
-
int n_idx_left = !Is_local ? n_idx : n_idx - window_size_left;
|
| 131 |
-
if (Is_local && attention_chunk_divmod.divisor > 0) {
|
| 132 |
-
n_idx_left = std::max(n_idx_left, flash::round_down(attention_chunk_divmod, n_idx));
|
| 133 |
-
}
|
| 134 |
-
return !Is_local ? n_block_min : std::max(n_block_min, cute::ceil_div(n_idx_left, kBlockN));
|
| 135 |
-
}
|
| 136 |
-
|
| 137 |
-
};
|
| 138 |
-
|
| 139 |
-
} // namespace flash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/copy_sm90_bulk_reduce.hpp
DELETED
|
@@ -1,49 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
#include<cute/arch/copy_sm90_tma.hpp>
|
| 8 |
-
|
| 9 |
-
namespace cute
|
| 10 |
-
{
|
| 11 |
-
|
| 12 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 13 |
-
|
| 14 |
-
struct SM90_BULK_REDUCE_ADD
|
| 15 |
-
{
|
| 16 |
-
CUTE_HOST_DEVICE static void
|
| 17 |
-
copy(float const* smem_ptr,
|
| 18 |
-
float * gmem_ptr, int32_t store_bytes)
|
| 19 |
-
{
|
| 20 |
-
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
| 21 |
-
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
| 22 |
-
asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.add.f32 [%0], [%1], %2;\n"
|
| 23 |
-
:
|
| 24 |
-
: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes)
|
| 25 |
-
: "memory");
|
| 26 |
-
#else
|
| 27 |
-
CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED.");
|
| 28 |
-
#endif
|
| 29 |
-
}
|
| 30 |
-
|
| 31 |
-
CUTE_HOST_DEVICE static void
|
| 32 |
-
copy(float const* smem_ptr,
|
| 33 |
-
float * gmem_ptr, int32_t store_bytes, uint64_t cache_hint)
|
| 34 |
-
{
|
| 35 |
-
#if defined(CUTE_ARCH_TMA_SM90_ENABLED)
|
| 36 |
-
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr);
|
| 37 |
-
asm volatile("cp.reduce.async.bulk.global.shared::cta.bulk_group.L2::cache_hint.add.f32 [%0], [%1], %2, %3;\n"
|
| 38 |
-
:
|
| 39 |
-
: "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes), "l"(cache_hint)
|
| 40 |
-
: "memory");
|
| 41 |
-
#else
|
| 42 |
-
CUTE_INVALID_CONTROL_PATH("Trying to use BULK_REDUCE_ADD without CUTE_ARCH_TMA_SM90_ENABLED.");
|
| 43 |
-
#endif
|
| 44 |
-
}
|
| 45 |
-
};
|
| 46 |
-
|
| 47 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 48 |
-
|
| 49 |
-
} // end namespace cute
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/cuda_check.h
DELETED
|
@@ -1,19 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
#include <assert.h>
|
| 8 |
-
#include <stdlib.h>
|
| 9 |
-
|
| 10 |
-
#define CHECK_CUDA(call) \
|
| 11 |
-
do { \
|
| 12 |
-
cudaError_t status_ = call; \
|
| 13 |
-
if (status_ != cudaSuccess) { \
|
| 14 |
-
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
|
| 15 |
-
exit(1); \
|
| 16 |
-
} \
|
| 17 |
-
} while(0)
|
| 18 |
-
|
| 19 |
-
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/epilogue_bwd.hpp
DELETED
|
@@ -1,533 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
#include "cutlass/cutlass.h"
|
| 8 |
-
#include "cutlass/barrier.h"
|
| 9 |
-
#include "cute/tensor.hpp"
|
| 10 |
-
|
| 11 |
-
#include "cutlass/gemm/collective/builders/sm90_common.inl"
|
| 12 |
-
|
| 13 |
-
#include "seqlen.h"
|
| 14 |
-
#include "named_barrier.hpp"
|
| 15 |
-
#include "utils.h"
|
| 16 |
-
|
| 17 |
-
namespace flash {
|
| 18 |
-
|
| 19 |
-
using namespace cute;
|
| 20 |
-
|
| 21 |
-
template <class TileShape_MNK_, class Element_, class ArchTag_,
|
| 22 |
-
int NumEpilogueThreads_, bool Varlen_, bool dKV_swapAB_, int AtomLayoutKdKV=1>
|
| 23 |
-
struct CollectiveEpilogueBwd {
|
| 24 |
-
|
| 25 |
-
using TileShape_MNK = TileShape_MNK_;
|
| 26 |
-
using Element = Element_;
|
| 27 |
-
using ArchTag = ArchTag_;
|
| 28 |
-
static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
|
| 29 |
-
static constexpr bool Varlen = Varlen_;
|
| 30 |
-
static constexpr bool dKV_swapAB = dKV_swapAB_;
|
| 31 |
-
static constexpr bool Use_TMA = !Varlen && ArchTag::kMinComputeCapability >= 90;
|
| 32 |
-
|
| 33 |
-
static_assert(ArchTag::kMinComputeCapability >= 80);
|
| 34 |
-
|
| 35 |
-
using GmemTiledCopydKVTMA = cute::SM90_TMA_STORE;
|
| 36 |
-
|
| 37 |
-
// These are for storing the output tensor without TMA (e.g., for setting output to zero)
|
| 38 |
-
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
| 39 |
-
static_assert(get<2>(TileShape_MNK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
|
| 40 |
-
static constexpr int kHeadDim = get<2>(TileShape_MNK{});
|
| 41 |
-
static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, NumEpilogueThreads);
|
| 42 |
-
static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
|
| 43 |
-
using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
| 44 |
-
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
| 45 |
-
using GmemTiledCopydKV = decltype(
|
| 46 |
-
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
| 47 |
-
GmemLayoutAtom{},
|
| 48 |
-
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per store
|
| 49 |
-
|
| 50 |
-
using SmemLayoutAtomdKVTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
| 51 |
-
// TODO: do we have to change this if dKV_swapAB is true?
|
| 52 |
-
decltype(cute::get<1>(TileShape_MNK{})), Int<CUTE_STATIC_V(cute::get<2>(TileShape_MNK{})) / AtomLayoutKdKV>>());
|
| 53 |
-
using SmemLayoutdKVTMA = decltype(tile_to_shape(SmemLayoutAtomdKVTMA{}, select<1, 2>(TileShape_MNK{})));
|
| 54 |
-
using SmemLayoutdKVtTMA =
|
| 55 |
-
decltype(cute::composition(SmemLayoutdKVTMA{},
|
| 56 |
-
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
|
| 57 |
-
make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
|
| 58 |
-
|
| 59 |
-
// If we don't use TMA
|
| 60 |
-
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : (kHeadDim % 32 == 0 ? 32 : 16);
|
| 61 |
-
static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
|
| 62 |
-
using SmemLayoutAtomdKVSTG =
|
| 63 |
-
decltype(composition(Swizzle<kSwizzle, 3, 3>{},
|
| 64 |
-
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
|
| 65 |
-
Stride<Int<kBlockKSmem>, _1>>{}));
|
| 66 |
-
|
| 67 |
-
using SmemLayoutAtomdKV = std::conditional_t<Use_TMA, SmemLayoutAtomdKVTMA, SmemLayoutAtomdKVSTG>;
|
| 68 |
-
using SmemLayoutdKV = decltype(tile_to_shape(SmemLayoutAtomdKV{}, select<1, 2>(TileShape_MNK{})));
|
| 69 |
-
using SmemLayoutdKVt =
|
| 70 |
-
decltype(cute::composition(SmemLayoutdKV{},
|
| 71 |
-
make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})),
|
| 72 |
-
make_stride(decltype(get<1>(TileShape_MNK{})){}, _1{}))));
|
| 73 |
-
|
| 74 |
-
using SmemCopyAtomdKV = Copy_Atom<
|
| 75 |
-
std::conditional_t<
|
| 76 |
-
ArchTag::kMinComputeCapability >= 90,
|
| 77 |
-
std::conditional_t<!dKV_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
|
| 78 |
-
AutoVectorizingCopyWithAssumedAlignment<128>
|
| 79 |
-
>,
|
| 80 |
-
Element>;
|
| 81 |
-
|
| 82 |
-
static constexpr size_t SmemAlignmentdKV = ArchTag::kMinComputeCapability >= 90 ? cutlass::detail::alignment_for_swizzle(SmemLayoutdKV{}) : 128;
|
| 83 |
-
static_assert(SmemAlignmentdKV >= 128, "Require at least 128B alignment");
|
| 84 |
-
|
| 85 |
-
struct TensorStorage : cute::aligned_struct<SmemAlignmentdKV> {
|
| 86 |
-
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dk;
|
| 87 |
-
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dv;
|
| 88 |
-
};
|
| 89 |
-
|
| 90 |
-
using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_k, d, head, batch)
|
| 91 |
-
using StridedKV = cute::Stride<int64_t, _1, int64_t, int64_t>;
|
| 92 |
-
|
| 93 |
-
using TMA_dKV = std::conditional_t<
|
| 94 |
-
Use_TMA,
|
| 95 |
-
decltype(make_tma_copy(
|
| 96 |
-
GmemTiledCopydKVTMA{},
|
| 97 |
-
make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapedKV{}, StridedKV{}),
|
| 98 |
-
SmemLayoutdKVTMA{},
|
| 99 |
-
select<1, 2>(TileShape_MNK{}),
|
| 100 |
-
_1{})), // no mcast for dKV
|
| 101 |
-
std::nullptr_t
|
| 102 |
-
>;
|
| 103 |
-
|
| 104 |
-
// Host side kernel arguments
|
| 105 |
-
struct Arguments {
|
| 106 |
-
Element* ptr_dK;
|
| 107 |
-
ShapedKV const shape_dK;
|
| 108 |
-
StridedKV const stride_dK;
|
| 109 |
-
Element* ptr_dV;
|
| 110 |
-
ShapedKV const shape_dV;
|
| 111 |
-
StridedKV const stride_dV;
|
| 112 |
-
int const num_heads_q;
|
| 113 |
-
int* dk_semaphore;
|
| 114 |
-
int* dv_semaphore;
|
| 115 |
-
int const* cu_seqlens;
|
| 116 |
-
int const* seqused;
|
| 117 |
-
};
|
| 118 |
-
|
| 119 |
-
// Device side kernel params
|
| 120 |
-
struct Params {
|
| 121 |
-
Element* ptr_dK;
|
| 122 |
-
ShapedKV const shape_dK;
|
| 123 |
-
StridedKV const stride_dK;
|
| 124 |
-
Element* ptr_dV;
|
| 125 |
-
ShapedKV const shape_dV;
|
| 126 |
-
StridedKV const stride_dV;
|
| 127 |
-
TMA_dKV tma_store_dK, tma_store_dV;
|
| 128 |
-
int const* cu_seqlens = nullptr;
|
| 129 |
-
int const* seqused = nullptr;
|
| 130 |
-
};
|
| 131 |
-
|
| 132 |
-
static Params
|
| 133 |
-
to_underlying_arguments(Arguments const& args) {
|
| 134 |
-
Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK);
|
| 135 |
-
Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dV, args.stride_dV);
|
| 136 |
-
TMA_dKV tma_store_dK = [&] {
|
| 137 |
-
if constexpr (Use_TMA) {
|
| 138 |
-
return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV
|
| 139 |
-
} else {
|
| 140 |
-
return nullptr;
|
| 141 |
-
}
|
| 142 |
-
}();
|
| 143 |
-
TMA_dKV tma_store_dV = [&] {
|
| 144 |
-
if constexpr (Use_TMA) {
|
| 145 |
-
return make_tma_copy(GmemTiledCopydKVTMA{}, mdV, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV
|
| 146 |
-
} else {
|
| 147 |
-
return nullptr;
|
| 148 |
-
}
|
| 149 |
-
}();
|
| 150 |
-
return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.shape_dV, args.stride_dV,
|
| 151 |
-
tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused};
|
| 152 |
-
}
|
| 153 |
-
|
| 154 |
-
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
| 155 |
-
CUTLASS_DEVICE
|
| 156 |
-
static void prefetch_tma_descriptors(Params const& params) {
|
| 157 |
-
if constexpr (Use_TMA) {
|
| 158 |
-
cute::prefetch_tma_descriptor(params.tma_store_dK.get_tma_descriptor());
|
| 159 |
-
cute::prefetch_tma_descriptor(params.tma_store_dV.get_tma_descriptor());
|
| 160 |
-
}
|
| 161 |
-
}
|
| 162 |
-
|
| 163 |
-
template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
|
| 164 |
-
CUTLASS_DEVICE void
|
| 165 |
-
store(Params const& params,
|
| 166 |
-
FrgTensorO const& tdKrdK,
|
| 167 |
-
FrgTensorO const& tdVrdV,
|
| 168 |
-
SharedStorage& shared_storage,
|
| 169 |
-
TiledMma tiled_mma,
|
| 170 |
-
int thread_idx,
|
| 171 |
-
cute::tuple<int32_t, int32_t, int32_t> const& block_coord
|
| 172 |
-
) {
|
| 173 |
-
|
| 174 |
-
auto [n_block, bidh, bidb] = block_coord;
|
| 175 |
-
Tensor sdK = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKV{}));
|
| 176 |
-
Tensor sdV = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKV{}));
|
| 177 |
-
Tensor sdKt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dk.data()), SmemLayoutdKVt{}));
|
| 178 |
-
Tensor sdVt = cute::as_position_independent_swizzle_tensor(make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dv.data()), SmemLayoutdKVt{}));
|
| 179 |
-
auto smem_tiled_copy_dKV = make_tiled_copy_C(SmemCopyAtomdKV{}, tiled_mma);
|
| 180 |
-
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(thread_idx);
|
| 181 |
-
|
| 182 |
-
Tensor tdVrdV_out = make_tensor_like<Element>(tdVrdV);
|
| 183 |
-
flash::convert_type_out(tdVrdV, tdVrdV_out);
|
| 184 |
-
Tensor tdKrdK_out = make_tensor_like<Element>(tdKrdK);
|
| 185 |
-
flash::convert_type_out(tdKrdK, tdKrdK_out);
|
| 186 |
-
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N)
|
| 187 |
-
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N)
|
| 188 |
-
// if (blockIdx.x == 0 && threadIdx.x == 128) { print(smem_thr_copy_dKV); print(sdK); printf("\n"); print(sdKt); printf("\n"); }
|
| 189 |
-
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdK, sdKt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
| 190 |
-
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(cute::conditional_return<!dKV_swapAB>(sdV, sdVt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
| 191 |
-
|
| 192 |
-
// Make sure all WGs have finished reading K and V
|
| 193 |
-
flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
| 194 |
-
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
|
| 195 |
-
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
|
| 196 |
-
if constexpr (Use_TMA) {
|
| 197 |
-
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
|
| 198 |
-
cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
|
| 199 |
-
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
| 200 |
-
|
| 201 |
-
Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK);
|
| 202 |
-
Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dV);
|
| 203 |
-
Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
|
| 204 |
-
Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
|
| 205 |
-
auto block_tma_dK = params.tma_store_dK.get_slice(_0{});
|
| 206 |
-
auto block_tma_dV = params.tma_store_dV.get_slice(_0{});
|
| 207 |
-
Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K)
|
| 208 |
-
Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K)
|
| 209 |
-
Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K)
|
| 210 |
-
Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
|
| 211 |
-
int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
|
| 212 |
-
if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
|
| 213 |
-
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
|
| 214 |
-
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
| 215 |
-
if (cute::elect_one_sync()) {
|
| 216 |
-
cute::copy(params.tma_store_dV, tdVsdV, tdVgdV);
|
| 217 |
-
cute::copy(params.tma_store_dK, tdKsdK, tdKgdK);
|
| 218 |
-
tma_store_arrive();
|
| 219 |
-
}
|
| 220 |
-
}
|
| 221 |
-
tma_store_wait<0>();
|
| 222 |
-
// // Tell warp 0 that smem_k and smem_v are ready
|
| 223 |
-
// cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
|
| 224 |
-
|
| 225 |
-
} else {
|
| 226 |
-
flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
| 227 |
-
static constexpr int kBlockN = get<1>(TileShape_MNK{});
|
| 228 |
-
flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
|
| 229 |
-
bool const is_varlen = Varlen && params.cu_seqlens;
|
| 230 |
-
Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
|
| 231 |
-
Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
|
| 232 |
-
Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
|
| 233 |
-
Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
|
| 234 |
-
|
| 235 |
-
GmemTiledCopydKV gmem_tiled_copy_dKV;
|
| 236 |
-
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
|
| 237 |
-
Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
|
| 238 |
-
Tensor tdKVsdV = gmem_thr_copy_dKV.partition_S(sdV); // (TMA, TMA_M, TMA_K)
|
| 239 |
-
Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
|
| 240 |
-
Tensor tdKVsdK = gmem_thr_copy_dKV.partition_S(sdK); // (TMA, TMA_M, TMA_K)
|
| 241 |
-
Tensor tdKVrdV = make_fragment_like(tdKVgdV);
|
| 242 |
-
Tensor tdKVrdK = make_fragment_like(tdKVgdK);
|
| 243 |
-
Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
| 244 |
-
// Repeat the partitioning with identity layouts
|
| 245 |
-
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
|
| 246 |
-
Tensor tdKVpdV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));
|
| 247 |
-
Tensor tdKVpdK = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));
|
| 248 |
-
#pragma unroll
|
| 249 |
-
for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); }
|
| 250 |
-
#pragma unroll
|
| 251 |
-
for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
|
| 252 |
-
// Need to check OOB when reading from smem if kBlockN isn't evenly tiled
|
| 253 |
-
static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;
|
| 254 |
-
flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
|
| 255 |
-
gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdV, kBlockN);
|
| 256 |
-
flash::copy</*Is_even_MN=*/EvenN, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
|
| 257 |
-
gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdK, kBlockN);
|
| 258 |
-
// // Tell warp 0 that smem_k and smem_v are ready
|
| 259 |
-
// cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v
|
| 260 |
-
// flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
|
| 261 |
-
// Construct identity layout for gdKV
|
| 262 |
-
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
| 263 |
-
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
| 264 |
-
gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)
|
| 265 |
-
);
|
| 266 |
-
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
| 267 |
-
gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdK, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN)
|
| 268 |
-
);
|
| 269 |
-
}
|
| 270 |
-
}
|
| 271 |
-
|
| 272 |
-
CUTLASS_DEVICE void
|
| 273 |
-
store_tail() {
|
| 274 |
-
// if constexpr (Use_TMA) { tma_store_wait<0>(); }
|
| 275 |
-
}
|
| 276 |
-
|
| 277 |
-
// Write 0 to dK and dV
|
| 278 |
-
CUTLASS_DEVICE void
|
| 279 |
-
store_zero(
|
| 280 |
-
Params const& params,
|
| 281 |
-
int thread_idx,
|
| 282 |
-
cute::tuple<int32_t, int32_t, int32_t> const& block_coord
|
| 283 |
-
) {
|
| 284 |
-
static constexpr int kBlockN = get<1>(TileShape_MNK{});
|
| 285 |
-
auto [n_block, bidh, bidb] = block_coord;
|
| 286 |
-
flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dK), params.cu_seqlens, params.seqused};
|
| 287 |
-
bool const is_varlen = Varlen && params.cu_seqlens;
|
| 288 |
-
Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0);
|
| 289 |
-
Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
|
| 290 |
-
Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0);
|
| 291 |
-
Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K)
|
| 292 |
-
|
| 293 |
-
GmemTiledCopydKV gmem_tiled_copy_dKV;
|
| 294 |
-
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(thread_idx);
|
| 295 |
-
Tensor tdKVgdK = gmem_thr_copy_dKV.partition_D(gdK);
|
| 296 |
-
Tensor tdKVgdV = gmem_thr_copy_dKV.partition_D(gdV);
|
| 297 |
-
Tensor tdKVrdKV = make_fragment_like(tdKVgdK);
|
| 298 |
-
clear(tdKVrdKV);
|
| 299 |
-
// Construct identity layout for gdKV
|
| 300 |
-
Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
| 301 |
-
// Repeat the partitioning with identity layouts
|
| 302 |
-
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
|
| 303 |
-
Tensor tdKVpdK = make_tensor<bool>(make_shape(size<2>(tdKVgdK)));
|
| 304 |
-
Tensor tdKVpdV = make_tensor<bool>(make_shape(size<2>(tdKVgdV)));
|
| 305 |
-
#pragma unroll
|
| 306 |
-
for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); }
|
| 307 |
-
#pragma unroll
|
| 308 |
-
for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); }
|
| 309 |
-
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
| 310 |
-
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
| 311 |
-
gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdK, seqlen_info.seqlen - n_block * kBlockN
|
| 312 |
-
);
|
| 313 |
-
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
| 314 |
-
gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdV, seqlen_info.seqlen - n_block * kBlockN
|
| 315 |
-
);
|
| 316 |
-
}
|
| 317 |
-
|
| 318 |
-
};
|
| 319 |
-
|
| 320 |
-
template <class TileShape_MNK_, class ElementAccum, class ArchTag_,
|
| 321 |
-
int NumEpilogueThreads_, bool Varlen_, bool Deterministic>
|
| 322 |
-
struct CollectiveEpilogueBwdGQA {
|
| 323 |
-
|
| 324 |
-
using TileShape_MNK = TileShape_MNK_;
|
| 325 |
-
using Element = ElementAccum;
|
| 326 |
-
using ArchTag = ArchTag_;
|
| 327 |
-
static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
|
| 328 |
-
static constexpr bool Varlen = Varlen_;
|
| 329 |
-
static constexpr bool Use_TMA = ArchTag::kMinComputeCapability >= 90;
|
| 330 |
-
|
| 331 |
-
static_assert(ArchTag::kMinComputeCapability >= 80);
|
| 332 |
-
|
| 333 |
-
static constexpr int kBlockN = get<1>(TileShape_MNK{});
|
| 334 |
-
static constexpr int kHeadDim = get<2>(TileShape_MNK{});
|
| 335 |
-
static_assert(NumEpilogueThreads % cutlass::NumThreadsPerWarp == 0, "NumEpilogueThreads must be a multiple of NumThreadsPerWarp");
|
| 336 |
-
static constexpr int NumWarpGroups = NumEpilogueThreads / cutlass::NumThreadsPerWarpGroup;
|
| 337 |
-
// Thread layout, 256 or 384 threads per row
|
| 338 |
-
// We split into NumWarpGroups so that we can use the same postprocessing kernel as dQ
|
| 339 |
-
using R2SLayoutAtomdKVaccum = Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumWarpGroups>>>;
|
| 340 |
-
using R2STiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdKVaccum{},
|
| 341 |
-
Layout<Shape < _4>>{})); // Val layout, 4 vals per store
|
| 342 |
-
// For Sm80
|
| 343 |
-
using R2GLayoutAtomdKVaccum = Layout<Shape<Int<NumEpilogueThreads>>>;
|
| 344 |
-
using R2GTiledCopydKVaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2GLayoutAtomdKVaccum{},
|
| 345 |
-
Layout<Shape < _1>>{})); // Val layout, 1 vals per store
|
| 346 |
-
|
| 347 |
-
using SmemLayoutdKVaccum = Layout<Shape<Int<kBlockN * kHeadDim / NumWarpGroups>, Int<NumWarpGroups>>>;
|
| 348 |
-
using SmemLayoutdKVaccumFlat = Layout<Shape<Int<kBlockN * kHeadDim>>>;
|
| 349 |
-
|
| 350 |
-
// Strangely without this SmemAlignment, the total smem for hdim 128 (80 x 128) is 228KB even though we
|
| 351 |
-
// only need 227KB. We use the same alignment as the non-GQA epilogue to avoid this issue.
|
| 352 |
-
static constexpr int SmemAlignment = kHeadDim % 64 == 0 ? 1024 : (kHeadDim % 32 == 0 ? 512 : 256);
|
| 353 |
-
struct TensorStorageTMA : cute::aligned_struct<SmemAlignment> {
|
| 354 |
-
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdKVaccum>, SmemAlignment> smem_dkv;
|
| 355 |
-
};
|
| 356 |
-
struct TensorStorageSTG {
|
| 357 |
-
cute::array<ElementAccum, 0> smem_dkv;
|
| 358 |
-
};
|
| 359 |
-
using TensorStorage = std::conditional_t<Use_TMA, TensorStorageTMA, TensorStorageSTG>;
|
| 360 |
-
|
| 361 |
-
using ShapedKV = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_k_rounded * d, head, batch)
|
| 362 |
-
using StridedKV = cute::Stride<_1, int64_t, int64_t>;
|
| 363 |
-
|
| 364 |
-
// Host side kernel arguments
|
| 365 |
-
struct Arguments {
|
| 366 |
-
ElementAccum* ptr_dKaccum;
|
| 367 |
-
ShapedKV const shape_dKaccum;
|
| 368 |
-
StridedKV const stride_dKaccum;
|
| 369 |
-
ElementAccum* ptr_dVaccum;
|
| 370 |
-
ShapedKV const shape_dVaccum;
|
| 371 |
-
StridedKV const stride_dVaccum;
|
| 372 |
-
int num_heads_q;
|
| 373 |
-
int* dk_semaphore;
|
| 374 |
-
int* dv_semaphore;
|
| 375 |
-
int const* cu_seqlens;
|
| 376 |
-
int const* seqused;
|
| 377 |
-
};
|
| 378 |
-
|
| 379 |
-
// Device side kernel params
|
| 380 |
-
struct Params {
|
| 381 |
-
ElementAccum* ptr_dKaccum;
|
| 382 |
-
ShapedKV const shape_dKaccum;
|
| 383 |
-
StridedKV const stride_dKaccum;
|
| 384 |
-
ElementAccum* ptr_dVaccum;
|
| 385 |
-
ShapedKV const shape_dVaccum;
|
| 386 |
-
StridedKV const stride_dVaccum;
|
| 387 |
-
cutlass::FastDivmod qhead_per_khead_divmod;
|
| 388 |
-
int* dk_semaphore;
|
| 389 |
-
int* dv_semaphore;
|
| 390 |
-
int const* cu_seqlens = nullptr;
|
| 391 |
-
int const* seqused = nullptr;
|
| 392 |
-
};
|
| 393 |
-
|
| 394 |
-
static Params
|
| 395 |
-
to_underlying_arguments(Arguments const& args) {
|
| 396 |
-
if constexpr (Deterministic) {
|
| 397 |
-
assert(args.dk_semaphore != nullptr);
|
| 398 |
-
assert(args.dv_semaphore != nullptr);
|
| 399 |
-
}
|
| 400 |
-
return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.shape_dVaccum, args.stride_dVaccum,
|
| 401 |
-
cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))),
|
| 402 |
-
args.dk_semaphore, args.dv_semaphore,
|
| 403 |
-
args.cu_seqlens, args.seqused};
|
| 404 |
-
}
|
| 405 |
-
|
| 406 |
-
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
| 407 |
-
CUTLASS_DEVICE
|
| 408 |
-
static void prefetch_tma_descriptors(Params const& params) {
|
| 409 |
-
}
|
| 410 |
-
|
| 411 |
-
template <typename SharedStorage, typename FrgTensorO, typename TiledMma>
|
| 412 |
-
CUTLASS_DEVICE void
|
| 413 |
-
store(Params const& params,
|
| 414 |
-
FrgTensorO const& tdKrdK,
|
| 415 |
-
FrgTensorO const& tdVrdV,
|
| 416 |
-
SharedStorage& shared_storage,
|
| 417 |
-
TiledMma tiled_mma,
|
| 418 |
-
int thread_idx,
|
| 419 |
-
cute::tuple<int32_t, int32_t, int32_t> const& block_coord
|
| 420 |
-
) {
|
| 421 |
-
|
| 422 |
-
auto [n_block, bidh, bidb] = block_coord;
|
| 423 |
-
int bidh_idx_in_group;
|
| 424 |
-
int bidh_kv = params.qhead_per_khead_divmod.divmod(bidh_idx_in_group, bidh);
|
| 425 |
-
Tensor sdKV = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccum{});
|
| 426 |
-
Tensor sdKV_flat = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_dkv.data()), SmemLayoutdKVaccumFlat{});
|
| 427 |
-
static constexpr int dKV_TMA_num_bytes = CUTE_STATIC_V(size(sdKV_flat)) * sizeof(ElementAccum);
|
| 428 |
-
|
| 429 |
-
flash::SeqlenInfo<Varlen, kBlockN> seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused};
|
| 430 |
-
bool const is_varlen = Varlen && params.cu_seqlens;
|
| 431 |
-
Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
|
| 432 |
-
Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dVaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0);
|
| 433 |
-
Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
|
| 434 |
-
Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape<Int<kBlockN * kHeadDim>>{}, make_coord(n_block)); // (M * K)
|
| 435 |
-
|
| 436 |
-
R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum;
|
| 437 |
-
auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
|
| 438 |
-
Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV);
|
| 439 |
-
|
| 440 |
-
// Only used if !Use_TMA
|
| 441 |
-
R2GTiledCopydKVaccum r2g_tiled_copy_dKVaccum;
|
| 442 |
-
auto r2g_thr_copy_dKVaccum = r2g_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
|
| 443 |
-
|
| 444 |
-
// Make sure all WGs have finished reading K and V, otherwise we get racy dQ
|
| 445 |
-
// because smem_q could be changed.
|
| 446 |
-
flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
| 447 |
-
if constexpr (Use_TMA) {
|
| 448 |
-
Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N)
|
| 449 |
-
cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum);
|
| 450 |
-
}
|
| 451 |
-
|
| 452 |
-
// int const num_batch = params.num_batch;
|
| 453 |
-
int const num_batch = get<2>(params.shape_dKaccum);
|
| 454 |
-
int const num_head_kv = get<1>(params.shape_dKaccum);
|
| 455 |
-
int *lock_ptr = !Deterministic ? nullptr : params.dv_semaphore + bidb * num_head_kv + bidh_kv;
|
| 456 |
-
using Barrier = cutlass::GenericBarrier<cutlass::detail::SyncwarpSync>;
|
| 457 |
-
|
| 458 |
-
// if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
|
| 459 |
-
|
| 460 |
-
if constexpr (Deterministic) {
|
| 461 |
-
Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
|
| 462 |
-
}
|
| 463 |
-
// if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dv_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dv_semaphore);}
|
| 464 |
-
if constexpr (Use_TMA) {
|
| 465 |
-
cutlass::arch::fence_view_async_shared();
|
| 466 |
-
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
| 467 |
-
if (thread_idx == 0) {
|
| 468 |
-
SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdVaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
|
| 469 |
-
tma_store_arrive();
|
| 470 |
-
tma_store_wait<0>();
|
| 471 |
-
}
|
| 472 |
-
} else {
|
| 473 |
-
Tensor tdVrdV_atomic = r2g_thr_copy_dKVaccum.retile_S(tdVrdV);
|
| 474 |
-
Tensor tdVgdV_atomic = r2g_thr_copy_dKVaccum.partition_D(gdVaccum);
|
| 475 |
-
static_assert(CUTE_STATIC_V(size(tdVrdV_atomic)) == CUTE_STATIC_V(size(tdVgdV_atomic)));
|
| 476 |
-
#pragma unroll
|
| 477 |
-
for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdV_atomic(i), tdVrdV_atomic(i)); }
|
| 478 |
-
}
|
| 479 |
-
if constexpr (Deterministic) {
|
| 480 |
-
Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
|
| 481 |
-
}
|
| 482 |
-
|
| 483 |
-
if constexpr (Use_TMA) {
|
| 484 |
-
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
| 485 |
-
Tensor taccdKVrdK = r2s_thr_copy_dKVaccum.retile_S(tdKrdK); // ((Atom,AtomNum), MMA_M, MMA_N)
|
| 486 |
-
cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdK, tdKVsdKVaccum);
|
| 487 |
-
}
|
| 488 |
-
lock_ptr = !Deterministic ? nullptr : params.dk_semaphore + bidb * num_head_kv + bidh_kv;
|
| 489 |
-
// if (thread_idx == 0) { printf("blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p, num_batch = %d, num_head_kv = %d, n_block = %d, bihd_idx_in_group = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore, num_batch, num_head_kv, n_block, bidh_idx_in_group);}
|
| 490 |
-
|
| 491 |
-
if constexpr (Deterministic) {
|
| 492 |
-
Barrier::wait_eq(lock_ptr, thread_idx, n_block * num_batch * num_head_kv, bidh_idx_in_group);
|
| 493 |
-
}
|
| 494 |
-
// if (thread_idx == 0) { printf("After barrier blockIdx.x = %d, blockIdx.y = %d, blockIdx.z = %d, bidb = %d, bidh_kv = %d, lock_ptr = %p, dk_semaphore = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, bidb, bidh_kv, lock_ptr, params.dk_semaphore);}
|
| 495 |
-
if constexpr (Use_TMA) {
|
| 496 |
-
cutlass::arch::fence_view_async_shared();
|
| 497 |
-
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
| 498 |
-
if (thread_idx == 0) {
|
| 499 |
-
SM90_BULK_REDUCE_ADD::copy(raw_pointer_cast(sdKV_flat.data()), raw_pointer_cast(gdKaccum.data()), dKV_TMA_num_bytes, static_cast<uint64_t>(TMA::CacheHintSm90::EVICT_LAST));
|
| 500 |
-
tma_store_arrive();
|
| 501 |
-
tma_store_wait<0>();
|
| 502 |
-
}
|
| 503 |
-
} else {
|
| 504 |
-
Tensor tdKrdK_atomic = r2g_thr_copy_dKVaccum.retile_S(tdKrdK);
|
| 505 |
-
Tensor tdKgdK_atomic = r2g_thr_copy_dKVaccum.partition_D(gdKaccum);
|
| 506 |
-
static_assert(CUTE_STATIC_V(size(tdKrdK_atomic)) == CUTE_STATIC_V(size(tdKgdK_atomic)));
|
| 507 |
-
#pragma unroll
|
| 508 |
-
for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdK_atomic(i), tdKrdK_atomic(i)); }
|
| 509 |
-
}
|
| 510 |
-
if constexpr (Deterministic) {
|
| 511 |
-
Barrier::arrive_inc(lock_ptr, thread_idx, n_block * num_batch * num_head_kv);
|
| 512 |
-
}
|
| 513 |
-
// // Tell warp 0 that smem_k and smem_v are ready
|
| 514 |
-
// flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast<uint32_t>(BwdNamedBarriers::KVEmpty) /*id*/);
|
| 515 |
-
}
|
| 516 |
-
|
| 517 |
-
CUTLASS_DEVICE void
|
| 518 |
-
store_tail() {
|
| 519 |
-
}
|
| 520 |
-
|
| 521 |
-
// Write 0 to dK and dV
|
| 522 |
-
CUTLASS_DEVICE void
|
| 523 |
-
store_zero(
|
| 524 |
-
Params const& params,
|
| 525 |
-
int thread_idx,
|
| 526 |
-
cute::tuple<int32_t, int32_t, int32_t> const& block_coord
|
| 527 |
-
) {
|
| 528 |
-
// Don't need to do anything since dKaccum and dVaccum are already zero-initialized
|
| 529 |
-
}
|
| 530 |
-
|
| 531 |
-
};
|
| 532 |
-
|
| 533 |
-
} // namespace flash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/epilogue_fwd.hpp
DELETED
|
@@ -1,484 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
#include <cutlass/cutlass.h>
|
| 8 |
-
#include <cutlass/fast_math.h> // For FastDivMod
|
| 9 |
-
#include "cute/tensor.hpp"
|
| 10 |
-
|
| 11 |
-
#include "cutlass/gemm/collective/builders/sm90_common.inl"
|
| 12 |
-
#include "cutlass/epilogue/collective/builders/sm90_common.inl"
|
| 13 |
-
|
| 14 |
-
#include "seqlen.h"
|
| 15 |
-
#include "named_barrier.hpp"
|
| 16 |
-
#include "pack_gqa.h"
|
| 17 |
-
#include "utils.h"
|
| 18 |
-
|
| 19 |
-
namespace flash {
|
| 20 |
-
|
| 21 |
-
using namespace cute;
|
| 22 |
-
|
| 23 |
-
template <class TileShape_MNK_PV_, class ClusterShape_, class Element_, class ArchTag_,
|
| 24 |
-
int NumEpilogueThreads_, bool Varlen_, bool PackGQA_, bool Split_, bool FP8PermuteCol=false>
|
| 25 |
-
struct CollectiveEpilogueFwd {
|
| 26 |
-
|
| 27 |
-
using TileShape_MNK_PV = TileShape_MNK_PV_;
|
| 28 |
-
using ClusterShape = ClusterShape_;
|
| 29 |
-
using Element = Element_;
|
| 30 |
-
using ElementPartial = float;
|
| 31 |
-
using ArchTag = ArchTag_;
|
| 32 |
-
static constexpr int NumEpilogueThreads = NumEpilogueThreads_;
|
| 33 |
-
static constexpr bool Varlen = Varlen_;
|
| 34 |
-
static constexpr bool PackGQA = PackGQA_;
|
| 35 |
-
static constexpr bool Split = Split_;
|
| 36 |
-
static constexpr bool Use_smem = !(Split && !Varlen);
|
| 37 |
-
static constexpr bool Use_TMA_O = ArchTag::kMinComputeCapability >= 90 && !Varlen && !Split && !PackGQA;
|
| 38 |
-
|
| 39 |
-
static_assert(ArchTag::kMinComputeCapability >= 80);
|
| 40 |
-
static_assert(ArchTag::kMinComputeCapability >= 90 || CUTE_STATIC_V(size(ClusterShape{})) == 1);
|
| 41 |
-
static_assert(sizeof(Element) <= 2);
|
| 42 |
-
|
| 43 |
-
static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
|
| 44 |
-
static constexpr int kHeadDimV = get<1>(TileShape_MNK_PV{});
|
| 45 |
-
|
| 46 |
-
static constexpr bool LargeHeadDimV = kHeadDimV > 256;
|
| 47 |
-
|
| 48 |
-
using GmemTiledCopyOTMA = cute::SM90_TMA_STORE;
|
| 49 |
-
|
| 50 |
-
// These are for storing the output tensor without TMA (e.g., for setting output to zero)
|
| 51 |
-
static constexpr int kGmemElemsPerStore = sizeof(cute::uint128_t) / sizeof(Element);
|
| 52 |
-
static_assert(kHeadDimV % kGmemElemsPerStore == 0, "Headdim must be a multiple of kGmemElemsPerStore");
|
| 53 |
-
// We want each "row" to have 64 elements (128 bytes, i.e. 1 cache line). We want each thread to have 4 elements
|
| 54 |
-
// in the M direction and 2 elements in the K direction. In the case of PackGQA, this reduces the number of times
|
| 55 |
-
// we need to call divmod.
|
| 56 |
-
static constexpr int kBytePerRow = kHeadDimV * sizeof(Element);
|
| 57 |
-
static constexpr int kBlockKGmem = (kBytePerRow % 128 == 0 ? 128 : (kBytePerRow % 64 == 0 ? 64 : 32)) / sizeof(Element);
|
| 58 |
-
static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerStore;
|
| 59 |
-
// If PackGQA, we split the work of compute O_ptr among threads in the same row, so we need this to within a warp
|
| 60 |
-
static_assert(cutlass::NumThreadsPerWarp % kGmemThreadsPerRow == 0);
|
| 61 |
-
static_assert(NumEpilogueThreads % kGmemThreadsPerRow == 0, "NumEpilogueThreads must be a multiple of kGmemThreadsPerRow");
|
| 62 |
-
using GmemLayoutAtom = Layout<Shape <Int<NumEpilogueThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
| 63 |
-
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
| 64 |
-
static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0, "kBlockM must be a multiple of NumEpilogueThreads / kGmemThreadsPerRow");
|
| 65 |
-
using GmemTiledCopyO = decltype(
|
| 66 |
-
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
| 67 |
-
GmemLayoutAtom{},
|
| 68 |
-
Layout<Shape<_1, Int<kGmemElemsPerStore>>>{})); // Val layout, 8 or 16 vals per store
|
| 69 |
-
|
| 70 |
-
using SmemLayoutAtomOTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector<GMMA::Major::K, Element,
|
| 71 |
-
decltype(cute::get<0>(TileShape_MNK_PV{})), decltype(cute::get<1>(TileShape_MNK_PV{}))>());
|
| 72 |
-
using SmemLayoutOTMA = decltype(tile_to_shape(SmemLayoutAtomOTMA{}, select<0, 1>(TileShape_MNK_PV{})));
|
| 73 |
-
static constexpr int kSwizzle = kBlockKGmem == 128 ? 4 : (kBlockKGmem == 64 ? 3 : (kBlockKGmem == 32 ? 2 : 1));
|
| 74 |
-
static constexpr int kSwizzleBase = sizeof(Element) == 4 ? 2 : (sizeof(Element) == 2 ? 3 : 4);
|
| 75 |
-
using SmemLayoutAtomO = decltype(
|
| 76 |
-
composition(Swizzle<kSwizzle, kSwizzleBase, kSwizzleBase>{},
|
| 77 |
-
Layout<Shape<_8, Int<kBlockKGmem>>,
|
| 78 |
-
Stride<Int<kBlockKGmem>, _1>>{}));
|
| 79 |
-
using SmemLayoutOSTS = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_MNK_PV{})));
|
| 80 |
-
using SmemLayoutO = std::conditional_t<ArchTag::kMinComputeCapability >= 90, SmemLayoutOTMA, SmemLayoutOSTS>;
|
| 81 |
-
|
| 82 |
-
using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch, num_splits)
|
| 83 |
-
using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
|
| 84 |
-
using StrideLSE = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen_q, head, batch, num_splits)
|
| 85 |
-
// ((qhead_per_khead, seqlen_q), d, nheads_kv, batch, num_splits)
|
| 86 |
-
using ShapeOPacked = std::conditional_t<!PackGQA, ShapeO, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t, int32_t>>;
|
| 87 |
-
using StrideOPacked = std::conditional_t<!PackGQA, StrideO, cute::Stride<cute::Stride<int64_t, int64_t>, _1, int64_t, int64_t, int64_t>>;
|
| 88 |
-
// ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits)
|
| 89 |
-
using ShapeLSEPacked = std::conditional_t<!PackGQA, cute::Shape<int32_t, int32_t, int32_t, int32_t>, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
|
| 90 |
-
using StrideLSEPacked = std::conditional_t<!PackGQA, StrideLSE, cute::Stride<cute::Stride<int64_t, _1>, int64_t, int64_t, int64_t>>;
|
| 91 |
-
|
| 92 |
-
using CopyOpR2S = std::conditional_t<
|
| 93 |
-
ArchTag::kMinComputeCapability >= 90,
|
| 94 |
-
// cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16)
|
| 95 |
-
decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element>()),
|
| 96 |
-
AutoVectorizingCopyWithAssumedAlignment<128>
|
| 97 |
-
>;
|
| 98 |
-
using SmemCopyAtomO = Copy_Atom<CopyOpR2S, Element>;
|
| 99 |
-
|
| 100 |
-
// static constexpr size_t SmemAlignmentO = cutlass::detail::alignment_for_swizzle(SmemLayoutO{});
|
| 101 |
-
// static_assert(SmemAlignmentO >= 128, "Require at least 128B alignment");
|
| 102 |
-
// struct TensorStorage : cute::aligned_struct<SmemAlignmentO> {
|
| 103 |
-
// cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0, SmemAlignmentO> smem_o;
|
| 104 |
-
// };
|
| 105 |
-
struct TensorStorage : cute::aligned_struct<128> {
|
| 106 |
-
cute::array_aligned<Element, Use_smem ? cute::cosize_v<SmemLayoutO> : 0> smem_o;
|
| 107 |
-
};
|
| 108 |
-
|
| 109 |
-
using TMA_O = std::conditional_t<
|
| 110 |
-
Use_TMA_O,
|
| 111 |
-
decltype(make_tma_copy(
|
| 112 |
-
GmemTiledCopyOTMA{},
|
| 113 |
-
make_tensor(make_gmem_ptr(static_cast<Element*>(nullptr)), ShapeO{}, StrideO{}),
|
| 114 |
-
SmemLayoutOTMA{},
|
| 115 |
-
select<0, 1>(TileShape_MNK_PV{}),
|
| 116 |
-
_1{})), // no mcast for O
|
| 117 |
-
std::nullptr_t
|
| 118 |
-
>;
|
| 119 |
-
|
| 120 |
-
// Host side kernel arguments
|
| 121 |
-
struct Arguments {
|
| 122 |
-
Element* ptr_O;
|
| 123 |
-
ShapeO const shape_O;
|
| 124 |
-
StrideO const stride_O;
|
| 125 |
-
ElementPartial* ptr_O_partial;
|
| 126 |
-
StrideO const stride_O_partial;
|
| 127 |
-
float* ptr_LSE;
|
| 128 |
-
StrideLSE const stride_LSE;
|
| 129 |
-
float* ptr_LSE_partial;
|
| 130 |
-
StrideLSE const stride_LSE_partial;
|
| 131 |
-
int32_t const nheads_kv;
|
| 132 |
-
int const* cu_seqlens = nullptr;
|
| 133 |
-
int const* seqused = nullptr;
|
| 134 |
-
};
|
| 135 |
-
|
| 136 |
-
// Device side kernel params
|
| 137 |
-
struct Params {
|
| 138 |
-
Element* ptr_O;
|
| 139 |
-
ShapeO const shape_O;
|
| 140 |
-
StrideO const stride_O;
|
| 141 |
-
ShapeOPacked const shape_O_packed;
|
| 142 |
-
StrideOPacked const stride_O_packed;
|
| 143 |
-
ElementPartial* ptr_O_partial;
|
| 144 |
-
StrideO const stride_O_partial;
|
| 145 |
-
StrideOPacked const stride_O_partial_packed;
|
| 146 |
-
float* ptr_LSE;
|
| 147 |
-
StrideLSE const stride_LSE;
|
| 148 |
-
ShapeLSEPacked const shape_LSE_packed;
|
| 149 |
-
StrideLSEPacked const stride_LSE_packed;
|
| 150 |
-
float* ptr_LSE_partial;
|
| 151 |
-
StrideLSE const stride_LSE_partial;
|
| 152 |
-
StrideLSEPacked const stride_LSE_partial_packed;
|
| 153 |
-
cutlass::FastDivmod qhead_per_khead_divmod;
|
| 154 |
-
TMA_O tma_store_O;
|
| 155 |
-
int const* cu_seqlens = nullptr;
|
| 156 |
-
int const* seqused = nullptr;
|
| 157 |
-
};
|
| 158 |
-
|
| 159 |
-
static Params
|
| 160 |
-
to_underlying_arguments(Arguments const& args) {
|
| 161 |
-
Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O);
|
| 162 |
-
TMA_O tma_store_O = [&]{
|
| 163 |
-
if constexpr (Use_TMA_O) {
|
| 164 |
-
return make_tma_copy(GmemTiledCopyOTMA{}, mO, SmemLayoutO{}, select<0, 1>(TileShape_MNK_PV{}), _1{}); // no mcast
|
| 165 |
-
} else {
|
| 166 |
-
return nullptr;
|
| 167 |
-
}
|
| 168 |
-
}();
|
| 169 |
-
// If PackGQA, reshape O to be ((qhead_per_khead, seqlen_q), head_size, nhead_k, batch_size, num_splits)
|
| 170 |
-
int const qhead_per_khead = !PackGQA ? 1 : cute::ceil_div(get<2>(args.shape_O), args.nheads_kv);
|
| 171 |
-
auto const shape_O_packed = cute::conditional_return<!PackGQA>(
|
| 172 |
-
args.shape_O,
|
| 173 |
-
make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), get<1>(args.shape_O), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
|
| 174 |
-
);
|
| 175 |
-
auto const stride_O_packed = cute::conditional_return<!PackGQA>(
|
| 176 |
-
args.stride_O,
|
| 177 |
-
make_stride(make_stride(get<2>(args.stride_O), get<0>(args.stride_O)), get<1>(args.stride_O), get<2>(args.stride_O) * qhead_per_khead, get<3>(args.stride_O), get<4>(args.stride_O))
|
| 178 |
-
);
|
| 179 |
-
auto const stride_O_partial_packed = cute::conditional_return<!PackGQA>(
|
| 180 |
-
args.stride_O_partial,
|
| 181 |
-
make_stride(make_stride(get<2>(args.stride_O_partial), get<0>(args.stride_O_partial)), get<1>(args.stride_O_partial), get<2>(args.stride_O_partial) * qhead_per_khead, get<3>(args.stride_O_partial), get<4>(args.stride_O_partial))
|
| 182 |
-
);
|
| 183 |
-
// If PackGQA, Reshape LSE to be ((qhead_per_khead, seqlen_q), nhead_k, batch_size, num_splits)
|
| 184 |
-
auto const shape_LSE_packed = cute::conditional_return<!PackGQA>(
|
| 185 |
-
select<0, 2, 3, 4>(args.shape_O),
|
| 186 |
-
make_shape(make_shape(qhead_per_khead, get<0>(args.shape_O)), args.nheads_kv, get<3>(args.shape_O), get<4>(args.shape_O))
|
| 187 |
-
);
|
| 188 |
-
auto const stride_LSE_packed = cute::conditional_return<!PackGQA>(
|
| 189 |
-
args.stride_LSE,
|
| 190 |
-
make_stride(make_stride(get<1>(args.stride_LSE), get<0>(args.stride_LSE)), get<1>(args.stride_LSE) * qhead_per_khead, get<2>(args.stride_LSE), get<3>(args.stride_LSE))
|
| 191 |
-
);
|
| 192 |
-
auto const stride_LSE_partial_packed = cute::conditional_return<!PackGQA>(
|
| 193 |
-
args.stride_LSE_partial,
|
| 194 |
-
make_stride(make_stride(get<1>(args.stride_LSE_partial), get<0>(args.stride_LSE_partial)), get<1>(args.stride_LSE_partial) * qhead_per_khead, get<2>(args.stride_LSE_partial), get<3>(args.stride_LSE_partial))
|
| 195 |
-
);
|
| 196 |
-
return {args.ptr_O, args.shape_O, args.stride_O, shape_O_packed, stride_O_packed,
|
| 197 |
-
args.ptr_O_partial, args.stride_O_partial, stride_O_partial_packed,
|
| 198 |
-
args.ptr_LSE, args.stride_LSE, shape_LSE_packed, stride_LSE_packed,
|
| 199 |
-
args.ptr_LSE_partial, args.stride_LSE_partial, stride_LSE_partial_packed,
|
| 200 |
-
cutlass::FastDivmod(qhead_per_khead),
|
| 201 |
-
tma_store_O, args.cu_seqlens, args.seqused};
|
| 202 |
-
}
|
| 203 |
-
|
| 204 |
-
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
|
| 205 |
-
CUTLASS_DEVICE
|
| 206 |
-
static void prefetch_tma_descriptors(Params const& params) {
|
| 207 |
-
if constexpr (Use_TMA_O) {
|
| 208 |
-
cute::prefetch_tma_descriptor(params.tma_store_O.get_tma_descriptor());
|
| 209 |
-
}
|
| 210 |
-
}
|
| 211 |
-
|
| 212 |
-
template <typename SharedStorage, typename FrgTensorO, typename FrgTensorLSE, typename TiledMma>
|
| 213 |
-
CUTLASS_DEVICE void
|
| 214 |
-
store(Params const& params,
|
| 215 |
-
FrgTensorO& tOrO,
|
| 216 |
-
FrgTensorLSE const& lse,
|
| 217 |
-
SharedStorage& shared_storage,
|
| 218 |
-
TiledMma tiled_mma,
|
| 219 |
-
int thread_idx,
|
| 220 |
-
cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
|
| 221 |
-
) {
|
| 222 |
-
|
| 223 |
-
auto [m_block, bidh, bidb, split_idx] = block_coord;
|
| 224 |
-
int num_splits = get<4>(params.shape_O_packed);
|
| 225 |
-
if constexpr (Split && Varlen) {
|
| 226 |
-
uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
|
| 227 |
-
int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
|
| 228 |
-
num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
|
| 229 |
-
split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx
|
| 230 |
-
}
|
| 231 |
-
bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
|
| 232 |
-
|
| 233 |
-
Tensor sO = make_tensor(make_smem_ptr(shared_storage.tensors.epilogue.smem_o.data()), SmemLayoutO{});
|
| 234 |
-
// Tensor sO_pi = cute::as_position_independent_swizzle_tensor(sO);
|
| 235 |
-
|
| 236 |
-
static constexpr bool NeedFP8Permute = FP8PermuteCol && (sizeof(Element) == 2 || sizeof(Element) == 4);
|
| 237 |
-
// If we will possibly need tOrO in FP32, we'd want to permute tOrO before type conversion.
|
| 238 |
-
// Otherwise we can permute after conversion.
|
| 239 |
-
if constexpr (NeedFP8Permute && Split) { flash::permute_output_fp8_Vcolmajor(tOrO); }
|
| 240 |
-
Tensor tOrO_out = make_tensor_like<Element>(tOrO);
|
| 241 |
-
flash::convert_type_out(tOrO, tOrO_out);
|
| 242 |
-
if constexpr (NeedFP8Permute && !Split) { flash::permute_output_fp8_Vcolmajor(tOrO_out); }
|
| 243 |
-
|
| 244 |
-
// Make sure all WGs have finished reading V
|
| 245 |
-
// Technically we don't need this if we're not using smem, but the mainloop makes the assumption that
|
| 246 |
-
// all epilogue threads sync at least once during the epilogue (so that we can start loading Q with
|
| 247 |
-
// cp.async if we need).
|
| 248 |
-
flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
| 249 |
-
|
| 250 |
-
// Step 1: Write O from rmem -> smem
|
| 251 |
-
if constexpr (Use_smem) {
|
| 252 |
-
auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma);
|
| 253 |
-
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx);
|
| 254 |
-
Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N)
|
| 255 |
-
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
| 256 |
-
// Tensor taccOsO = smem_thr_copy_O.partition_D(sO_pi); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
| 257 |
-
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
| 258 |
-
if constexpr (Use_TMA_O) {
|
| 259 |
-
cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA
|
| 260 |
-
cutlass::arch::NamedBarrier::arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
|
| 261 |
-
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
| 262 |
-
} else {
|
| 263 |
-
flash::named_barrier_sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
| 264 |
-
}
|
| 265 |
-
} else {
|
| 266 |
-
if constexpr (ArchTag::kMinComputeCapability >= 90) {
|
| 267 |
-
#pragma unroll
|
| 268 |
-
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
|
| 269 |
-
shared_storage.pipelines.barrier_O.arrive(cta_id);
|
| 270 |
-
}
|
| 271 |
-
}
|
| 272 |
-
}
|
| 273 |
-
|
| 274 |
-
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
|
| 275 |
-
bool is_varlen = Varlen && params.cu_seqlens;
|
| 276 |
-
int offset_o = seqlen_info.offset;
|
| 277 |
-
int seqlen_o = seqlen_info.seqlen;
|
| 278 |
-
int warp_group_idx = __shfl_sync(0xFFFFFFFF, thread_idx / cutlass::NumThreadsPerWarpGroup, 0);
|
| 279 |
-
|
| 280 |
-
// Step 2: Write LSE from rmem -> gmem
|
| 281 |
-
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
|
| 282 |
-
// (MMA,MMA_M,MMA_K)
|
| 283 |
-
Tensor taccOcO = thread_mma.partition_C(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
|
| 284 |
-
static_assert(decltype(size<0, 0>(taccOcO))::value == 2);
|
| 285 |
-
static_assert(decltype(size<0, 1>(taccOcO))::value == 2);
|
| 286 |
-
Tensor taccOcO_rowcol = make_tensor(taccOcO.data(), flash::convert_layout_acc_rowcol(taccOcO.layout()));
|
| 287 |
-
Tensor taccOcO_row = taccOcO_rowcol(_, _0{});
|
| 288 |
-
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
|
| 289 |
-
|
| 290 |
-
using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;
|
| 291 |
-
using PackGQApartial_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, ElementPartial>;
|
| 292 |
-
|
| 293 |
-
Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
|
| 294 |
-
params.shape_LSE_packed,
|
| 295 |
-
!is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
|
| 296 |
-
// if (thread_idx == 0) { printf("Before LSE write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o); print(mLSE); printf("\n"); }
|
| 297 |
-
if (!LargeHeadDimV || warp_group_idx == 0) {
|
| 298 |
-
if constexpr (!PackGQA) {
|
| 299 |
-
#pragma unroll
|
| 300 |
-
for (int mi = 0; mi < size(lse); ++mi) {
|
| 301 |
-
int const row = m_block * kBlockM + get<0>(taccOcO_row(mi));
|
| 302 |
-
if (get<1>(taccOcO_row(_0{})) == 0 && row < seqlen_o) { mLSE(row) = lse(mi); }
|
| 303 |
-
}
|
| 304 |
-
} else {
|
| 305 |
-
PackGQA_t::store_LSE(mLSE, lse, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
|
| 306 |
-
}
|
| 307 |
-
}
|
| 308 |
-
|
| 309 |
-
// Step 3: Write O from smem -> gmem
|
| 310 |
-
if constexpr (Use_TMA_O) {
|
| 311 |
-
Tensor mO = params.tma_store_O.get_tma_tensor(params.shape_O)(_, _, bidh, bidb, split_idx);
|
| 312 |
-
Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
|
| 313 |
-
auto block_tma_O = params.tma_store_O.get_slice(_0{});
|
| 314 |
-
Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K)
|
| 315 |
-
Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K)
|
| 316 |
-
int warp_idx_sync = __shfl_sync(0xffffffff, thread_idx / cutlass::NumThreadsPerWarp, 0);
|
| 317 |
-
if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
|
| 318 |
-
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
|
| 319 |
-
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
|
| 320 |
-
if (cute::elect_one_sync()) {
|
| 321 |
-
cute::copy(params.tma_store_O, tOsO, tOgO);
|
| 322 |
-
tma_store_arrive();
|
| 323 |
-
tma_store_wait<0>();
|
| 324 |
-
#pragma unroll
|
| 325 |
-
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
|
| 326 |
-
shared_storage.pipelines.barrier_O.arrive(cta_id);
|
| 327 |
-
}
|
| 328 |
-
}
|
| 329 |
-
}
|
| 330 |
-
} else { // Don't use TMA in Varlen case since we don't want to overwrite the output of another sequence
|
| 331 |
-
if (!is_split) {
|
| 332 |
-
Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});
|
| 333 |
-
Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
|
| 334 |
-
// if (thread_idx == 0) { printf("Before O write, m_block: %d, bidh: %d, bidb: %d, split_idx: %d, offset_o: %d, seqlen_o: %d, mO_addr = %p, addr diff = %d\n", m_block, bidh, bidb, split_idx, offset_o, seqlen_o, mO.data(), reinterpret_cast<int>(&mO(0)) - reinterpret_cast<int>(params.ptr_O)); }
|
| 335 |
-
GmemTiledCopyO gmem_tiled_copy_O;
|
| 336 |
-
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
| 337 |
-
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
| 338 |
-
// Tensor tOsO = gmem_thr_copy_O.partition_S(sO_pi); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
| 339 |
-
Tensor tOrO = make_fragment_like(tOsO);
|
| 340 |
-
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
|
| 341 |
-
if constexpr (ArchTag::kMinComputeCapability >= 90) {
|
| 342 |
-
cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_v
|
| 343 |
-
#pragma unroll
|
| 344 |
-
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
|
| 345 |
-
shared_storage.pipelines.barrier_O.arrive(cta_id);
|
| 346 |
-
}
|
| 347 |
-
}
|
| 348 |
-
if constexpr (!PackGQA) {
|
| 349 |
-
// (BLK_M,BLK_K) -> (blk_m,blk_k)
|
| 350 |
-
Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
|
| 351 |
-
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOsO)));
|
| 352 |
-
#pragma unroll
|
| 353 |
-
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
|
| 354 |
-
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
| 355 |
-
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
| 356 |
-
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
| 357 |
-
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
|
| 358 |
-
);
|
| 359 |
-
} else {
|
| 360 |
-
// If PackGQA, we split the work of compute O_ptr among threads in the same row
|
| 361 |
-
PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
|
| 362 |
-
}
|
| 363 |
-
} else {
|
| 364 |
-
Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset_o * get<0>(params.stride_O_partial)), params.shape_O_packed, params.stride_O_partial_packed)(_, _, bidh, !is_varlen ? bidb : 0, split_idx);
|
| 365 |
-
Tensor gOpartial = local_tile(mOpartial, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
|
| 366 |
-
// We already arrived on barrier_O earlier if !Use_smem
|
| 367 |
-
if constexpr (Use_smem) {
|
| 368 |
-
if constexpr (ArchTag::kMinComputeCapability >= 90) {
|
| 369 |
-
#pragma unroll
|
| 370 |
-
for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) {
|
| 371 |
-
shared_storage.pipelines.barrier_O.arrive(cta_id);
|
| 372 |
-
}
|
| 373 |
-
}
|
| 374 |
-
}
|
| 375 |
-
if constexpr (!PackGQA) {
|
| 376 |
-
static constexpr int kGmemElemsPerStoreDirect = 2;
|
| 377 |
-
cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial> gmem_copy_direct;
|
| 378 |
-
// Reshape acc from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
|
| 379 |
-
Tensor tOrO_rowcol = make_tensor(tOrO.data(), flash::convert_layout_acc_rowcol(tOrO.layout()));
|
| 380 |
-
Tensor tOrO_copy = cute::tiled_divide(tOrO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
|
| 381 |
-
Tensor tOgO = thread_mma.partition_C(gOpartial);
|
| 382 |
-
Tensor tOgO_rowcol = make_tensor(tOgO.data(), flash::convert_layout_acc_rowcol(tOgO.layout()));
|
| 383 |
-
Tensor tOgO_copy = cute::tiled_divide(tOgO_rowcol, Shape<_1, Int<kGmemElemsPerStoreDirect>>{});
|
| 384 |
-
Tensor taccOcO_col = taccOcO_rowcol(_0{}, _);
|
| 385 |
-
#pragma unroll
|
| 386 |
-
for (int m = 0; m < size(taccOcO_row); ++m) {
|
| 387 |
-
if (get<0>(taccOcO_row(m)) < seqlen_o - m_block * kBlockM) {
|
| 388 |
-
#pragma unroll
|
| 389 |
-
for (int k = 0; k < size(taccOcO_col) / kGmemElemsPerStoreDirect; ++k) {
|
| 390 |
-
if (get<1>(taccOcO_col(k * kGmemElemsPerStoreDirect)) < get<1>(params.shape_O)) {
|
| 391 |
-
cute::copy(gmem_copy_direct, tOrO_copy(_, m, k), tOgO_copy(_, m, k));
|
| 392 |
-
}
|
| 393 |
-
}
|
| 394 |
-
}
|
| 395 |
-
}
|
| 396 |
-
} else {
|
| 397 |
-
PackGQApartial_t::store_O_direct(mOpartial, tOrO, tiled_mma, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
|
| 398 |
-
}
|
| 399 |
-
}
|
| 400 |
-
}
|
| 401 |
-
}
|
| 402 |
-
|
| 403 |
-
CUTLASS_DEVICE void
|
| 404 |
-
store_tail() {
|
| 405 |
-
// Don't need to do tma_store_wait<0>() here since we already did in @store
|
| 406 |
-
}
|
| 407 |
-
|
| 408 |
-
// Write 0 to output and -inf to LSE
|
| 409 |
-
CUTLASS_DEVICE void
|
| 410 |
-
store_zero(
|
| 411 |
-
Params const& params,
|
| 412 |
-
int thread_idx,
|
| 413 |
-
cute::tuple<int32_t, int32_t, int32_t, int32_t> const& block_coord
|
| 414 |
-
) {
|
| 415 |
-
static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
|
| 416 |
-
auto [m_block, bidh, bidb, split_idx] = block_coord;
|
| 417 |
-
int num_splits = get<4>(params.shape_O_packed);
|
| 418 |
-
if constexpr (Split && Varlen) {
|
| 419 |
-
uint32_t num_splits_dynamic_u = reinterpret_cast<uint32_t const&>(split_idx) >> 16; // first 16 bits are for num_splits
|
| 420 |
-
int num_splits_dynamic = reinterpret_cast<int&>(num_splits_dynamic_u);
|
| 421 |
-
num_splits = num_splits_dynamic > 0 ? num_splits_dynamic : num_splits;
|
| 422 |
-
split_idx &= 0x0000FFFF; // Only use the lower 16 bits of split_idx
|
| 423 |
-
}
|
| 424 |
-
bool const is_split = !Split ? false : (!Varlen ? true : num_splits > 1);
|
| 425 |
-
|
| 426 |
-
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused};
|
| 427 |
-
bool const is_varlen = Varlen && params.cu_seqlens;
|
| 428 |
-
int offset_o = seqlen_info.offset;
|
| 429 |
-
int seqlen_o = seqlen_info.seqlen;
|
| 430 |
-
int qhead_per_khead = !PackGQA ? 1 : params.qhead_per_khead_divmod.divisor;
|
| 431 |
-
Tensor mLSE = make_tensor(make_gmem_ptr((!is_split ? params.ptr_LSE : params.ptr_LSE_partial) + offset_o * get<0>(!is_split ? params.stride_LSE : params.stride_LSE_partial)),
|
| 432 |
-
params.shape_LSE_packed,
|
| 433 |
-
!is_split ? params.stride_LSE_packed : params.stride_LSE_partial_packed)(_, bidh, !is_varlen ? bidb : 0, !is_split ? 0 : split_idx);
|
| 434 |
-
Tensor gLSE = local_tile(mLSE, Shape<Int<kBlockM>>{}, make_coord(m_block));
|
| 435 |
-
|
| 436 |
-
static_assert(kBlockM <= NumEpilogueThreads);
|
| 437 |
-
if (thread_idx < kBlockM) {
|
| 438 |
-
const int row = m_block * kBlockM + thread_idx;
|
| 439 |
-
if constexpr (!PackGQA) {
|
| 440 |
-
if (row < seqlen_o) { mLSE(row) = -INFINITY; }
|
| 441 |
-
} else {
|
| 442 |
-
if (row < seqlen_o * qhead_per_khead) {
|
| 443 |
-
int m_idx, h_idx;
|
| 444 |
-
m_idx = params.qhead_per_khead_divmod.divmod(h_idx, row);
|
| 445 |
-
// mLSE has shape ((qhead_per_khead, seqlen_q)) and it's unhappy with just 1 "make_coord"
|
| 446 |
-
mLSE(make_coord(make_coord(h_idx, m_idx))) = -INFINITY;
|
| 447 |
-
}
|
| 448 |
-
}
|
| 449 |
-
}
|
| 450 |
-
|
| 451 |
-
// If split, we don't have to write 0 to mOpartial if the mha_combine kernel is used,
|
| 452 |
-
// since it will not use the value of O if LSE is -inf.
|
| 453 |
-
if (!is_split) {
|
| 454 |
-
Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset_o * get<0>(params.stride_O)), params.shape_O_packed, params.stride_O_packed)(_, _, bidh, !is_varlen ? bidb : 0, _0{});
|
| 455 |
-
|
| 456 |
-
GmemTiledCopyO gmem_tiled_copy_O;
|
| 457 |
-
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
| 458 |
-
Tensor tOcO = gmem_thr_copy_O.partition_D(cute::make_identity_tensor(select<0, 1>(TileShape_MNK_PV{})));
|
| 459 |
-
if constexpr (!PackGQA) {
|
| 460 |
-
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
|
| 461 |
-
#pragma unroll
|
| 462 |
-
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
|
| 463 |
-
Tensor gO = local_tile(mO, select<0, 1>(TileShape_MNK_PV{}), make_coord(m_block, _0{})); // (M, K)
|
| 464 |
-
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
| 465 |
-
Tensor tOrO = make_fragment_like(tOgO);
|
| 466 |
-
cute::clear(tOrO);
|
| 467 |
-
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
| 468 |
-
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
| 469 |
-
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, seqlen_o - m_block * kBlockM
|
| 470 |
-
);
|
| 471 |
-
} else {
|
| 472 |
-
// If PackGQA, we split the work of compute O_ptr among threads in the same row
|
| 473 |
-
using PackGQA_t = flash::PackGQAManager<get<0>(TileShape_MNK_PV{}), get<1>(TileShape_MNK_PV{}), NumEpilogueThreads, Element>;
|
| 474 |
-
Tensor tOrO = make_tensor<Element>(make_shape(Shape<_1, Int<kGmemElemsPerStore>>{}, size<1>(tOcO), size<2>(tOcO)));
|
| 475 |
-
cute::clear(tOrO);
|
| 476 |
-
PackGQA_t::store_O(mO, tOrO, params.qhead_per_khead_divmod, thread_idx, seqlen_o, m_block);
|
| 477 |
-
}
|
| 478 |
-
}
|
| 479 |
-
|
| 480 |
-
}
|
| 481 |
-
|
| 482 |
-
};
|
| 483 |
-
|
| 484 |
-
} // namespace flash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/flash.h
DELETED
|
@@ -1,218 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2023, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
#include <cuda.h>
|
| 8 |
-
#include <vector>
|
| 9 |
-
|
| 10 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 11 |
-
|
| 12 |
-
struct Qkv_params {
|
| 13 |
-
using index_t = int64_t;
|
| 14 |
-
// The QKV matrices.
|
| 15 |
-
void *__restrict__ q_ptr;
|
| 16 |
-
void *__restrict__ k_ptr;
|
| 17 |
-
void *__restrict__ v_ptr;
|
| 18 |
-
|
| 19 |
-
// The stride between rows of the Q, K and V matrices.
|
| 20 |
-
index_t q_batch_stride;
|
| 21 |
-
index_t k_batch_stride;
|
| 22 |
-
index_t v_batch_stride;
|
| 23 |
-
index_t q_row_stride;
|
| 24 |
-
index_t k_row_stride;
|
| 25 |
-
index_t v_row_stride;
|
| 26 |
-
index_t q_head_stride;
|
| 27 |
-
index_t k_head_stride;
|
| 28 |
-
index_t v_head_stride;
|
| 29 |
-
index_t v_dim_stride;
|
| 30 |
-
|
| 31 |
-
// The number of heads.
|
| 32 |
-
int h, h_k;
|
| 33 |
-
};
|
| 34 |
-
|
| 35 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 36 |
-
|
| 37 |
-
struct Flash_fwd_params : public Qkv_params {
|
| 38 |
-
using index_t = int64_t;
|
| 39 |
-
|
| 40 |
-
// The O matrix (output).
|
| 41 |
-
void * __restrict__ o_ptr;
|
| 42 |
-
void * __restrict__ oaccum_ptr;
|
| 43 |
-
|
| 44 |
-
// The stride between rows of O.
|
| 45 |
-
index_t o_batch_stride;
|
| 46 |
-
index_t o_row_stride;
|
| 47 |
-
index_t o_head_stride;
|
| 48 |
-
|
| 49 |
-
// The pointer to the softmax sum.
|
| 50 |
-
void * __restrict__ softmax_lse_ptr;
|
| 51 |
-
void * __restrict__ softmax_lseaccum_ptr;
|
| 52 |
-
|
| 53 |
-
// For FP8 scaling
|
| 54 |
-
float * __restrict__ q_descale_ptr;
|
| 55 |
-
float * __restrict__ k_descale_ptr;
|
| 56 |
-
float * __restrict__ v_descale_ptr;
|
| 57 |
-
index_t q_descale_batch_stride;
|
| 58 |
-
index_t q_descale_head_stride;
|
| 59 |
-
index_t k_descale_batch_stride;
|
| 60 |
-
index_t k_descale_head_stride;
|
| 61 |
-
index_t v_descale_batch_stride;
|
| 62 |
-
index_t v_descale_head_stride;
|
| 63 |
-
|
| 64 |
-
// The dimensions.
|
| 65 |
-
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
|
| 66 |
-
int total_q, total_k, total_knew;
|
| 67 |
-
int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q
|
| 68 |
-
int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim
|
| 69 |
-
|
| 70 |
-
// The scaling factors for the kernel.
|
| 71 |
-
float scale_softmax;
|
| 72 |
-
float softcap;
|
| 73 |
-
|
| 74 |
-
// array of length b+1 holding starting offset of each sequence.
|
| 75 |
-
int * __restrict__ cu_seqlens_q;
|
| 76 |
-
int * __restrict__ cu_seqlens_k;
|
| 77 |
-
int * __restrict__ cu_seqlens_knew;
|
| 78 |
-
int * __restrict__ leftpad_k;
|
| 79 |
-
|
| 80 |
-
// If provided, the actual length of each q/k sequence.
|
| 81 |
-
int *__restrict__ seqused_q;
|
| 82 |
-
int *__restrict__ seqused_k;
|
| 83 |
-
|
| 84 |
-
// The stride between rows of Oaccum.
|
| 85 |
-
index_t oaccum_split_stride;
|
| 86 |
-
index_t oaccum_batch_stride;
|
| 87 |
-
index_t oaccum_row_stride;
|
| 88 |
-
index_t oaccum_head_stride;
|
| 89 |
-
|
| 90 |
-
// The stride between rows of LSEaccum.
|
| 91 |
-
index_t lseaccum_split_stride;
|
| 92 |
-
index_t lseaccum_batch_stride;
|
| 93 |
-
index_t lseaccum_head_stride;
|
| 94 |
-
|
| 95 |
-
// The K_new and V_new matrices.
|
| 96 |
-
void * __restrict__ knew_ptr;
|
| 97 |
-
void * __restrict__ vnew_ptr;
|
| 98 |
-
|
| 99 |
-
// The stride between rows of the Q, K and V matrices.
|
| 100 |
-
index_t knew_batch_stride;
|
| 101 |
-
index_t vnew_batch_stride;
|
| 102 |
-
index_t knew_row_stride;
|
| 103 |
-
index_t vnew_row_stride;
|
| 104 |
-
index_t knew_head_stride;
|
| 105 |
-
index_t vnew_head_stride;
|
| 106 |
-
|
| 107 |
-
void *__restrict__ qv_ptr;
|
| 108 |
-
index_t qv_batch_stride;
|
| 109 |
-
index_t qv_row_stride;
|
| 110 |
-
index_t qv_head_stride;
|
| 111 |
-
|
| 112 |
-
// The cos and sin matrices for rotary embedding.
|
| 113 |
-
void * __restrict__ rotary_cos_ptr;
|
| 114 |
-
void * __restrict__ rotary_sin_ptr;
|
| 115 |
-
int *__restrict__ seqlens_rotary;
|
| 116 |
-
|
| 117 |
-
// The indices to index into the KV cache.
|
| 118 |
-
int * __restrict__ kv_batch_idx;
|
| 119 |
-
|
| 120 |
-
// Paged KV cache
|
| 121 |
-
int * __restrict__ page_table;
|
| 122 |
-
index_t page_table_batch_stride;
|
| 123 |
-
int page_size;
|
| 124 |
-
int num_pages;
|
| 125 |
-
bool pagedkv_tma;
|
| 126 |
-
|
| 127 |
-
// The dropout probability (probability of keeping an activation).
|
| 128 |
-
float p_dropout;
|
| 129 |
-
// uint32_t p_dropout_in_uint;
|
| 130 |
-
// uint16_t p_dropout_in_uint16_t;
|
| 131 |
-
uint8_t p_dropout_in_uint8_t;
|
| 132 |
-
|
| 133 |
-
// Scale factor of 1 / (1 - p_dropout).
|
| 134 |
-
float rp_dropout;
|
| 135 |
-
|
| 136 |
-
// Local window size
|
| 137 |
-
int window_size_left, window_size_right;
|
| 138 |
-
int attention_chunk;
|
| 139 |
-
|
| 140 |
-
// Pointer to the RNG seed (idx 0) and offset (idx 1).
|
| 141 |
-
uint64_t * rng_state;
|
| 142 |
-
|
| 143 |
-
bool is_bf16;
|
| 144 |
-
bool is_fp32;
|
| 145 |
-
bool is_e4m3;
|
| 146 |
-
bool is_causal;
|
| 147 |
-
bool is_local;
|
| 148 |
-
|
| 149 |
-
bool is_rotary_interleaved;
|
| 150 |
-
|
| 151 |
-
int num_splits; // For split-KV version
|
| 152 |
-
bool pack_gqa;
|
| 153 |
-
|
| 154 |
-
int * __restrict__ tile_count_semaphore;
|
| 155 |
-
// int * __restrict__ num_m_blocks_ptr;
|
| 156 |
-
// int * __restrict__ num_n_blocks_ptr;
|
| 157 |
-
int * __restrict__ num_splits_dynamic_ptr;
|
| 158 |
-
bool skip_scheduler_metadata_computation;
|
| 159 |
-
|
| 160 |
-
int arch;
|
| 161 |
-
int num_sm;
|
| 162 |
-
};
|
| 163 |
-
|
| 164 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 165 |
-
|
| 166 |
-
struct Flash_bwd_params : public Flash_fwd_params {
|
| 167 |
-
using index_t = int64_t;
|
| 168 |
-
|
| 169 |
-
// The dO and dQKV matrices.
|
| 170 |
-
void *__restrict__ do_ptr;
|
| 171 |
-
void *__restrict__ dq_ptr;
|
| 172 |
-
void *__restrict__ dk_ptr;
|
| 173 |
-
void *__restrict__ dv_ptr;
|
| 174 |
-
|
| 175 |
-
// To accumulate dQ
|
| 176 |
-
void *__restrict__ dq_accum_ptr;
|
| 177 |
-
void *__restrict__ dk_accum_ptr;
|
| 178 |
-
void *__restrict__ dv_accum_ptr;
|
| 179 |
-
|
| 180 |
-
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
|
| 181 |
-
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
|
| 182 |
-
// dv_accum_ptr;
|
| 183 |
-
|
| 184 |
-
// The stride between rows of the dO, dQ, dK and dV matrices.
|
| 185 |
-
index_t do_batch_stride;
|
| 186 |
-
index_t do_row_stride;
|
| 187 |
-
index_t do_head_stride;
|
| 188 |
-
index_t dq_batch_stride;
|
| 189 |
-
index_t dk_batch_stride;
|
| 190 |
-
index_t dv_batch_stride;
|
| 191 |
-
index_t dq_row_stride;
|
| 192 |
-
index_t dk_row_stride;
|
| 193 |
-
index_t dv_row_stride;
|
| 194 |
-
index_t dq_head_stride;
|
| 195 |
-
index_t dk_head_stride;
|
| 196 |
-
index_t dv_head_stride;
|
| 197 |
-
|
| 198 |
-
// The pointer to the softmax d sum.
|
| 199 |
-
void *__restrict__ dsoftmax_sum;
|
| 200 |
-
void *__restrict__ softmax_lse_log2_ptr;
|
| 201 |
-
|
| 202 |
-
int *__restrict__ dq_semaphore;
|
| 203 |
-
int *__restrict__ dk_semaphore;
|
| 204 |
-
int *__restrict__ dv_semaphore;
|
| 205 |
-
|
| 206 |
-
bool deterministic;
|
| 207 |
-
index_t dq_accum_split_stride;
|
| 208 |
-
};
|
| 209 |
-
|
| 210 |
-
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 211 |
-
|
| 212 |
-
template <int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>
|
| 213 |
-
void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
| 214 |
-
void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa, int blockM, int blockN, bool enable_pdl);
|
| 215 |
-
template <int Arch, typename T, int kHeadDim, bool Has_softcap>
|
| 216 |
-
void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream);
|
| 217 |
-
template <typename T, typename Tpartial, int kBlockK>
|
| 218 |
-
void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/flash_api.cpp
DELETED
|
@@ -1,1720 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#include <Python.h>
|
| 6 |
-
#include <torch/nn/functional/padding.h>
|
| 7 |
-
#include <ATen/cuda/CUDAContextLight.h>
|
| 8 |
-
#include <c10/cuda/CUDAGuard.h>
|
| 9 |
-
|
| 10 |
-
#include <cutlass/numeric_types.h>
|
| 11 |
-
|
| 12 |
-
#include "flash.h"
|
| 13 |
-
#include "static_switch.h"
|
| 14 |
-
#include "tile_size.h"
|
| 15 |
-
#include "heuristics.h"
|
| 16 |
-
#include "cuda_check.h"
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
extern "C" {
|
| 20 |
-
/* Creates a dummy empty _C module that can be imported from Python.
|
| 21 |
-
The import from Python will load the .so consisting of this file
|
| 22 |
-
in this extension, so that the TORCH_LIBRARY static initializers
|
| 23 |
-
below are run. */
|
| 24 |
-
PyObject* PyInit__C(void)
|
| 25 |
-
{
|
| 26 |
-
static struct PyModuleDef module_def = {
|
| 27 |
-
PyModuleDef_HEAD_INIT,
|
| 28 |
-
"_C", /* name of module */
|
| 29 |
-
NULL, /* module documentation, may be NULL */
|
| 30 |
-
-1, /* size of per-interpreter state of the module,
|
| 31 |
-
or -1 if the module keeps state in global variables. */
|
| 32 |
-
NULL, /* methods */
|
| 33 |
-
};
|
| 34 |
-
return PyModule_Create(&module_def);
|
| 35 |
-
}
|
| 36 |
-
}
|
| 37 |
-
|
| 38 |
-
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
| 39 |
-
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
| 40 |
-
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
| 41 |
-
|
| 42 |
-
void set_params_fprop(Flash_fwd_params ¶ms,
|
| 43 |
-
// sizes
|
| 44 |
-
const size_t b,
|
| 45 |
-
const size_t seqlen_q,
|
| 46 |
-
const size_t seqlen_k,
|
| 47 |
-
const size_t seqlen_q_rounded,
|
| 48 |
-
const size_t seqlen_k_rounded,
|
| 49 |
-
const size_t h,
|
| 50 |
-
const size_t h_k,
|
| 51 |
-
const size_t d,
|
| 52 |
-
const size_t d_rounded,
|
| 53 |
-
// device pointers
|
| 54 |
-
const at::Tensor q,
|
| 55 |
-
const at::Tensor k,
|
| 56 |
-
const at::Tensor v,
|
| 57 |
-
at::Tensor out,
|
| 58 |
-
void *cu_seqlens_q_d,
|
| 59 |
-
void *cu_seqlens_k_d,
|
| 60 |
-
void *seqused_q,
|
| 61 |
-
void *seqused_k,
|
| 62 |
-
void *softmax_lse_d,
|
| 63 |
-
float p_dropout,
|
| 64 |
-
float softmax_scale,
|
| 65 |
-
int window_size_left,
|
| 66 |
-
int window_size_right,
|
| 67 |
-
int attention_chunk,
|
| 68 |
-
const float softcap=0.f,
|
| 69 |
-
const int sm_margin=0) {
|
| 70 |
-
|
| 71 |
-
// Reset the parameters
|
| 72 |
-
params = {};
|
| 73 |
-
|
| 74 |
-
params.is_bf16 = q.dtype() == torch::kBFloat16;
|
| 75 |
-
params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn;
|
| 76 |
-
|
| 77 |
-
// Set the pointers and strides.
|
| 78 |
-
params.q_ptr = q.data_ptr();
|
| 79 |
-
params.k_ptr = k.data_ptr();
|
| 80 |
-
params.v_ptr = v.data_ptr();
|
| 81 |
-
// All stride are in elements, not bytes.
|
| 82 |
-
params.q_row_stride = q.stride(-3);
|
| 83 |
-
params.k_row_stride = k.stride(-3);
|
| 84 |
-
params.v_row_stride = v.stride(-3);
|
| 85 |
-
params.q_head_stride = q.stride(-2);
|
| 86 |
-
params.k_head_stride = k.stride(-2);
|
| 87 |
-
params.v_head_stride = v.stride(-2);
|
| 88 |
-
params.v_dim_stride = v.stride(-1);
|
| 89 |
-
params.o_ptr = out.data_ptr();
|
| 90 |
-
params.o_row_stride = out.stride(-3);
|
| 91 |
-
params.o_head_stride = out.stride(-2);
|
| 92 |
-
|
| 93 |
-
if (cu_seqlens_q_d == nullptr) {
|
| 94 |
-
params.q_batch_stride = q.stride(0);
|
| 95 |
-
params.o_batch_stride = out.stride(0);
|
| 96 |
-
}
|
| 97 |
-
if (cu_seqlens_k_d == nullptr) {
|
| 98 |
-
params.k_batch_stride = k.stride(0);
|
| 99 |
-
params.v_batch_stride = v.stride(0);
|
| 100 |
-
}
|
| 101 |
-
|
| 102 |
-
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
|
| 103 |
-
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
|
| 104 |
-
params.seqused_q = static_cast<int *>(seqused_q);
|
| 105 |
-
params.seqused_k = static_cast<int *>(seqused_k);
|
| 106 |
-
|
| 107 |
-
// Softmax sum
|
| 108 |
-
params.softmax_lse_ptr = softmax_lse_d;
|
| 109 |
-
|
| 110 |
-
// Set the dimensions.
|
| 111 |
-
params.b = b;
|
| 112 |
-
params.h = h;
|
| 113 |
-
params.h_k = h_k;
|
| 114 |
-
params.seqlen_q = seqlen_q;
|
| 115 |
-
params.seqlen_k = seqlen_k;
|
| 116 |
-
params.seqlen_q_rounded = seqlen_q_rounded;
|
| 117 |
-
params.seqlen_k_rounded = seqlen_k_rounded;
|
| 118 |
-
params.d = d;
|
| 119 |
-
params.d_rounded = d_rounded;
|
| 120 |
-
|
| 121 |
-
// Set the different scale values.
|
| 122 |
-
params.scale_softmax = softmax_scale;
|
| 123 |
-
params.softcap = softcap;
|
| 124 |
-
|
| 125 |
-
// Set this to probability of keeping an element to simplify things.
|
| 126 |
-
params.p_dropout = 1.f - p_dropout;
|
| 127 |
-
// Convert p from float to int so we don't have to convert the random uint to float to compare.
|
| 128 |
-
// [Minor] We want to round down since when we do the comparison we use <= instead of <
|
| 129 |
-
// params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
|
| 130 |
-
// params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
|
| 131 |
-
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
|
| 132 |
-
params.rp_dropout = 1.f / params.p_dropout;
|
| 133 |
-
TORCH_CHECK(p_dropout < 1.f);
|
| 134 |
-
#ifdef FLASHATTENTION_DISABLE_DROPOUT
|
| 135 |
-
TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
|
| 136 |
-
#endif
|
| 137 |
-
|
| 138 |
-
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
|
| 139 |
-
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
|
| 140 |
-
params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0;
|
| 141 |
-
params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal;
|
| 142 |
-
|
| 143 |
-
// TODO: check this
|
| 144 |
-
if (window_size_left < 0) { window_size_left = seqlen_k - 1; }
|
| 145 |
-
if (window_size_right < 0) { window_size_right = seqlen_q - 1; }
|
| 146 |
-
if (attention_chunk > 0) {
|
| 147 |
-
window_size_left = std::min(window_size_left, attention_chunk - 1);
|
| 148 |
-
window_size_right = std::min(window_size_right, attention_chunk - 1);
|
| 149 |
-
}
|
| 150 |
-
params.window_size_left = window_size_left;
|
| 151 |
-
params.window_size_right = window_size_right;
|
| 152 |
-
params.attention_chunk = attention_chunk;
|
| 153 |
-
|
| 154 |
-
params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
|
| 155 |
-
params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
|
| 156 |
-
|
| 157 |
-
#ifdef FLASHATTENTION_DISABLE_LOCAL
|
| 158 |
-
TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
|
| 159 |
-
#endif
|
| 160 |
-
}
|
| 161 |
-
|
| 162 |
-
void set_params_dgrad(Flash_bwd_params ¶ms,
|
| 163 |
-
// sizes
|
| 164 |
-
const size_t b,
|
| 165 |
-
const size_t seqlen_q,
|
| 166 |
-
const size_t seqlen_k,
|
| 167 |
-
const size_t seqlen_q_rounded,
|
| 168 |
-
const size_t seqlen_k_rounded,
|
| 169 |
-
const size_t h,
|
| 170 |
-
const size_t h_k,
|
| 171 |
-
const size_t d,
|
| 172 |
-
const size_t d_rounded,
|
| 173 |
-
// device pointers
|
| 174 |
-
const at::Tensor q,
|
| 175 |
-
const at::Tensor k,
|
| 176 |
-
const at::Tensor v,
|
| 177 |
-
const at::Tensor out,
|
| 178 |
-
const at::Tensor dout,
|
| 179 |
-
at::Tensor dq,
|
| 180 |
-
at::Tensor dk,
|
| 181 |
-
at::Tensor dv,
|
| 182 |
-
void *cu_seqlens_q_d,
|
| 183 |
-
void *cu_seqlens_k_d,
|
| 184 |
-
void *seqused_q,
|
| 185 |
-
void *seqused_k,
|
| 186 |
-
void *dq_accum_d,
|
| 187 |
-
void *dk_accum_d,
|
| 188 |
-
void *dv_accum_d,
|
| 189 |
-
void *softmax_lse_d,
|
| 190 |
-
void *dsoftmax_sum_d,
|
| 191 |
-
float p_dropout,
|
| 192 |
-
float softmax_scale,
|
| 193 |
-
int window_size_left,
|
| 194 |
-
int window_size_right,
|
| 195 |
-
int attention_chunk,
|
| 196 |
-
const float softcap=0.f,
|
| 197 |
-
bool deterministic=false,
|
| 198 |
-
int const sm_margin=0) {
|
| 199 |
-
|
| 200 |
-
set_params_fprop(params,
|
| 201 |
-
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
|
| 202 |
-
q, k, v, out,
|
| 203 |
-
cu_seqlens_q_d,
|
| 204 |
-
cu_seqlens_k_d,
|
| 205 |
-
seqused_q,
|
| 206 |
-
seqused_k,
|
| 207 |
-
softmax_lse_d,
|
| 208 |
-
p_dropout,
|
| 209 |
-
softmax_scale,
|
| 210 |
-
window_size_left,
|
| 211 |
-
window_size_right,
|
| 212 |
-
attention_chunk,
|
| 213 |
-
softcap,
|
| 214 |
-
sm_margin);
|
| 215 |
-
|
| 216 |
-
// Set the pointers and strides.
|
| 217 |
-
params.do_ptr = dout.data_ptr();
|
| 218 |
-
params.do_row_stride = dout.stride(-3);
|
| 219 |
-
params.do_head_stride = dout.stride(-2);
|
| 220 |
-
params.dq_ptr = dq.data_ptr();
|
| 221 |
-
params.dk_ptr = dk.data_ptr();
|
| 222 |
-
params.dv_ptr = dv.data_ptr();
|
| 223 |
-
params.dq_row_stride = dq.stride(-3);
|
| 224 |
-
params.dk_row_stride = dk.stride(-3);
|
| 225 |
-
params.dv_row_stride = dv.stride(-3);
|
| 226 |
-
params.dq_head_stride = dq.stride(-2);
|
| 227 |
-
params.dk_head_stride = dk.stride(-2);
|
| 228 |
-
params.dv_head_stride = dv.stride(-2);
|
| 229 |
-
|
| 230 |
-
if (cu_seqlens_q_d == nullptr) {
|
| 231 |
-
params.do_batch_stride = dout.stride(0);
|
| 232 |
-
params.dq_batch_stride = dq.stride(0);
|
| 233 |
-
params.dk_batch_stride = dk.stride(0);
|
| 234 |
-
params.dv_batch_stride = dv.stride(0);
|
| 235 |
-
}
|
| 236 |
-
|
| 237 |
-
params.dq_accum_ptr = dq_accum_d;
|
| 238 |
-
params.dk_accum_ptr = dk_accum_d;
|
| 239 |
-
params.dv_accum_ptr = dv_accum_d;
|
| 240 |
-
|
| 241 |
-
// Softmax sum
|
| 242 |
-
params.dsoftmax_sum = dsoftmax_sum_d;
|
| 243 |
-
|
| 244 |
-
params.deterministic = deterministic;
|
| 245 |
-
}
|
| 246 |
-
|
| 247 |
-
template <int Arch, int Split, bool PagedKVNonTMA, bool PackGQA, bool Has_softcap>
|
| 248 |
-
void run_mha_fwd_constexpr(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
| 249 |
-
if (!params.is_e4m3) {
|
| 250 |
-
if (params.is_bf16) {
|
| 251 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM64
|
| 252 |
-
if (params.d <= 64) {
|
| 253 |
-
if constexpr (Arch == 90) {
|
| 254 |
-
if (params.dv > 256) {
|
| 255 |
-
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
| 256 |
-
} else if (params.dv > 64) {
|
| 257 |
-
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
| 258 |
-
}
|
| 259 |
-
}
|
| 260 |
-
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
| 261 |
-
}
|
| 262 |
-
#endif
|
| 263 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM96
|
| 264 |
-
if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
| 265 |
-
#endif
|
| 266 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
| 267 |
-
if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
| 268 |
-
#endif
|
| 269 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
| 270 |
-
if (params.d <= 192) {
|
| 271 |
-
if constexpr (Arch == 90) {
|
| 272 |
-
if (params.dv <= 128) {
|
| 273 |
-
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
| 274 |
-
}
|
| 275 |
-
}
|
| 276 |
-
return run_mha_fwd_<Arch, cutlass::bfloat16_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
| 277 |
-
}
|
| 278 |
-
#endif
|
| 279 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
| 280 |
-
if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::bfloat16_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
| 281 |
-
#endif
|
| 282 |
-
} else {
|
| 283 |
-
#ifndef FLASHATTENTION_DISABLE_FP16
|
| 284 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM64
|
| 285 |
-
if (params.d <= 64) {
|
| 286 |
-
if constexpr (Arch == 90) {
|
| 287 |
-
if (params.dv > 256) {
|
| 288 |
-
return run_mha_fwd_<Arch, cutlass::half_t, 64, 512, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
| 289 |
-
} else if (params.dv > 64) {
|
| 290 |
-
return run_mha_fwd_<Arch, cutlass::half_t, 64, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
| 291 |
-
}
|
| 292 |
-
}
|
| 293 |
-
return run_mha_fwd_<Arch, cutlass::half_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
| 294 |
-
}
|
| 295 |
-
#endif
|
| 296 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM96
|
| 297 |
-
if (params.d <= 96) { return run_mha_fwd_<Arch, cutlass::half_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
| 298 |
-
#endif
|
| 299 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
| 300 |
-
if (params.d <= 128) { return run_mha_fwd_<Arch, cutlass::half_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
| 301 |
-
#endif
|
| 302 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
| 303 |
-
if (params.d <= 192) {
|
| 304 |
-
if constexpr (Arch == 90) {
|
| 305 |
-
if (params.dv <= 128) {
|
| 306 |
-
return run_mha_fwd_<Arch, cutlass::half_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
| 307 |
-
}
|
| 308 |
-
}
|
| 309 |
-
return run_mha_fwd_<Arch, cutlass::half_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
| 310 |
-
}
|
| 311 |
-
#endif
|
| 312 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
| 313 |
-
if (params.d <= 256) { return run_mha_fwd_<Arch, cutlass::half_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
| 314 |
-
#endif
|
| 315 |
-
#else
|
| 316 |
-
TORCH_CHECK(false, "This flash attention build does not support FP16.");
|
| 317 |
-
#endif
|
| 318 |
-
}
|
| 319 |
-
} else {
|
| 320 |
-
#ifndef FLASHATTENTION_DISABLE_FP8
|
| 321 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM64
|
| 322 |
-
if (params.d <= 64) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 64, 64, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
| 323 |
-
#endif
|
| 324 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM96
|
| 325 |
-
if (params.d <= 96) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 96, 96, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
| 326 |
-
#endif
|
| 327 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
| 328 |
-
if (params.d <= 128) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 128, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
| 329 |
-
#endif
|
| 330 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
| 331 |
-
if (params.d <= 192) {
|
| 332 |
-
if constexpr (Arch == 90) {
|
| 333 |
-
if (params.dv <= 128) {
|
| 334 |
-
return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 128, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
| 335 |
-
}
|
| 336 |
-
}
|
| 337 |
-
return run_mha_fwd_<90, cutlass::float_e4m3_t, 192, 192, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
|
| 338 |
-
}
|
| 339 |
-
#endif
|
| 340 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
| 341 |
-
if (params.d <= 256) { return run_mha_fwd_<90, cutlass::float_e4m3_t, 256, 256, Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream); }
|
| 342 |
-
#endif
|
| 343 |
-
#else
|
| 344 |
-
TORCH_CHECK(false, "This flash attention build does not support FP8.");
|
| 345 |
-
#endif
|
| 346 |
-
}
|
| 347 |
-
}
|
| 348 |
-
|
| 349 |
-
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
| 350 |
-
// HEADDIM_SWITCH(params.d, [&] {
|
| 351 |
-
// run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
|
| 352 |
-
// });
|
| 353 |
-
TORCH_CHECK(params.num_splits >= 1);
|
| 354 |
-
ARCH_SWITCH(params.arch, Arch, [&] {
|
| 355 |
-
SPLIT_SWITCH(params.num_splits > 1, Split, [&] {
|
| 356 |
-
PAGEDKV_SWITCH(params.page_table && !params.pagedkv_tma, PagedKVNonTMA, [&] {
|
| 357 |
-
PACKGQA_SWITCH(params.pack_gqa, PackGQA_, [&] {
|
| 358 |
-
// Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation
|
| 359 |
-
static constexpr bool PackGQA = PackGQA_ || Arch < 90 || PagedKVNonTMA || Split;
|
| 360 |
-
SOFTCAP_SWITCH(params.softcap > 0.0, Has_softcap, [&] {
|
| 361 |
-
run_mha_fwd_constexpr<Arch, Split, PagedKVNonTMA, PackGQA, Has_softcap>(params, stream);
|
| 362 |
-
});
|
| 363 |
-
});
|
| 364 |
-
});
|
| 365 |
-
});
|
| 366 |
-
});
|
| 367 |
-
}
|
| 368 |
-
|
| 369 |
-
void run_mha_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl=false) {
|
| 370 |
-
#ifndef FLASHATTENTION_DISABLE_SPLIT
|
| 371 |
-
// If hdim is 96 or 192, it's faster to round them to 128 or 256 respectively
|
| 372 |
-
// so that kBlockM is smaller and we have more parallelism.
|
| 373 |
-
if (params.is_fp32) {
|
| 374 |
-
if (params.dv <= 64) {
|
| 375 |
-
run_mha_fwd_combine_<float, float, 64>(params, stream, enable_pdl);
|
| 376 |
-
} else {
|
| 377 |
-
run_mha_fwd_combine_<float, float, 128>(params, stream, enable_pdl);
|
| 378 |
-
}
|
| 379 |
-
} else if (params.is_bf16) {
|
| 380 |
-
if (params.dv <= 64) {
|
| 381 |
-
run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(params, stream, enable_pdl);
|
| 382 |
-
} else {
|
| 383 |
-
run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(params, stream, enable_pdl);
|
| 384 |
-
}
|
| 385 |
-
} else {
|
| 386 |
-
if (params.dv <= 64) {
|
| 387 |
-
run_mha_fwd_combine_<cutlass::half_t, float, 64>(params, stream, enable_pdl);
|
| 388 |
-
} else {
|
| 389 |
-
run_mha_fwd_combine_<cutlass::half_t, float, 128>(params, stream, enable_pdl);
|
| 390 |
-
}
|
| 391 |
-
}
|
| 392 |
-
#else
|
| 393 |
-
TORCH_CHECK(false, "This flash attention build does not support combine kernels.");
|
| 394 |
-
#endif
|
| 395 |
-
}
|
| 396 |
-
|
| 397 |
-
inline bool get_pagedkv_tma(Flash_fwd_params const& params) {
|
| 398 |
-
if (params.arch < 90 || !params.page_table || params.leftpad_k || params.knew_ptr) { return false; }
|
| 399 |
-
// This needs to match the kernel configs
|
| 400 |
-
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, false /*paged_kv_non_TMA*/, params.softcap > 0.f);
|
| 401 |
-
int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
|
| 402 |
-
int const kBlockN = std::get<1>(kBlockMN_kernel_args_sm90);
|
| 403 |
-
// Heuristic: when seqlen_q <= kBlockM, we're not compute bound, and somehow using TMA is slower,
|
| 404 |
-
// at least for MLA.
|
| 405 |
-
return params.page_size % kBlockN == 0 && params.seqlen_q * (params.h / params.h_k) > kBlockM;
|
| 406 |
-
}
|
| 407 |
-
|
| 408 |
-
inline bool get_pack_gqa(Flash_fwd_params const& params) {
|
| 409 |
-
// Always enable PackGQA for Sm8x or PagedKVNonTMA or Split to reduce compilation and binary size.
|
| 410 |
-
// Has little effect on speed.
|
| 411 |
-
if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; }
|
| 412 |
-
#ifdef FLASHATTENTION_DISABLE_PACKGQA
|
| 413 |
-
return false;
|
| 414 |
-
#else
|
| 415 |
-
// params.page_table must already be set
|
| 416 |
-
if (params.h == params.h_k) { return false; }
|
| 417 |
-
// This needs to match the kernel configs
|
| 418 |
-
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
|
| 419 |
-
int const kBlockM = std::get<0>(kBlockMN_kernel_args_sm90);
|
| 420 |
-
return should_pack_gqa(params.cu_seqlens_q || params.seqused_q, params.seqlen_q, params.h / params.h_k, kBlockM);
|
| 421 |
-
#endif
|
| 422 |
-
}
|
| 423 |
-
|
| 424 |
-
inline int get_num_splits(Flash_fwd_params const& params) {
|
| 425 |
-
#ifdef FLASHATTENTION_DISABLE_SPLIT
|
| 426 |
-
return 1;
|
| 427 |
-
#else
|
| 428 |
-
// Always enable PackGQA for Split
|
| 429 |
-
// params.page_table must already be set
|
| 430 |
-
// This needs to match the kernel configs
|
| 431 |
-
bool varlen = params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k;
|
| 432 |
-
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
|
| 433 |
-
// Strictly speaking we need to pass in (varlen && params.num_splits > 1) but num_splits
|
| 434 |
-
// has not been set here. It's OK though because we might just underestimate kBlockN a bit
|
| 435 |
-
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, varlen, params.softcap > 0.f, params.knew_ptr);
|
| 436 |
-
int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
|
| 437 |
-
int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
|
| 438 |
-
int seqlen_q_packgqa = params.seqlen_q * (params.h / params.h_k);
|
| 439 |
-
// If is_local, we're not going to load all of seqlen_k
|
| 440 |
-
int const seqlen_k_loaded = !params.is_local
|
| 441 |
-
? params.seqlen_k
|
| 442 |
-
: std::max(0, std::min(params.seqlen_k, params.window_size_right + params.window_size_left + 1 + kBlockM));
|
| 443 |
-
int const num_n_blocks = (seqlen_k_loaded + kBlockN - 1) / kBlockN;
|
| 444 |
-
int const num_m_blocks = (seqlen_q_packgqa + kBlockM - 1) / kBlockM;
|
| 445 |
-
int const size_one_kv_head = params.seqlen_k * (params.d + params.dv) * (params.is_e4m3 ? 1 : 2);
|
| 446 |
-
// Always enable PackGQA for Split
|
| 447 |
-
// If varlen, we use dynamic split, so this heuristic just needs to get an upper bound on num_splits.
|
| 448 |
-
// We assume the case where there's 1 long sequence and the rest are short, i.e. pretending
|
| 449 |
-
// that batch = 1.
|
| 450 |
-
int total_mblocks = (params.num_splits_dynamic_ptr ? 1 : params.b) * params.h_k * num_m_blocks;
|
| 451 |
-
return num_splits_heuristic(total_mblocks, params.num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, params.is_causal || params.is_local, 128);
|
| 452 |
-
#endif
|
| 453 |
-
}
|
| 454 |
-
|
| 455 |
-
inline int get_max_headdim() {
|
| 456 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
| 457 |
-
return 256;
|
| 458 |
-
#endif
|
| 459 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
| 460 |
-
return 192;
|
| 461 |
-
#endif
|
| 462 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
| 463 |
-
return 128;
|
| 464 |
-
#endif
|
| 465 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM96
|
| 466 |
-
return 96;
|
| 467 |
-
#endif
|
| 468 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM64
|
| 469 |
-
return 64;
|
| 470 |
-
#endif
|
| 471 |
-
return 0;
|
| 472 |
-
}
|
| 473 |
-
|
| 474 |
-
inline int round_up_headdim(int head_size) {
|
| 475 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM64
|
| 476 |
-
if (head_size <= 64) { return 64; }
|
| 477 |
-
#endif
|
| 478 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM96
|
| 479 |
-
if (head_size <= 96) { return 96; }
|
| 480 |
-
#endif
|
| 481 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
| 482 |
-
if (head_size <= 128) { return 128; }
|
| 483 |
-
#endif
|
| 484 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
| 485 |
-
if (head_size <= 192) { return 192; }
|
| 486 |
-
#endif
|
| 487 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
| 488 |
-
if (head_size <= 256) { return 256; }
|
| 489 |
-
#endif
|
| 490 |
-
return 256;
|
| 491 |
-
}
|
| 492 |
-
|
| 493 |
-
inline int round_up_headdimv(int head_size) {
|
| 494 |
-
if (head_size <= 64) { return 64; }
|
| 495 |
-
if (head_size <= 96) { return 96; }
|
| 496 |
-
if (head_size <= 128) { return 128; }
|
| 497 |
-
if (head_size <= 192) { return 192; }
|
| 498 |
-
if (head_size <= 256) { return 256; }
|
| 499 |
-
return 512;
|
| 500 |
-
}
|
| 501 |
-
|
| 502 |
-
// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
|
| 503 |
-
at::Tensor
|
| 504 |
-
mha_fwd_get_scheduler_metadata(
|
| 505 |
-
int64_t batch_size,
|
| 506 |
-
int64_t max_seqlen_q,
|
| 507 |
-
int64_t max_seqlen_k,
|
| 508 |
-
int64_t num_heads,
|
| 509 |
-
int64_t num_heads_k,
|
| 510 |
-
int64_t headdim,
|
| 511 |
-
int64_t headdim_v,
|
| 512 |
-
at::ScalarType qkv_dtype,
|
| 513 |
-
at::Tensor seqused_k, // b
|
| 514 |
-
std::optional<at::Tensor> cu_seqlens_q_, // b+1
|
| 515 |
-
std::optional<at::Tensor> cu_seqlens_k_, // b+1
|
| 516 |
-
std::optional<at::Tensor> cu_seqlens_k_new_, // b+1
|
| 517 |
-
std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
|
| 518 |
-
std::optional<at::Tensor> leftpad_k_, // b
|
| 519 |
-
std::optional<int64_t> page_size,
|
| 520 |
-
int64_t max_seqlen_k_new, // 0 means we're not appending new KV
|
| 521 |
-
bool is_causal,
|
| 522 |
-
int64_t window_size_left,
|
| 523 |
-
int64_t window_size_right,
|
| 524 |
-
int64_t attention_chunk,
|
| 525 |
-
bool has_softcap,
|
| 526 |
-
int64_t num_splits,
|
| 527 |
-
std::optional<bool> pack_gqa_,
|
| 528 |
-
int64_t sm_margin
|
| 529 |
-
) {
|
| 530 |
-
|
| 531 |
-
TORCH_CHECK(qkv_dtype == at::ScalarType::Half || qkv_dtype == at::ScalarType::BFloat16 || qkv_dtype == at::ScalarType::Float8_e4m3fn,
|
| 532 |
-
"FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
|
| 533 |
-
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
| 534 |
-
|
| 535 |
-
// Reset the parameters
|
| 536 |
-
Flash_fwd_params params{};
|
| 537 |
-
params.is_bf16 = qkv_dtype == at::ScalarType::BFloat16;
|
| 538 |
-
params.is_e4m3 = qkv_dtype == at::ScalarType::Float8_e4m3fn;
|
| 539 |
-
params.b = batch_size;
|
| 540 |
-
params.seqlen_q = max_seqlen_q;
|
| 541 |
-
params.seqlen_k = max_seqlen_k;
|
| 542 |
-
params.h = num_heads;
|
| 543 |
-
params.h_k = num_heads_k;
|
| 544 |
-
params.d = headdim;
|
| 545 |
-
params.dv = headdim_v;
|
| 546 |
-
params.d_rounded = round_up_headdim(headdim);
|
| 547 |
-
params.dv_rounded = headdim_v == headdim ? params.d_rounded : round_up_headdimv(headdim_v);
|
| 548 |
-
params.seqlen_knew = max_seqlen_k_new;
|
| 549 |
-
|
| 550 |
-
bool const is_varlen_q = cu_seqlens_q_.has_value();
|
| 551 |
-
params.cu_seqlens_q = is_varlen_q ? cu_seqlens_q_.value().data_ptr<int>() : nullptr;
|
| 552 |
-
bool const is_varlen_k = cu_seqlens_k_.has_value();
|
| 553 |
-
params.cu_seqlens_k = is_varlen_k ? cu_seqlens_k_.value().data_ptr<int>() : nullptr;
|
| 554 |
-
params.cu_seqlens_knew = cu_seqlens_k_new_.has_value() ? cu_seqlens_k_new_.value().data_ptr<int>() : nullptr;
|
| 555 |
-
params.seqused_q = seqused_q_.has_value() ? seqused_q_.value().data_ptr<int>() : nullptr;
|
| 556 |
-
params.seqused_k = seqused_k.data_ptr<int>();
|
| 557 |
-
params.leftpad_k = leftpad_k_.has_value() ? leftpad_k_.value().data_ptr<int>() : nullptr;
|
| 558 |
-
params.knew_ptr = params.seqlen_knew > 0 ? reinterpret_cast<int*>(1) : nullptr;
|
| 559 |
-
if (window_size_left >= max_seqlen_k - 1) { window_size_left = -1; }
|
| 560 |
-
if (window_size_right >= max_seqlen_q - 1) { window_size_right = -1; }
|
| 561 |
-
// causal=true is the same as causal=false in this case
|
| 562 |
-
if (max_seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) {
|
| 563 |
-
// Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA
|
| 564 |
-
if ((headdim <= 64 || headdim > 128) || !page_size.has_value()) {
|
| 565 |
-
is_causal = false;
|
| 566 |
-
}
|
| 567 |
-
}
|
| 568 |
-
if (is_causal) { window_size_right = 0; }
|
| 569 |
-
|
| 570 |
-
params.is_causal = window_size_left < 0 && window_size_right == 0 && attention_chunk == 0;
|
| 571 |
-
params.is_local = (window_size_left >= 0 || window_size_right >= 0 || attention_chunk >= 1) && !params.is_causal;
|
| 572 |
-
if (window_size_left < 0) { window_size_left = max_seqlen_k - 1; }
|
| 573 |
-
if (window_size_right < 0) { window_size_right = max_seqlen_q - 1; }
|
| 574 |
-
if (attention_chunk > 0) {
|
| 575 |
-
window_size_left = std::min(window_size_left, attention_chunk - 1);
|
| 576 |
-
window_size_right = std::min(window_size_right, attention_chunk - 1);
|
| 577 |
-
}
|
| 578 |
-
params.window_size_left = window_size_left;
|
| 579 |
-
params.window_size_right = window_size_right;
|
| 580 |
-
params.attention_chunk = attention_chunk;
|
| 581 |
-
params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
|
| 582 |
-
params.num_sm = at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin;
|
| 583 |
-
params.softcap = has_softcap ? 1.0f : 0.0f;
|
| 584 |
-
|
| 585 |
-
params.page_size = page_size.has_value() ? page_size.value() : 1;
|
| 586 |
-
params.page_table = !page_size.has_value() ? nullptr : reinterpret_cast<int*>(1);
|
| 587 |
-
|
| 588 |
-
bool const use_dynamic_split = params.b <= 992;
|
| 589 |
-
params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
|
| 590 |
-
|
| 591 |
-
params.pagedkv_tma = get_pagedkv_tma(params);
|
| 592 |
-
params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
|
| 593 |
-
// Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide
|
| 594 |
-
params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
|
| 595 |
-
|
| 596 |
-
bool is_varlen = true;
|
| 597 |
-
|
| 598 |
-
// Otherwise the kernel will be launched from cuda:0 device
|
| 599 |
-
// Cast to char to avoid compiler warning about narrowing
|
| 600 |
-
at::cuda::CUDAGuard device_guard{(char)seqused_k.get_device()};
|
| 601 |
-
|
| 602 |
-
auto opts = seqused_k.options();
|
| 603 |
-
// This needs to be set after get_num_splits
|
| 604 |
-
at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
|
| 605 |
-
bool const scheduler_needs_semaphore = params.arch >= 90 || params.num_splits > 1;
|
| 606 |
-
if (scheduler_needs_semaphore || use_dynamic_split) {
|
| 607 |
-
tile_count_semaphore = torch::empty({int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b}, opts.dtype(torch::kInt32));
|
| 608 |
-
if (scheduler_needs_semaphore) {
|
| 609 |
-
if (!use_dynamic_split) { tile_count_semaphore.zero_(); } // If varlen we'll manually do the zero-ing
|
| 610 |
-
params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
|
| 611 |
-
} else {
|
| 612 |
-
params.tile_count_semaphore = nullptr;
|
| 613 |
-
}
|
| 614 |
-
params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
|
| 615 |
-
}
|
| 616 |
-
|
| 617 |
-
if (params.num_splits_dynamic_ptr) {
|
| 618 |
-
auto kBlockMN_kernel_args_sm90 = tile_size_fwd_sm90(params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, false /*v_colmajor*/, params.page_table && !params.pagedkv_tma, params.softcap > 0.f);
|
| 619 |
-
auto kBlockMN_kernel_args_sm8x = tile_size_fwd_sm8x(params.arch == 86 || params.arch == 89, params.d_rounded, params.dv_rounded, params.is_causal, params.is_local, params.is_e4m3 ? 1 : 2 /*element_size*/, params.page_table, is_varlen && params.num_splits > 1, params.softcap > 0.f, params.knew_ptr);
|
| 620 |
-
int const kBlockM = params.arch >= 90 ? std::get<0>(kBlockMN_kernel_args_sm90) : std::get<0>(kBlockMN_kernel_args_sm8x);
|
| 621 |
-
int const kBlockN = params.arch >= 90 ? std::get<1>(kBlockMN_kernel_args_sm90) : std::get<1>(kBlockMN_kernel_args_sm8x);
|
| 622 |
-
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 623 |
-
prepare_varlen_num_blocks(params, stream, params.pack_gqa, kBlockM, kBlockN, false /*enable_pdl*/);
|
| 624 |
-
CHECK_CUDA_KERNEL_LAUNCH();
|
| 625 |
-
}
|
| 626 |
-
return tile_count_semaphore;
|
| 627 |
-
}
|
| 628 |
-
|
| 629 |
-
// b: batch_size
|
| 630 |
-
// b_k: batch_size_k
|
| 631 |
-
// s_q: seqlen_q
|
| 632 |
-
// s_k: seqlen_k
|
| 633 |
-
// s_k_new: seqlen_k_new
|
| 634 |
-
// h: num_heads
|
| 635 |
-
// h_k: num_heads_k
|
| 636 |
-
// d: head_size
|
| 637 |
-
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
| 638 |
-
mha_fwd(at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
|
| 639 |
-
at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, h_k, d) if there is page_table.
|
| 640 |
-
at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table.
|
| 641 |
-
std::optional<at::Tensor> k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new
|
| 642 |
-
std::optional<at::Tensor> v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new
|
| 643 |
-
std::optional<at::Tensor> q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
|
| 644 |
-
std::optional<at::Tensor> out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
|
| 645 |
-
std::optional<at::Tensor> cu_seqlens_q_, // b+1
|
| 646 |
-
std::optional<at::Tensor> cu_seqlens_k_, // b+1
|
| 647 |
-
std::optional<at::Tensor> cu_seqlens_k_new_, // b+1
|
| 648 |
-
std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
|
| 649 |
-
std::optional<at::Tensor> seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
|
| 650 |
-
std::optional<int64_t> max_seqlen_q_,
|
| 651 |
-
// TODO: check if we need max_seqlen_k
|
| 652 |
-
std::optional<int64_t> max_seqlen_k_,
|
| 653 |
-
std::optional<at::Tensor> page_table_, // (b_k, max_num_pages_per_seq)
|
| 654 |
-
std::optional<at::Tensor> kv_batch_idx_, // b. indices to index into the KV cache
|
| 655 |
-
std::optional<at::Tensor> leftpad_k_, // b
|
| 656 |
-
std::optional<at::Tensor> rotary_cos_, // seqlen_ro x (rotary_dim / 2)
|
| 657 |
-
std::optional<at::Tensor> rotary_sin_, // seqlen_ro x (rotary_dim / 2)
|
| 658 |
-
std::optional<at::Tensor> seqlens_rotary_, // b
|
| 659 |
-
std::optional<at::Tensor> q_descale_, // (b, h_k), not (b, h)
|
| 660 |
-
std::optional<at::Tensor> k_descale_, // (b, h_k)
|
| 661 |
-
std::optional<at::Tensor> v_descale_, // (b, h_k)
|
| 662 |
-
std::optional<double> softmax_scale_,
|
| 663 |
-
bool is_causal,
|
| 664 |
-
int64_t window_size_left,
|
| 665 |
-
int64_t window_size_right,
|
| 666 |
-
int64_t attention_chunk,
|
| 667 |
-
double softcap,
|
| 668 |
-
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
| 669 |
-
std::optional<at::Tensor> scheduler_metadata_, // (b + 1)
|
| 670 |
-
int64_t num_splits,
|
| 671 |
-
std::optional<bool> pack_gqa_,
|
| 672 |
-
int64_t sm_margin
|
| 673 |
-
) {
|
| 674 |
-
|
| 675 |
-
auto dprops = at::cuda::getCurrentDeviceProperties();
|
| 676 |
-
bool is_sm8x = dprops->major >= 8;
|
| 677 |
-
TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
|
| 678 |
-
|
| 679 |
-
auto q_type = q.scalar_type();
|
| 680 |
-
TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,
|
| 681 |
-
"FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
|
| 682 |
-
if (dprops->major < 9) {
|
| 683 |
-
TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
|
| 684 |
-
"FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type");
|
| 685 |
-
}
|
| 686 |
-
TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype");
|
| 687 |
-
TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype");
|
| 688 |
-
|
| 689 |
-
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
| 690 |
-
|
| 691 |
-
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 692 |
-
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 693 |
-
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 694 |
-
|
| 695 |
-
at::Tensor page_table;
|
| 696 |
-
const bool paged_KV = page_table_.has_value();
|
| 697 |
-
if (paged_KV) {
|
| 698 |
-
page_table = page_table_.value();
|
| 699 |
-
CHECK_DEVICE(page_table);
|
| 700 |
-
TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32");
|
| 701 |
-
TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension");
|
| 702 |
-
}
|
| 703 |
-
|
| 704 |
-
at::Tensor cu_seqlens_q;
|
| 705 |
-
bool const is_varlen_q = cu_seqlens_q_.has_value();
|
| 706 |
-
if (is_varlen_q) {
|
| 707 |
-
cu_seqlens_q = cu_seqlens_q_.value();
|
| 708 |
-
CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
|
| 709 |
-
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
|
| 710 |
-
TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
|
| 711 |
-
}
|
| 712 |
-
at::Tensor cu_seqlens_k;
|
| 713 |
-
bool const is_varlen_k = cu_seqlens_k_.has_value();
|
| 714 |
-
if (is_varlen_k) {
|
| 715 |
-
cu_seqlens_k = cu_seqlens_k_.value();
|
| 716 |
-
CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
|
| 717 |
-
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
|
| 718 |
-
TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
|
| 719 |
-
TORCH_CHECK(!paged_KV, "If cu_seqlens_k is passed in, then page table is not supported");
|
| 720 |
-
TORCH_CHECK(!kv_batch_idx_.has_value(), "If cu_seqlens_k is passed in, then page table is not supported");
|
| 721 |
-
}
|
| 722 |
-
|
| 723 |
-
auto const sizes = q.sizes();
|
| 724 |
-
const int batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
|
| 725 |
-
int seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
|
| 726 |
-
int total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
|
| 727 |
-
int num_heads = q.size(-2);
|
| 728 |
-
int const head_size = q.size(-1);
|
| 729 |
-
int const head_size_v = v.size(-1);
|
| 730 |
-
int const max_num_pages_per_seq = !paged_KV ? 0 : page_table.size(1);
|
| 731 |
-
int const num_pages = !paged_KV ? 0 : k.size(0);
|
| 732 |
-
int const page_size = !paged_KV ? 1 : k.size(1);
|
| 733 |
-
int const seqlen_k = !is_varlen_k ? (!paged_KV ? k.size(1) : max_num_pages_per_seq * page_size) : max_seqlen_k_.value();
|
| 734 |
-
int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
|
| 735 |
-
int const num_heads_k = k.size(-2);
|
| 736 |
-
int const batch_size_k = !paged_KV ? (!is_varlen_k ? k.size(0) : cu_seqlens_k.size(0) - 1) : page_table.size(0);
|
| 737 |
-
double softmax_scale = 1.0 / sqrt(double(head_size));
|
| 738 |
-
if (softmax_scale_.has_value()) {
|
| 739 |
-
softmax_scale = softmax_scale_.value();
|
| 740 |
-
}
|
| 741 |
-
if (!kv_batch_idx_.has_value()) {
|
| 742 |
-
TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k");
|
| 743 |
-
}
|
| 744 |
-
int const max_headdim = get_max_headdim();
|
| 745 |
-
TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
|
| 746 |
-
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
| 747 |
-
if (head_size_v != head_size) {
|
| 748 |
-
TORCH_CHECK((head_size > 128 && head_size <= 192 && head_size_v > 96 && head_size_v <= 128) ||
|
| 749 |
-
(head_size <= 64 && head_size_v <= 512),
|
| 750 |
-
"If V headdim is different from Q/K dim, we only support Q/K headdim in (128, 192] and V headdim in (96, 128], "
|
| 751 |
-
"or (Q/K <= 64 and V <= 512).");
|
| 752 |
-
TORCH_CHECK(dprops->major == 9, "Only Hopper supports different V headdim");
|
| 753 |
-
if (head_size_v > 256) {
|
| 754 |
-
TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
|
| 755 |
-
"HeaddimV > 256 requires fp16 and bf16 data type");
|
| 756 |
-
}
|
| 757 |
-
}
|
| 758 |
-
|
| 759 |
-
// This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
|
| 760 |
-
// TODO: check this
|
| 761 |
-
if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
|
| 762 |
-
if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
|
| 763 |
-
// causal=true is the same as causal=false in this case
|
| 764 |
-
if (seqlen_q == 1 && window_size_left == -1 && window_size_right == -1 && attention_chunk == 0) {
|
| 765 |
-
// Special case of hdim 128 where we want causal to have kBlockN=128, better for pagedKV and TMA
|
| 766 |
-
if ((head_size <= 64 || head_size > 128) || !paged_KV) {
|
| 767 |
-
is_causal = false;
|
| 768 |
-
}
|
| 769 |
-
}
|
| 770 |
-
if (is_causal) { window_size_right = 0; }
|
| 771 |
-
|
| 772 |
-
if (!is_varlen_q) {
|
| 773 |
-
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
| 774 |
-
} else {
|
| 775 |
-
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
| 776 |
-
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
| 777 |
-
}
|
| 778 |
-
if (!paged_KV) {
|
| 779 |
-
if (!is_varlen_k) {
|
| 780 |
-
CHECK_SHAPE(k, batch_size_k, seqlen_k, num_heads_k, head_size);
|
| 781 |
-
CHECK_SHAPE(v, batch_size_k, seqlen_k, num_heads_k, head_size_v);
|
| 782 |
-
} else {
|
| 783 |
-
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
|
| 784 |
-
CHECK_SHAPE(v, total_k, num_heads_k, head_size_v);
|
| 785 |
-
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
| 786 |
-
}
|
| 787 |
-
} else {
|
| 788 |
-
CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size);
|
| 789 |
-
CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v);
|
| 790 |
-
CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq);
|
| 791 |
-
}
|
| 792 |
-
|
| 793 |
-
if (seqused_q_.has_value()){
|
| 794 |
-
auto seqused_q = seqused_q_.value();
|
| 795 |
-
TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
|
| 796 |
-
CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
|
| 797 |
-
CHECK_SHAPE(seqused_q, batch_size);
|
| 798 |
-
}
|
| 799 |
-
if (seqused_k_.has_value()) {
|
| 800 |
-
auto seqused_k = seqused_k_.value();
|
| 801 |
-
TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
|
| 802 |
-
CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
|
| 803 |
-
CHECK_SHAPE(seqused_k, batch_size);
|
| 804 |
-
}
|
| 805 |
-
|
| 806 |
-
if (leftpad_k_.has_value()) {
|
| 807 |
-
auto leftpad_k = leftpad_k_.value();
|
| 808 |
-
TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
|
| 809 |
-
CHECK_DEVICE(leftpad_k); CHECK_CONTIGUOUS(leftpad_k);
|
| 810 |
-
CHECK_SHAPE(leftpad_k, batch_size);
|
| 811 |
-
}
|
| 812 |
-
|
| 813 |
-
// This is what we will template on
|
| 814 |
-
bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value() || leftpad_k_.has_value();
|
| 815 |
-
#ifdef FLASHATTENTION_DISABLE_VARLEN
|
| 816 |
-
TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
|
| 817 |
-
#endif
|
| 818 |
-
|
| 819 |
-
int const alignment = q_type == torch::kFloat8_e4m3fn ? 16 : 8;
|
| 820 |
-
TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment));
|
| 821 |
-
TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment));
|
| 822 |
-
|
| 823 |
-
auto opts = q.options();
|
| 824 |
-
auto out_type = q_type == at::ScalarType::Float8_e4m3fn ? at::ScalarType::BFloat16 : q_type;
|
| 825 |
-
at::Tensor out;
|
| 826 |
-
if (out_.has_value()) {
|
| 827 |
-
out = out_.value();
|
| 828 |
-
TORCH_CHECK(out.scalar_type() == out_type, "For FP16/BF16 input, output must have the same dtype as inputs. For FP8 input, output must have dtype BF16");
|
| 829 |
-
CHECK_DEVICE(out);
|
| 830 |
-
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
| 831 |
-
if (!is_varlen_q) {
|
| 832 |
-
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v);
|
| 833 |
-
} else {
|
| 834 |
-
CHECK_SHAPE(out, total_q, num_heads, head_size_v);
|
| 835 |
-
}
|
| 836 |
-
} else {
|
| 837 |
-
out = !is_varlen_q
|
| 838 |
-
? torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts.dtype(out_type))
|
| 839 |
-
: torch::empty({total_q, num_heads, head_size_v}, opts.dtype(out_type));
|
| 840 |
-
}
|
| 841 |
-
|
| 842 |
-
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 843 |
-
int const head_size_rounded = round_up_headdim(head_size);
|
| 844 |
-
int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv(head_size_v);
|
| 845 |
-
int const seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
| 846 |
-
int const seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
| 847 |
-
|
| 848 |
-
// Otherwise the kernel will be launched from cuda:0 device
|
| 849 |
-
// Cast to char to avoid compiler warning about narrowing
|
| 850 |
-
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
| 851 |
-
|
| 852 |
-
at::Tensor softmax_lse;
|
| 853 |
-
if (!is_varlen_q) {
|
| 854 |
-
softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
| 855 |
-
} else {
|
| 856 |
-
softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
|
| 857 |
-
}
|
| 858 |
-
|
| 859 |
-
Flash_fwd_params params;
|
| 860 |
-
set_params_fprop(params,
|
| 861 |
-
batch_size,
|
| 862 |
-
seqlen_q, seqlen_k,
|
| 863 |
-
seqlen_q_rounded, seqlen_k_rounded,
|
| 864 |
-
num_heads, num_heads_k,
|
| 865 |
-
head_size, head_size_rounded,
|
| 866 |
-
q, k, v, out,
|
| 867 |
-
!is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
|
| 868 |
-
!is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
|
| 869 |
-
seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
|
| 870 |
-
seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
|
| 871 |
-
softmax_lse.data_ptr(),
|
| 872 |
-
/*p_dropout=*/0.f,
|
| 873 |
-
softmax_scale,
|
| 874 |
-
window_size_left,
|
| 875 |
-
window_size_right,
|
| 876 |
-
attention_chunk,
|
| 877 |
-
softcap,
|
| 878 |
-
sm_margin);
|
| 879 |
-
params.total_q = total_q;
|
| 880 |
-
params.total_k = total_k;
|
| 881 |
-
params.b_k = batch_size_k;
|
| 882 |
-
params.dv = head_size_v;
|
| 883 |
-
params.dv_rounded = head_size_v_rounded;
|
| 884 |
-
if (leftpad_k_.has_value()) { // This needs to be set before get_pagedkv_tma
|
| 885 |
-
params.leftpad_k = static_cast<int *>(leftpad_k_.value().data_ptr());
|
| 886 |
-
}
|
| 887 |
-
if (paged_KV) {
|
| 888 |
-
params.page_table = page_table.data_ptr<int>();
|
| 889 |
-
params.page_table_batch_stride = page_table.stride(0);
|
| 890 |
-
}
|
| 891 |
-
params.page_size = page_size;
|
| 892 |
-
params.num_pages = num_pages;
|
| 893 |
-
|
| 894 |
-
if (k_new_.has_value()) { // This needs to be set before get_pagedkv_tma
|
| 895 |
-
at::Tensor k_new, v_new;
|
| 896 |
-
TORCH_CHECK(v_new_.has_value(), "If k_new is supplied, v_new must also be passed in");
|
| 897 |
-
TORCH_CHECK(seqused_k_.has_value(), "If k_new is supplied, seqlens_k must also be passed in");
|
| 898 |
-
TORCH_CHECK(seqlen_q <= seqlen_k, "If k_new is supplied, it must have seqlen <= the seqlen of the KV cache");
|
| 899 |
-
at::Tensor cu_seqlens_k_new;
|
| 900 |
-
bool const is_varlen_k_new = cu_seqlens_k_new_.has_value();
|
| 901 |
-
if (is_varlen_k_new) {
|
| 902 |
-
cu_seqlens_k_new = cu_seqlens_k_new_.value();
|
| 903 |
-
CHECK_DEVICE(cu_seqlens_k_new); CHECK_CONTIGUOUS(cu_seqlens_k_new);
|
| 904 |
-
TORCH_CHECK(cu_seqlens_k_new.dtype() == torch::kInt32, "cu_seqlens_k_new must have dtype torch.int32");
|
| 905 |
-
}
|
| 906 |
-
k_new = k_new_.value();
|
| 907 |
-
v_new = v_new_.value();
|
| 908 |
-
TORCH_CHECK(k_new.dtype() == q_type, "k_new must have the same dtype as query");
|
| 909 |
-
TORCH_CHECK(v_new.dtype() == q_type, "v_new must have the same dtype as query");
|
| 910 |
-
CHECK_DEVICE(k_new); CHECK_DEVICE(v_new);
|
| 911 |
-
TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension");
|
| 912 |
-
TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension");
|
| 913 |
-
// We don't need max_seqlen_k_new, so seqlen_k_new can be whatever when is_varlen_k_new
|
| 914 |
-
int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 0;
|
| 915 |
-
int total_k_new = !is_varlen_k_new ? batch_size * k_new.size(1): k_new.size(0);
|
| 916 |
-
if (!is_varlen_k_new) {
|
| 917 |
-
CHECK_SHAPE(k_new, batch_size, seqlen_k_new, num_heads_k, head_size);
|
| 918 |
-
CHECK_SHAPE(v_new, batch_size, seqlen_k_new, num_heads_k, head_size_v);
|
| 919 |
-
} else {
|
| 920 |
-
CHECK_SHAPE(k_new, total_k_new, num_heads_k, head_size);
|
| 921 |
-
CHECK_SHAPE(v_new, total_k_new, num_heads_k, head_size_v);
|
| 922 |
-
CHECK_SHAPE(cu_seqlens_k_new, batch_size + 1);
|
| 923 |
-
}
|
| 924 |
-
params.seqlen_knew = seqlen_k_new;
|
| 925 |
-
params.total_knew = total_k_new;
|
| 926 |
-
params.knew_ptr = k_new.data_ptr();
|
| 927 |
-
params.vnew_ptr = v_new.data_ptr();
|
| 928 |
-
// All stride are in elements, not bytes.
|
| 929 |
-
params.knew_row_stride = k_new.stride(-3);
|
| 930 |
-
params.vnew_row_stride = v_new.stride(-3);
|
| 931 |
-
params.knew_head_stride = k_new.stride(-2);
|
| 932 |
-
params.vnew_head_stride = v_new.stride(-2);
|
| 933 |
-
if (!is_varlen_k_new) {
|
| 934 |
-
params.knew_batch_stride = k_new.stride(0);
|
| 935 |
-
params.vnew_batch_stride = v_new.stride(0);
|
| 936 |
-
}
|
| 937 |
-
if (is_varlen_k_new) {
|
| 938 |
-
params.cu_seqlens_knew = static_cast<int*>(cu_seqlens_k_new.data_ptr());
|
| 939 |
-
}
|
| 940 |
-
}
|
| 941 |
-
|
| 942 |
-
// 992 = 32 * 31 is the max supported batch in prepare_varlen_num_blocks kernel
|
| 943 |
-
bool const use_dynamic_split = is_varlen && params.b <= 992;
|
| 944 |
-
// Temporarily set num_splits_dynamic_ptr to 1 since get_num_splits checks it
|
| 945 |
-
params.num_splits_dynamic_ptr = !use_dynamic_split ? nullptr : reinterpret_cast<int*>(1);
|
| 946 |
-
|
| 947 |
-
params.pagedkv_tma = get_pagedkv_tma(params);
|
| 948 |
-
params.num_splits = num_splits <= 0 ? get_num_splits(params) : num_splits;
|
| 949 |
-
// Always enable PackGQA for Split, and get_pack_gqa requires params.num_splits to decide
|
| 950 |
-
params.pack_gqa = pack_gqa_.has_value() ? pack_gqa_.value() : get_pack_gqa(params);
|
| 951 |
-
|
| 952 |
-
// This needs to be set after get_num_splits
|
| 953 |
-
at::Tensor tile_count_semaphore; // Contains the semaphore and optionally num_splits_dynamic
|
| 954 |
-
// We don't use the persistent scheduler if Split and not Varlen
|
| 955 |
-
bool const scheduler_needs_semaphore = params.arch >= 90
|
| 956 |
-
? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen)
|
| 957 |
-
: ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1));
|
| 958 |
-
if (scheduler_needs_semaphore || use_dynamic_split) {
|
| 959 |
-
int metadata_size = int(scheduler_needs_semaphore) + int(use_dynamic_split) * params.b;
|
| 960 |
-
params.skip_scheduler_metadata_computation = scheduler_metadata_.has_value();
|
| 961 |
-
if (scheduler_metadata_.has_value()) {
|
| 962 |
-
at::Tensor scheduler_metadata = scheduler_metadata_.value();
|
| 963 |
-
CHECK_DEVICE(scheduler_metadata);
|
| 964 |
-
CHECK_SHAPE(scheduler_metadata, metadata_size);
|
| 965 |
-
CHECK_CONTIGUOUS(scheduler_metadata);
|
| 966 |
-
TORCH_CHECK(scheduler_metadata.dtype() == torch::kInt32, "scheduler_metadata must have dtype int32");
|
| 967 |
-
tile_count_semaphore = scheduler_metadata;
|
| 968 |
-
} else {
|
| 969 |
-
tile_count_semaphore = torch::empty({metadata_size}, opts.dtype(torch::kInt32));
|
| 970 |
-
}
|
| 971 |
-
if (scheduler_needs_semaphore && !use_dynamic_split) {
|
| 972 |
-
tile_count_semaphore.zero_(); // If varlen we'll manually do the zero-ing
|
| 973 |
-
}
|
| 974 |
-
params.tile_count_semaphore = scheduler_needs_semaphore ? tile_count_semaphore.data_ptr<int>() : nullptr;
|
| 975 |
-
params.num_splits_dynamic_ptr = use_dynamic_split ? tile_count_semaphore.data_ptr<int>() + 1 : nullptr;
|
| 976 |
-
}
|
| 977 |
-
|
| 978 |
-
if (q_v_.has_value()) {
|
| 979 |
-
TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64");
|
| 980 |
-
TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
|
| 981 |
-
"q_v is only supported for fp16 and bf16 data type");
|
| 982 |
-
TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs");
|
| 983 |
-
at::Tensor q_v = q_v_.value();
|
| 984 |
-
TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query");
|
| 985 |
-
CHECK_DEVICE(q_v);
|
| 986 |
-
TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension");
|
| 987 |
-
if (!is_varlen_q) {
|
| 988 |
-
CHECK_SHAPE(q_v, batch_size, seqlen_q, num_heads, head_size_v);
|
| 989 |
-
} else {
|
| 990 |
-
CHECK_SHAPE(q_v, total_q, num_heads, head_size_v);
|
| 991 |
-
}
|
| 992 |
-
params.qv_ptr = q_v.data_ptr();
|
| 993 |
-
// All stride are in elements, not bytes.
|
| 994 |
-
params.qv_row_stride = q_v.stride(-3);
|
| 995 |
-
params.qv_head_stride = q_v.stride(-2);
|
| 996 |
-
if (!is_varlen_q) {
|
| 997 |
-
params.qv_batch_stride = q_v.stride(0);
|
| 998 |
-
}
|
| 999 |
-
}
|
| 1000 |
-
|
| 1001 |
-
if (rotary_cos_.has_value()) {
|
| 1002 |
-
TORCH_CHECK(k_new_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided");
|
| 1003 |
-
auto rotary_cos = rotary_cos_.value();
|
| 1004 |
-
CHECK_DEVICE(rotary_cos); CHECK_CONTIGUOUS(rotary_cos);
|
| 1005 |
-
params.rotary_dim = rotary_cos.size(1) * 2;
|
| 1006 |
-
TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
|
| 1007 |
-
TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
|
| 1008 |
-
const int seqlen_ro = rotary_cos.size(0);
|
| 1009 |
-
if (paged_KV) {
|
| 1010 |
-
TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
|
| 1011 |
-
}
|
| 1012 |
-
CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
|
| 1013 |
-
TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
|
| 1014 |
-
|
| 1015 |
-
TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
|
| 1016 |
-
auto rotary_sin = rotary_sin_.value();
|
| 1017 |
-
CHECK_DEVICE(rotary_sin); CHECK_CONTIGUOUS(rotary_sin);
|
| 1018 |
-
CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
|
| 1019 |
-
TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
|
| 1020 |
-
params.rotary_cos_ptr = rotary_cos.data_ptr();
|
| 1021 |
-
params.rotary_sin_ptr = rotary_sin.data_ptr();
|
| 1022 |
-
params.is_rotary_interleaved = is_rotary_interleaved;
|
| 1023 |
-
if (seqlens_rotary_.has_value()) {
|
| 1024 |
-
at::Tensor seqlens_rotary = seqlens_rotary_.value();
|
| 1025 |
-
CHECK_DEVICE(seqlens_rotary); CHECK_CONTIGUOUS(seqlens_rotary);
|
| 1026 |
-
TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32");
|
| 1027 |
-
CHECK_SHAPE(seqlens_rotary, batch_size);
|
| 1028 |
-
params.seqlens_rotary = seqlens_rotary.data_ptr<int>();
|
| 1029 |
-
}
|
| 1030 |
-
} else {
|
| 1031 |
-
params.rotary_dim = 0;
|
| 1032 |
-
}
|
| 1033 |
-
|
| 1034 |
-
if (kv_batch_idx_.has_value()) {
|
| 1035 |
-
auto kv_batch_idx = kv_batch_idx_.value();
|
| 1036 |
-
CHECK_DEVICE(kv_batch_idx); CHECK_CONTIGUOUS(kv_batch_idx);
|
| 1037 |
-
TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32");
|
| 1038 |
-
params.kv_batch_idx = reinterpret_cast<int *>(kv_batch_idx.data_ptr());
|
| 1039 |
-
}
|
| 1040 |
-
|
| 1041 |
-
at::Tensor out_accum, softmax_lse_accum;
|
| 1042 |
-
auto outaccum_type = at::ScalarType::Float;
|
| 1043 |
-
if (params.num_splits > 1) {
|
| 1044 |
-
TORCH_CHECK(params.num_splits <= 256, "num_splits > 256 not supported");
|
| 1045 |
-
if (!is_varlen_q) {
|
| 1046 |
-
out_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_v}, opts.dtype(outaccum_type));
|
| 1047 |
-
softmax_lse_accum = torch::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
| 1048 |
-
params.oaccum_batch_stride = out_accum.stride(1);
|
| 1049 |
-
params.lseaccum_batch_stride = softmax_lse_accum.stride(1);
|
| 1050 |
-
} else {
|
| 1051 |
-
out_accum = torch::empty({params.num_splits, num_heads, total_q, head_size_v}, opts.dtype(outaccum_type));
|
| 1052 |
-
softmax_lse_accum = torch::empty({params.num_splits, num_heads, total_q}, opts.dtype(at::kFloat));
|
| 1053 |
-
}
|
| 1054 |
-
params.is_fp32 = false;
|
| 1055 |
-
params.oaccum_ptr = out_accum.data_ptr();
|
| 1056 |
-
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
|
| 1057 |
-
params.oaccum_split_stride = out_accum.stride(0);
|
| 1058 |
-
params.oaccum_row_stride = out_accum.stride(-2);
|
| 1059 |
-
params.oaccum_head_stride = out_accum.stride(-3);
|
| 1060 |
-
params.lseaccum_split_stride = softmax_lse_accum.stride(0);
|
| 1061 |
-
params.lseaccum_head_stride = softmax_lse_accum.stride(-2);
|
| 1062 |
-
}
|
| 1063 |
-
|
| 1064 |
-
if (q_type == at::ScalarType::Float8_e4m3fn) {
|
| 1065 |
-
if (q_descale_.has_value()) {
|
| 1066 |
-
auto q_descale = q_descale_.value();
|
| 1067 |
-
CHECK_DEVICE(q_descale);
|
| 1068 |
-
CHECK_SHAPE(q_descale, batch_size, num_heads_k);
|
| 1069 |
-
params.q_descale_ptr = q_descale.data_ptr<float>();
|
| 1070 |
-
params.q_descale_batch_stride = q_descale.stride(0);
|
| 1071 |
-
params.q_descale_head_stride = q_descale.stride(1);
|
| 1072 |
-
} else {
|
| 1073 |
-
params.q_descale_ptr = nullptr;
|
| 1074 |
-
}
|
| 1075 |
-
if (k_descale_.has_value()) {
|
| 1076 |
-
auto k_descale = k_descale_.value();
|
| 1077 |
-
CHECK_DEVICE(k_descale);
|
| 1078 |
-
CHECK_SHAPE(k_descale, batch_size, num_heads_k);
|
| 1079 |
-
params.k_descale_ptr = k_descale.data_ptr<float>();
|
| 1080 |
-
params.k_descale_batch_stride = k_descale.stride(0);
|
| 1081 |
-
params.k_descale_head_stride = k_descale.stride(1);
|
| 1082 |
-
} else {
|
| 1083 |
-
params.k_descale_ptr = nullptr;
|
| 1084 |
-
}
|
| 1085 |
-
if (v_descale_.has_value()) {
|
| 1086 |
-
auto v_descale = v_descale_.value();
|
| 1087 |
-
CHECK_DEVICE(v_descale);
|
| 1088 |
-
CHECK_SHAPE(v_descale, batch_size, num_heads_k);
|
| 1089 |
-
params.v_descale_ptr = v_descale.data_ptr<float>();
|
| 1090 |
-
params.v_descale_batch_stride = v_descale.stride(0);
|
| 1091 |
-
params.v_descale_head_stride = v_descale.stride(1);
|
| 1092 |
-
} else {
|
| 1093 |
-
params.v_descale_ptr = nullptr;
|
| 1094 |
-
}
|
| 1095 |
-
}
|
| 1096 |
-
|
| 1097 |
-
#ifdef FLASHATTENTION_DISABLE_LOCAL
|
| 1098 |
-
TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
|
| 1099 |
-
#endif
|
| 1100 |
-
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
|
| 1101 |
-
TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
|
| 1102 |
-
#endif
|
| 1103 |
-
#ifdef FLASHATTENTION_DISABLE_SPLIT
|
| 1104 |
-
TORCH_CHECK(params.num_splits == 1, "This flash attention build does not support splits.");
|
| 1105 |
-
#endif
|
| 1106 |
-
#ifdef FLASHATTENTION_DISABLE_PACKGQA
|
| 1107 |
-
TORCH_CHECK(!params.pack_gqa || params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1, "This flash attention build does not support pack_gqa.");
|
| 1108 |
-
#endif
|
| 1109 |
-
#ifdef FLASHATTENTION_DISABLE_PAGEDKV
|
| 1110 |
-
TORCH_CHECK(!(params.page_table && !params.pagedkv_tma), "This flash attention build does not support paged KV.");
|
| 1111 |
-
#endif
|
| 1112 |
-
#ifdef FLASHATTENTION_DISABLE_APPENDKV
|
| 1113 |
-
TORCH_CHECK(!k_new_.has_value(), "This flash attention build does not support appending KV.");
|
| 1114 |
-
#endif
|
| 1115 |
-
|
| 1116 |
-
if (total_q > 0 && (total_k + params.total_knew) > 0 && num_heads_k > 0) {
|
| 1117 |
-
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 1118 |
-
run_mha_fwd(params, stream);
|
| 1119 |
-
if (params.num_splits > 1) {
|
| 1120 |
-
if (out_type == at::ScalarType::BFloat16) {
|
| 1121 |
-
// Since we want output in BF16. Otherwise fwd_combine will output to FP16
|
| 1122 |
-
params.is_bf16 = true;
|
| 1123 |
-
}
|
| 1124 |
-
// Unless there's seqused_q, for the purpose of attn_combine, we can just treat it as batch=1
|
| 1125 |
-
// and seqlen = total_q, and don't need to dispatch to Varlen there.
|
| 1126 |
-
// However, with dynamic split, each row needs to know which batch it belongs to
|
| 1127 |
-
// to read the number of splits, so we just use the varlen version of combine kernel.
|
| 1128 |
-
// if (is_varlen_q && !seqused_q_.has_value()) {
|
| 1129 |
-
// if (is_varlen_q) {
|
| 1130 |
-
// params.b = 1;
|
| 1131 |
-
// params.seqlen_q = total_q;
|
| 1132 |
-
// }
|
| 1133 |
-
// This will zero out the semaphore if needed
|
| 1134 |
-
run_mha_fwd_combine(params, stream, true /*enable_pdl*/);
|
| 1135 |
-
} else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation) {
|
| 1136 |
-
// need to zero out the semaphore in this case
|
| 1137 |
-
tile_count_semaphore.index({torch::indexing::Slice(0, 1)}).zero_();
|
| 1138 |
-
}
|
| 1139 |
-
} else if (total_q > 0 && num_heads_k > 0) {
|
| 1140 |
-
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
|
| 1141 |
-
out.zero_();
|
| 1142 |
-
softmax_lse.fill_(std::numeric_limits<float>::infinity());
|
| 1143 |
-
}
|
| 1144 |
-
|
| 1145 |
-
// return {out, softmax_lse};
|
| 1146 |
-
return {out, softmax_lse, out_accum, softmax_lse_accum};
|
| 1147 |
-
}
|
| 1148 |
-
|
| 1149 |
-
#ifdef FLASHATTENTION_DISABLE_BACKWARD
|
| 1150 |
-
void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 1151 |
-
TORCH_CHECK(false, "Flash-Attention was built with backward disabled");
|
| 1152 |
-
}
|
| 1153 |
-
#else
|
| 1154 |
-
template <int Arch, bool Has_softcap>
|
| 1155 |
-
void run_mha_bwd_constexpr(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 1156 |
-
if (!params.is_bf16) {
|
| 1157 |
-
#ifndef FLASHATTENTION_DISABLE_FP16
|
| 1158 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM64
|
| 1159 |
-
if (params.d_rounded == 64) { return run_mha_bwd_<Arch, cutlass::half_t, 64, Has_softcap>(params, stream); }
|
| 1160 |
-
#endif
|
| 1161 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM96
|
| 1162 |
-
if (params.d_rounded == 96) { return run_mha_bwd_<Arch, cutlass::half_t, 96, Has_softcap>(params, stream); }
|
| 1163 |
-
#endif
|
| 1164 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
| 1165 |
-
if (params.d_rounded == 128) { return run_mha_bwd_<Arch, cutlass::half_t, 128, Has_softcap>(params, stream); }
|
| 1166 |
-
#endif
|
| 1167 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
| 1168 |
-
if (params.d_rounded == 192) { return run_mha_bwd_<Arch, cutlass::half_t, 192, Has_softcap>(params, stream); }
|
| 1169 |
-
#endif
|
| 1170 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
| 1171 |
-
if (params.d_rounded == 256) { return run_mha_bwd_<Arch, cutlass::half_t, 256, Has_softcap>(params, stream); }
|
| 1172 |
-
#endif
|
| 1173 |
-
#else
|
| 1174 |
-
TORCH_CHECK(false, "This flash attention build does not support FP16.");
|
| 1175 |
-
#endif
|
| 1176 |
-
} else {
|
| 1177 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM64
|
| 1178 |
-
if (params.d_rounded == 64) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 64, Has_softcap>(params, stream); }
|
| 1179 |
-
#endif
|
| 1180 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM96
|
| 1181 |
-
if (params.d_rounded == 96) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 96, Has_softcap>(params, stream); }
|
| 1182 |
-
#endif
|
| 1183 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
| 1184 |
-
if (params.d_rounded == 128) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 128, Has_softcap>(params, stream); }
|
| 1185 |
-
#endif
|
| 1186 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
| 1187 |
-
if (params.d_rounded == 192) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 192, Has_softcap>(params, stream); }
|
| 1188 |
-
#endif
|
| 1189 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
| 1190 |
-
if (params.d_rounded == 256) { return run_mha_bwd_<Arch, cutlass::bfloat16_t, 256, Has_softcap>(params, stream); }
|
| 1191 |
-
#endif
|
| 1192 |
-
}
|
| 1193 |
-
}
|
| 1194 |
-
|
| 1195 |
-
void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 1196 |
-
// FP16_SWITCH(!params.is_bf16, [&] {
|
| 1197 |
-
// HEADDIM_SWITCH(params.d, [&] {
|
| 1198 |
-
// run_mha_bwd_<elem_type, kHeadDim>(params, stream);
|
| 1199 |
-
// });
|
| 1200 |
-
// });
|
| 1201 |
-
ARCH_SWITCH(params.arch, Arch, [&] {
|
| 1202 |
-
SOFTCAP_SWITCH(params.softcap > 0.f, Has_softcap, [&] {
|
| 1203 |
-
run_mha_bwd_constexpr<Arch, Has_softcap>(params, stream);
|
| 1204 |
-
});
|
| 1205 |
-
});
|
| 1206 |
-
}
|
| 1207 |
-
#endif
|
| 1208 |
-
|
| 1209 |
-
|
| 1210 |
-
// b: batch_size
|
| 1211 |
-
// s_q: seqlen_q
|
| 1212 |
-
// s_k: seqlen_k
|
| 1213 |
-
// h: num_heads
|
| 1214 |
-
// h_k: num_heads_k
|
| 1215 |
-
// d: head_size
|
| 1216 |
-
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
|
| 1217 |
-
at::Tensor dout, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
|
| 1218 |
-
at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
|
| 1219 |
-
at::Tensor k, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
|
| 1220 |
-
at::Tensor v, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k
|
| 1221 |
-
at::Tensor out, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
|
| 1222 |
-
at::Tensor softmax_lse, // (b, h, s_q) or (h, total_q) if there is cu_seqlens_q
|
| 1223 |
-
std::optional<at::Tensor> dq_, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
|
| 1224 |
-
std::optional<at::Tensor> dk_, // (b, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k
|
| 1225 |
-
std::optional<at::Tensor> dv_, // (b, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k
|
| 1226 |
-
std::optional<at::Tensor> cu_seqlens_q_, // b+1
|
| 1227 |
-
std::optional<at::Tensor> cu_seqlens_k_, // b+1
|
| 1228 |
-
std::optional<at::Tensor> seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used.
|
| 1229 |
-
std::optional<at::Tensor> seqused_k_, // b. If given, only this many elements of each batch element's keys are used.
|
| 1230 |
-
std::optional<int64_t> max_seqlen_q_,
|
| 1231 |
-
std::optional<int64_t> max_seqlen_k_,
|
| 1232 |
-
std::optional<double> softmax_scale_,
|
| 1233 |
-
bool is_causal,
|
| 1234 |
-
int64_t window_size_left,
|
| 1235 |
-
int64_t window_size_right,
|
| 1236 |
-
double softcap,
|
| 1237 |
-
bool deterministic,
|
| 1238 |
-
int64_t sm_margin
|
| 1239 |
-
) {
|
| 1240 |
-
|
| 1241 |
-
#ifdef FLASHATTENTION_DISABLE_BACKWARD
|
| 1242 |
-
TORCH_CHECK(false, "This flash attention build does not support backward.");
|
| 1243 |
-
#endif
|
| 1244 |
-
|
| 1245 |
-
auto dprops = at::cuda::getCurrentDeviceProperties();
|
| 1246 |
-
bool is_sm8x = dprops->major >= 8;
|
| 1247 |
-
TORCH_CHECK(is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
|
| 1248 |
-
|
| 1249 |
-
auto q_type = q.dtype();
|
| 1250 |
-
TORCH_CHECK(q_type == torch::kFloat16 || q_type == torch::kBFloat16,
|
| 1251 |
-
"FlashAttention only support fp16 and bf16 data type");
|
| 1252 |
-
TORCH_CHECK(k.dtype() == q_type, "query and key must have the same dtype");
|
| 1253 |
-
TORCH_CHECK(v.dtype() == q_type, "query and value must have the same dtype");
|
| 1254 |
-
TORCH_CHECK(out.dtype() == q_type, "query and out must have the same dtype");
|
| 1255 |
-
TORCH_CHECK(dout.dtype() == q_type, "query and dout must have the same dtype");
|
| 1256 |
-
|
| 1257 |
-
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v);
|
| 1258 |
-
CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse);
|
| 1259 |
-
|
| 1260 |
-
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 1261 |
-
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 1262 |
-
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 1263 |
-
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
| 1264 |
-
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
| 1265 |
-
|
| 1266 |
-
at::Tensor cu_seqlens_q;
|
| 1267 |
-
bool const is_varlen_q = cu_seqlens_q_.has_value();
|
| 1268 |
-
if (is_varlen_q) {
|
| 1269 |
-
cu_seqlens_q = cu_seqlens_q_.value();
|
| 1270 |
-
CHECK_DEVICE(cu_seqlens_q); CHECK_CONTIGUOUS(cu_seqlens_q);
|
| 1271 |
-
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
|
| 1272 |
-
TORCH_CHECK(max_seqlen_q_.has_value(), "max_seqlen_q must be provided if cu_seqlens_q is provided");
|
| 1273 |
-
}
|
| 1274 |
-
at::Tensor cu_seqlens_k;
|
| 1275 |
-
bool const is_varlen_k = cu_seqlens_k_.has_value();
|
| 1276 |
-
if (is_varlen_k) {
|
| 1277 |
-
cu_seqlens_k = cu_seqlens_k_.value();
|
| 1278 |
-
CHECK_DEVICE(cu_seqlens_k); CHECK_CONTIGUOUS(cu_seqlens_k);
|
| 1279 |
-
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
|
| 1280 |
-
TORCH_CHECK(max_seqlen_k_.has_value(), "max_seqlen_k must be provided if cu_seqlens_k is provided");
|
| 1281 |
-
}
|
| 1282 |
-
// This is what we will template on
|
| 1283 |
-
bool const is_varlen = is_varlen_q || is_varlen_k || seqused_q_.has_value() || seqused_k_.has_value();
|
| 1284 |
-
#ifdef FLASHATTENTION_DISABLE_VARLEN
|
| 1285 |
-
TORCH_CHECK(!is_varlen, "This flash attention build does not support varlen.");
|
| 1286 |
-
#endif
|
| 1287 |
-
|
| 1288 |
-
auto const sizes = q.sizes();
|
| 1289 |
-
int const batch_size = !is_varlen_q ? sizes[0] : cu_seqlens_q.size(0) - 1;
|
| 1290 |
-
int const seqlen_q = !is_varlen_q ? sizes[1] : max_seqlen_q_.value();
|
| 1291 |
-
int const total_q = !is_varlen_q ? batch_size * sizes[1] : sizes[0];
|
| 1292 |
-
int const num_heads = q.size(-2);
|
| 1293 |
-
int const head_size = q.size(-1);
|
| 1294 |
-
int const head_size_v = v.size(-1);
|
| 1295 |
-
int const seqlen_k = !is_varlen_k ? k.size(1) : max_seqlen_k_.value();
|
| 1296 |
-
int const total_k = !is_varlen_k ? batch_size * k.size(1) : k.size(0);
|
| 1297 |
-
int const num_heads_k = k.size(-2);
|
| 1298 |
-
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
| 1299 |
-
TORCH_CHECK(head_size_v % 8 == 0, "head_size_v should be a multiple of 8");
|
| 1300 |
-
int const max_headdim = get_max_headdim();
|
| 1301 |
-
TORCH_CHECK(std::max(head_size, head_size_v) <= max_headdim, "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim));
|
| 1302 |
-
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
| 1303 |
-
double softmax_scale = 1.0 / sqrt(double(head_size));
|
| 1304 |
-
if (softmax_scale_.has_value()) {
|
| 1305 |
-
softmax_scale = softmax_scale_.value();
|
| 1306 |
-
}
|
| 1307 |
-
|
| 1308 |
-
// This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
|
| 1309 |
-
if (window_size_left >= seqlen_k - 1) { window_size_left = -1; }
|
| 1310 |
-
if (window_size_right >= seqlen_q - 1) { window_size_right = -1; }
|
| 1311 |
-
if (is_causal) { window_size_right = 0; }
|
| 1312 |
-
// There's a case where is_causal=false, window_size=(-1, 0). Then set_params_bprop will set params.is_causal=true.
|
| 1313 |
-
// If we don't have is_causal here matching params.is_causal, we might get the wrong kBlockM (and cause IMA).
|
| 1314 |
-
is_causal = window_size_left < 0 && window_size_right == 0;
|
| 1315 |
-
|
| 1316 |
-
int const arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
|
| 1317 |
-
int const head_size_rounded = round_up_headdim(std::max(head_size, head_size_v));
|
| 1318 |
-
int const head_size_v_rounded = head_size_rounded;
|
| 1319 |
-
// Very important that these match the kernel configs
|
| 1320 |
-
bool const is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal;
|
| 1321 |
-
int const kBlockM_sm90 = head_size_rounded <= 64 ? (is_causal && softcap > 0.0 ? 96 : 128)
|
| 1322 |
-
: (head_size_rounded <= 96 ? 64
|
| 1323 |
-
: (head_size_rounded <= 128 ? (is_causal || is_local || softcap > 0.0 ? 64 : 80)
|
| 1324 |
-
: 64));
|
| 1325 |
-
int const kBlockM_sm80 = head_size_rounded <= 64 ? 128 : 64;
|
| 1326 |
-
int const kBlockM_sm86 = head_size_rounded <= 192 ? 64 : 32;
|
| 1327 |
-
int const kBlockM = arch >= 90 ? kBlockM_sm90 : (arch == 86 || arch == 89 ? kBlockM_sm86 : kBlockM_sm80);
|
| 1328 |
-
int const kBlockN_sm90 = head_size_rounded <= 128
|
| 1329 |
-
? 128
|
| 1330 |
-
: (head_size_rounded <= 192 ? 96 : 80);
|
| 1331 |
-
int const kBlockN_sm80 = head_size_rounded <= 128
|
| 1332 |
-
? 128
|
| 1333 |
-
: (head_size_rounded <= 192 ? 80 : 64);
|
| 1334 |
-
int const kBlockN_sm86 = head_size_rounded <= 64 ? 128
|
| 1335 |
-
: (head_size_rounded <= 96 ? 128
|
| 1336 |
-
: (head_size_rounded <= 128 ? 96
|
| 1337 |
-
: (head_size_rounded <= 192 ? 64 : 64)));
|
| 1338 |
-
int const kBlockN = arch >= 90 ? kBlockN_sm90 : (arch == 86 || arch == 89 ? kBlockN_sm86 : kBlockN_sm80);
|
| 1339 |
-
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 1340 |
-
int const seqlen_q_rounded = round_multiple(seqlen_q, kBlockM);
|
| 1341 |
-
int const seqlen_k_rounded = round_multiple(seqlen_k, kBlockN);
|
| 1342 |
-
int const total_q_padded_rounded = round_multiple(total_q + batch_size * kBlockM, kBlockM);
|
| 1343 |
-
int const total_k_padded_rounded = round_multiple(total_k + batch_size * kBlockN, kBlockN);
|
| 1344 |
-
|
| 1345 |
-
if (!is_varlen_q) {
|
| 1346 |
-
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
| 1347 |
-
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_v);
|
| 1348 |
-
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_v);
|
| 1349 |
-
} else {
|
| 1350 |
-
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
| 1351 |
-
CHECK_SHAPE(out, total_q, num_heads, head_size_v);
|
| 1352 |
-
CHECK_SHAPE(dout, total_q, num_heads, head_size_v);
|
| 1353 |
-
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
| 1354 |
-
}
|
| 1355 |
-
if (!is_varlen_k) {
|
| 1356 |
-
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
|
| 1357 |
-
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_v);
|
| 1358 |
-
} else {
|
| 1359 |
-
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
|
| 1360 |
-
CHECK_SHAPE(v, total_k, num_heads_k, head_size_v);
|
| 1361 |
-
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
| 1362 |
-
}
|
| 1363 |
-
|
| 1364 |
-
if (seqused_q_.has_value()){
|
| 1365 |
-
auto seqused_q = seqused_q_.value();
|
| 1366 |
-
TORCH_CHECK(seqused_q.dtype() == torch::kInt32, "seqused_q must have dtype int32");
|
| 1367 |
-
CHECK_DEVICE(seqused_q); CHECK_CONTIGUOUS(seqused_q);
|
| 1368 |
-
CHECK_SHAPE(seqused_q, batch_size);
|
| 1369 |
-
}
|
| 1370 |
-
if (seqused_k_.has_value()){
|
| 1371 |
-
auto seqused_k = seqused_k_.value();
|
| 1372 |
-
TORCH_CHECK(seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
|
| 1373 |
-
CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
|
| 1374 |
-
CHECK_SHAPE(seqused_k, batch_size);
|
| 1375 |
-
}
|
| 1376 |
-
|
| 1377 |
-
at::Tensor dq, dk, dv;
|
| 1378 |
-
if (dq_.has_value()) {
|
| 1379 |
-
dq = dq_.value();
|
| 1380 |
-
TORCH_CHECK(dq.dtype() == q_type, "dq must have the same dtype as q");
|
| 1381 |
-
CHECK_DEVICE(dq);
|
| 1382 |
-
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
| 1383 |
-
if (!is_varlen_q) {
|
| 1384 |
-
CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
|
| 1385 |
-
} else {
|
| 1386 |
-
CHECK_SHAPE(dq, total_q, num_heads, head_size);
|
| 1387 |
-
}
|
| 1388 |
-
} else {
|
| 1389 |
-
dq = torch::empty_like(q);
|
| 1390 |
-
}
|
| 1391 |
-
if (dk_.has_value()) {
|
| 1392 |
-
dk = dk_.value();
|
| 1393 |
-
TORCH_CHECK(dk.dtype() == q_type, "dk must have the same dtype as q");
|
| 1394 |
-
CHECK_DEVICE(dk);
|
| 1395 |
-
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
| 1396 |
-
if (!is_varlen_k) {
|
| 1397 |
-
CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
|
| 1398 |
-
} else {
|
| 1399 |
-
CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
|
| 1400 |
-
}
|
| 1401 |
-
} else {
|
| 1402 |
-
dk = torch::empty_like(k);
|
| 1403 |
-
}
|
| 1404 |
-
if (dv_.has_value()) {
|
| 1405 |
-
dv = dv_.value();
|
| 1406 |
-
TORCH_CHECK(dv.dtype() == q_type, "dv must have the same dtype as q");
|
| 1407 |
-
CHECK_DEVICE(dv);
|
| 1408 |
-
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
| 1409 |
-
if (!is_varlen_k) {
|
| 1410 |
-
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size_v);
|
| 1411 |
-
} else {
|
| 1412 |
-
CHECK_SHAPE(dv, total_k, num_heads_k, head_size_v);
|
| 1413 |
-
}
|
| 1414 |
-
} else {
|
| 1415 |
-
dv = torch::empty_like(v);
|
| 1416 |
-
}
|
| 1417 |
-
|
| 1418 |
-
// Otherwise the kernel will be launched from cuda:0 device
|
| 1419 |
-
// Cast to char to avoid compiler warning about narrowing
|
| 1420 |
-
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
| 1421 |
-
|
| 1422 |
-
auto opts = q.options();
|
| 1423 |
-
// Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
|
| 1424 |
-
at::Tensor softmax_d, softmax_lse_log2;
|
| 1425 |
-
if (!is_varlen) {
|
| 1426 |
-
// Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
|
| 1427 |
-
softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
| 1428 |
-
softmax_lse_log2 = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
| 1429 |
-
} else {
|
| 1430 |
-
softmax_d = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
|
| 1431 |
-
softmax_lse_log2 = torch::empty({num_heads, total_q_padded_rounded}, opts.dtype(at::kFloat));
|
| 1432 |
-
}
|
| 1433 |
-
at::Tensor dq_accum, dk_accum, dv_accum;
|
| 1434 |
-
if (!is_varlen) {
|
| 1435 |
-
dq_accum = torch::empty({batch_size, num_heads, seqlen_q_rounded * head_size_rounded}, opts.dtype(at::kFloat));
|
| 1436 |
-
} else {
|
| 1437 |
-
dq_accum = torch::empty({num_heads, total_q_padded_rounded * head_size_rounded}, opts.dtype(at::kFloat));
|
| 1438 |
-
}
|
| 1439 |
-
if (num_heads_k != num_heads) { // MQA / GQA
|
| 1440 |
-
if (!is_varlen) {
|
| 1441 |
-
dk_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_rounded}, opts.dtype(at::kFloat));
|
| 1442 |
-
dv_accum = torch::zeros({batch_size, num_heads_k, seqlen_k_rounded * head_size_v_rounded}, opts.dtype(at::kFloat));
|
| 1443 |
-
} else {
|
| 1444 |
-
dk_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
| 1445 |
-
dv_accum = torch::zeros({num_heads_k, total_k_padded_rounded, head_size_v_rounded}, opts.dtype(at::kFloat));
|
| 1446 |
-
}
|
| 1447 |
-
}
|
| 1448 |
-
|
| 1449 |
-
Flash_bwd_params params;
|
| 1450 |
-
set_params_dgrad(params,
|
| 1451 |
-
batch_size,
|
| 1452 |
-
seqlen_q, seqlen_k,
|
| 1453 |
-
seqlen_q_rounded, seqlen_k_rounded,
|
| 1454 |
-
num_heads, num_heads_k,
|
| 1455 |
-
head_size, head_size_rounded,
|
| 1456 |
-
q, k, v, out,
|
| 1457 |
-
dout, dq, dk, dv,
|
| 1458 |
-
!is_varlen_q ? nullptr : cu_seqlens_q.data_ptr(),
|
| 1459 |
-
!is_varlen_k ? nullptr : cu_seqlens_k.data_ptr(),
|
| 1460 |
-
seqused_q_.has_value() ? seqused_q_.value().data_ptr() : nullptr,
|
| 1461 |
-
seqused_k_.has_value() ? seqused_k_.value().data_ptr() : nullptr,
|
| 1462 |
-
dq_accum.data_ptr(),
|
| 1463 |
-
num_heads_k != num_heads ? dk_accum.data_ptr() : nullptr,
|
| 1464 |
-
num_heads_k != num_heads ? dv_accum.data_ptr() : nullptr,
|
| 1465 |
-
softmax_lse.data_ptr(),
|
| 1466 |
-
softmax_d.data_ptr(),
|
| 1467 |
-
/*p_dropout=*/0.f,
|
| 1468 |
-
softmax_scale,
|
| 1469 |
-
window_size_left,
|
| 1470 |
-
window_size_right,
|
| 1471 |
-
0, // attention_chunk
|
| 1472 |
-
softcap,
|
| 1473 |
-
deterministic,
|
| 1474 |
-
sm_margin);
|
| 1475 |
-
params.total_q = total_q;
|
| 1476 |
-
params.total_k = total_k;
|
| 1477 |
-
params.softmax_lse_log2_ptr = softmax_lse_log2.data_ptr();
|
| 1478 |
-
params.dv = head_size_v;
|
| 1479 |
-
params.dv_rounded = head_size_v_rounded;
|
| 1480 |
-
|
| 1481 |
-
// auto tile_count_semaphore = (params.is_causal || params.is_local) ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32));
|
| 1482 |
-
// params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
|
| 1483 |
-
// Will be zero'ed out in the backward preprocess kernel
|
| 1484 |
-
at::Tensor dq_semaphore = torch::empty({(seqlen_q + kBlockM - 1) / kBlockM, batch_size, num_heads}, opts.dtype(torch::kInt32));
|
| 1485 |
-
params.dq_semaphore = dq_semaphore.data_ptr<int>();
|
| 1486 |
-
if (num_heads_k != num_heads && params.deterministic) {
|
| 1487 |
-
// TODO: do we need to zero them out?
|
| 1488 |
-
at::Tensor dk_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
|
| 1489 |
-
at::Tensor dv_semaphore = torch::empty({(seqlen_k + kBlockN - 1) / kBlockN, batch_size, num_heads_k}, opts.dtype(torch::kInt32));
|
| 1490 |
-
params.dk_semaphore = dk_semaphore.data_ptr<int>();
|
| 1491 |
-
params.dv_semaphore = dv_semaphore.data_ptr<int>();
|
| 1492 |
-
}
|
| 1493 |
-
|
| 1494 |
-
#ifdef FLASHATTENTION_DISABLE_LOCAL
|
| 1495 |
-
TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
|
| 1496 |
-
#endif
|
| 1497 |
-
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
|
| 1498 |
-
TORCH_CHECK(params.softcap == 0.0, "This flash attention build does not support tanh softcapping.");
|
| 1499 |
-
#endif
|
| 1500 |
-
|
| 1501 |
-
if (total_q > 0 && total_k > 0 && num_heads_k > 0) {
|
| 1502 |
-
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 1503 |
-
run_mha_bwd(params, stream);
|
| 1504 |
-
} else if (total_k > 0 && num_heads_k > 0) {
|
| 1505 |
-
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
|
| 1506 |
-
dk.zero_();
|
| 1507 |
-
dv.zero_();
|
| 1508 |
-
softmax_d.zero_();
|
| 1509 |
-
} else if (total_q > 0 && num_heads_k > 0) {
|
| 1510 |
-
dq.zero_();
|
| 1511 |
-
softmax_d.zero_();
|
| 1512 |
-
}
|
| 1513 |
-
|
| 1514 |
-
return { dq, dk, dv, softmax_d, softmax_lse_log2, dq_accum, dk_accum, dv_accum };
|
| 1515 |
-
}
|
| 1516 |
-
|
| 1517 |
-
std::tuple<at::Tensor, at::Tensor>
|
| 1518 |
-
mha_combine(at::Tensor out_partial, // num_splits x batch_size x seqlen x num_heads x head_size
|
| 1519 |
-
at::Tensor lse_partial, // num_splits x batch_size x seqlen x num_heads
|
| 1520 |
-
std::optional<at::Tensor> out_, // batch_size x seqlen x num_heads x head_size
|
| 1521 |
-
std::optional<at::ScalarType> out_dtype_
|
| 1522 |
-
) {
|
| 1523 |
-
|
| 1524 |
-
auto dprops = at::cuda::getCurrentDeviceProperties();
|
| 1525 |
-
bool is_sm8x = dprops->major >= 8;
|
| 1526 |
-
TORCH_CHECK(is_sm8x, "Attention combine function only supports Ampere GPUs or newer.");
|
| 1527 |
-
|
| 1528 |
-
auto out_partial_type = out_partial.scalar_type();
|
| 1529 |
-
TORCH_CHECK(out_partial_type == at::ScalarType::Float, "Attention combine function only support fp32 data type");
|
| 1530 |
-
TORCH_CHECK(lse_partial.scalar_type() == at::ScalarType::Float, "Attention combine function only support fp32 data type");
|
| 1531 |
-
|
| 1532 |
-
CHECK_DEVICE(out_partial); CHECK_DEVICE(lse_partial);
|
| 1533 |
-
|
| 1534 |
-
TORCH_CHECK(out_partial.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
| 1535 |
-
TORCH_CHECK(lse_partial.stride(-2) == 1, "LSE tensor must be contiguous in the seqlen dimension");
|
| 1536 |
-
|
| 1537 |
-
const auto sizes = out_partial.sizes();
|
| 1538 |
-
|
| 1539 |
-
const int num_splits = sizes[0];
|
| 1540 |
-
const int batch_size = sizes[1];
|
| 1541 |
-
const int seqlen = sizes[2];
|
| 1542 |
-
const int num_heads = sizes[3];
|
| 1543 |
-
const int head_size_og = sizes[4];
|
| 1544 |
-
TORCH_CHECK(num_splits <= 256, "FlashAttention combine only supports num_splits at most 256");
|
| 1545 |
-
|
| 1546 |
-
CHECK_SHAPE(out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og);
|
| 1547 |
-
CHECK_SHAPE(lse_partial, num_splits, batch_size, seqlen, num_heads);
|
| 1548 |
-
|
| 1549 |
-
int const alignment = 4;
|
| 1550 |
-
at::Tensor out_partial_padded;
|
| 1551 |
-
auto pad = [](at::Tensor x, int alignment) {
|
| 1552 |
-
return x.size(-1) % alignment == 0 ? x : torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}));
|
| 1553 |
-
};
|
| 1554 |
-
out_partial_padded = pad(out_partial, alignment);
|
| 1555 |
-
|
| 1556 |
-
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 1557 |
-
const int head_size = round_multiple(head_size_og, alignment);
|
| 1558 |
-
|
| 1559 |
-
auto opts = out_partial.options();
|
| 1560 |
-
at::ScalarType out_type = out_dtype_.value_or(out_partial.scalar_type());
|
| 1561 |
-
TORCH_CHECK(out_type == at::ScalarType::Float || out_type == at::ScalarType::BFloat16 || out_type == at::ScalarType::Half, "Output type must be FP32, FP16 or BF16");
|
| 1562 |
-
at::Tensor out;
|
| 1563 |
-
if (out_.has_value()) {
|
| 1564 |
-
out = out_.value();
|
| 1565 |
-
TORCH_CHECK(out.scalar_type() == out_type);
|
| 1566 |
-
CHECK_DEVICE(out);
|
| 1567 |
-
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
| 1568 |
-
CHECK_SHAPE(out, batch_size, seqlen, num_heads, head_size_og);
|
| 1569 |
-
if (head_size_og % alignment != 0) {
|
| 1570 |
-
out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
|
| 1571 |
-
}
|
| 1572 |
-
} else {
|
| 1573 |
-
out = torch::empty({batch_size, seqlen, num_heads, head_size}, opts.dtype(out_type));
|
| 1574 |
-
}
|
| 1575 |
-
|
| 1576 |
-
// Otherwise the kernel will be launched from cuda:0 device
|
| 1577 |
-
// Cast to char to avoid compiler warning about narrowing
|
| 1578 |
-
at::cuda::CUDAGuard device_guard{(char)out_partial.get_device()};
|
| 1579 |
-
|
| 1580 |
-
auto softmax_lse = torch::empty({batch_size, num_heads, seqlen}, opts.dtype(at::kFloat)).transpose(1, 2);
|
| 1581 |
-
|
| 1582 |
-
Flash_fwd_params params {}; // Need to reset the params to set everything to zero
|
| 1583 |
-
params.is_fp32 = out_type == at::ScalarType::Float;
|
| 1584 |
-
params.is_bf16 = out_type == at::ScalarType::BFloat16;
|
| 1585 |
-
params.oaccum_ptr = out_partial_padded.data_ptr();
|
| 1586 |
-
params.softmax_lseaccum_ptr = lse_partial.data_ptr();
|
| 1587 |
-
params.o_ptr = out.data_ptr();
|
| 1588 |
-
params.softmax_lse_ptr = softmax_lse.data_ptr();
|
| 1589 |
-
params.b = batch_size;
|
| 1590 |
-
params.h = num_heads;
|
| 1591 |
-
params.seqlen_q = seqlen;
|
| 1592 |
-
params.dv = head_size;
|
| 1593 |
-
params.num_splits = num_splits;
|
| 1594 |
-
params.oaccum_split_stride = out_partial_padded.stride(0);
|
| 1595 |
-
params.oaccum_row_stride = out_partial_padded.stride(2);
|
| 1596 |
-
params.oaccum_head_stride = out_partial_padded.stride(3);
|
| 1597 |
-
params.oaccum_batch_stride = out_partial_padded.stride(1);
|
| 1598 |
-
params.lseaccum_split_stride = lse_partial.stride(0);
|
| 1599 |
-
params.lseaccum_head_stride = lse_partial.stride(3);
|
| 1600 |
-
params.lseaccum_batch_stride = lse_partial.stride(1);
|
| 1601 |
-
params.o_row_stride = out.stride(1);
|
| 1602 |
-
params.o_head_stride = out.stride(2);
|
| 1603 |
-
params.o_batch_stride = out.stride(0);
|
| 1604 |
-
params.arch = at::cuda::getCurrentDeviceProperties()->major * 10 + at::cuda::getCurrentDeviceProperties()->minor;
|
| 1605 |
-
|
| 1606 |
-
if (seqlen > 0 && batch_size > 0) {
|
| 1607 |
-
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 1608 |
-
run_mha_fwd_combine(params, stream, false /*enable_pdl*/);
|
| 1609 |
-
}
|
| 1610 |
-
|
| 1611 |
-
at::Tensor out_padded = out;
|
| 1612 |
-
if (head_size_og % alignment != 0) {
|
| 1613 |
-
out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
|
| 1614 |
-
// if (out_.has_value()) { out_.value().copy_(out); }
|
| 1615 |
-
}
|
| 1616 |
-
|
| 1617 |
-
return {out, softmax_lse};
|
| 1618 |
-
}
|
| 1619 |
-
|
| 1620 |
-
#ifdef false
|
| 1621 |
-
|
| 1622 |
-
TORCH_LIBRARY(flash_attn_3, m) {
|
| 1623 |
-
m.def("fwd("
|
| 1624 |
-
"Tensor q,"
|
| 1625 |
-
"Tensor k,"
|
| 1626 |
-
"Tensor v,"
|
| 1627 |
-
"Tensor(k_new!)? k_new = None,"
|
| 1628 |
-
"Tensor(v_new!)? v_new = None,"
|
| 1629 |
-
"Tensor? q_v = None,"
|
| 1630 |
-
"Tensor(out!)? out = None,"
|
| 1631 |
-
"Tensor? cu_seqlens_q = None,"
|
| 1632 |
-
"Tensor? cu_seqlens_k = None,"
|
| 1633 |
-
"Tensor? cu_seqlens_k_new = None,"
|
| 1634 |
-
"Tensor? seqused_q = None,"
|
| 1635 |
-
"Tensor? seqused_k = None,"
|
| 1636 |
-
"int? max_seqlen_q = None,"
|
| 1637 |
-
"int? max_seqlen_k = None,"
|
| 1638 |
-
"Tensor? page_table = None,"
|
| 1639 |
-
"Tensor? kv_batch_idx = None,"
|
| 1640 |
-
"Tensor? leftpad_k = None,"
|
| 1641 |
-
"Tensor? rotary_cos = None,"
|
| 1642 |
-
"Tensor? rotary_sin = None,"
|
| 1643 |
-
"Tensor? seqlens_rotary = None,"
|
| 1644 |
-
"Tensor? q_descale = None,"
|
| 1645 |
-
"Tensor? k_descale = None,"
|
| 1646 |
-
"Tensor? v_descale = None,"
|
| 1647 |
-
"float? softmax_scale = None,"
|
| 1648 |
-
"bool is_causal = False,"
|
| 1649 |
-
"int window_size_left = -1,"
|
| 1650 |
-
"int window_size_right = -1,"
|
| 1651 |
-
"int attention_chunk = 0,"
|
| 1652 |
-
"float softcap = 0.0,"
|
| 1653 |
-
"bool is_rotary_interleaved = False,"
|
| 1654 |
-
"Tensor? scheduler_metadata = None,"
|
| 1655 |
-
"int num_splits = 0,"
|
| 1656 |
-
"bool? pack_gqa = None,"
|
| 1657 |
-
"int sm_margin = 0) -> (Tensor(out!), Tensor, Tensor, Tensor)");
|
| 1658 |
-
m.def("bwd("
|
| 1659 |
-
"Tensor dout,"
|
| 1660 |
-
"Tensor q,"
|
| 1661 |
-
"Tensor k,"
|
| 1662 |
-
"Tensor v,"
|
| 1663 |
-
"Tensor out,"
|
| 1664 |
-
"Tensor softmax_lse,"
|
| 1665 |
-
"Tensor(dq!)? dq = None,"
|
| 1666 |
-
"Tensor(dk!)? dk = None,"
|
| 1667 |
-
"Tensor(dv!)? dv = None,"
|
| 1668 |
-
"Tensor? cu_seqlens_q = None,"
|
| 1669 |
-
"Tensor? cu_seqlens_k = None,"
|
| 1670 |
-
"Tensor? seqused_q = None,"
|
| 1671 |
-
"Tensor? seqused_k = None,"
|
| 1672 |
-
"int? max_seqlen_q = None,"
|
| 1673 |
-
"int? max_seqlen_k = None,"
|
| 1674 |
-
"float? softmax_scale = None,"
|
| 1675 |
-
"bool is_causal = False,"
|
| 1676 |
-
"int window_size_left = -1,"
|
| 1677 |
-
"int window_size_right = -1,"
|
| 1678 |
-
"float softcap = 0.0,"
|
| 1679 |
-
"bool deterministic = False,"
|
| 1680 |
-
"int sm_margin = 0) -> (Tensor(dq!), Tensor(dk!), Tensor(dv!), Tensor, Tensor, Tensor, Tensor, Tensor)");
|
| 1681 |
-
m.def("fwd_combine("
|
| 1682 |
-
"Tensor out_partial,"
|
| 1683 |
-
"Tensor lse_partial,"
|
| 1684 |
-
"Tensor(out!)? out = None,"
|
| 1685 |
-
"ScalarType? out_dtype = None) -> (Tensor(out!), Tensor)");
|
| 1686 |
-
m.def("get_scheduler_metadata("
|
| 1687 |
-
"int batch_size,"
|
| 1688 |
-
"int max_seqlen_q,"
|
| 1689 |
-
"int max_seqlen_k,"
|
| 1690 |
-
"int num_heads,"
|
| 1691 |
-
"int num_heads_k,"
|
| 1692 |
-
"int headdim,"
|
| 1693 |
-
"int headdim_v,"
|
| 1694 |
-
"ScalarType qkv_dtype,"
|
| 1695 |
-
"Tensor seqused_k,"
|
| 1696 |
-
"Tensor? cu_seqlens_q = None,"
|
| 1697 |
-
"Tensor? cu_seqlens_k = None,"
|
| 1698 |
-
"Tensor? cu_seqlens_k_new = None,"
|
| 1699 |
-
"Tensor? seqused_q = None,"
|
| 1700 |
-
"Tensor? leftpad_k = None,"
|
| 1701 |
-
"int? page_size = None,"
|
| 1702 |
-
"int max_seqlen_k_new = 0,"
|
| 1703 |
-
"bool is_causal = False,"
|
| 1704 |
-
"int window_size_left = -1,"
|
| 1705 |
-
"int window_size_right = -1,"
|
| 1706 |
-
"int attention_chunk = 0,"
|
| 1707 |
-
"bool has_softcap = False,"
|
| 1708 |
-
"int num_splits = 0,"
|
| 1709 |
-
"bool? pack_gqa = None,"
|
| 1710 |
-
"int sm_margin = 0) -> Tensor");
|
| 1711 |
-
}
|
| 1712 |
-
|
| 1713 |
-
TORCH_LIBRARY_IMPL(flash_attn_3, CUDA, m) {
|
| 1714 |
-
m.impl("fwd", &mha_fwd);
|
| 1715 |
-
m.impl("bwd", &mha_bwd);
|
| 1716 |
-
m.impl("fwd_combine", &mha_combine);
|
| 1717 |
-
m.impl("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata);
|
| 1718 |
-
}
|
| 1719 |
-
|
| 1720 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/flash_bwd_kernel_sm80.h
DELETED
|
@@ -1,173 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
#include "cute/tensor.hpp"
|
| 8 |
-
|
| 9 |
-
#include <cutlass/cutlass.h>
|
| 10 |
-
#include <cutlass/array.h>
|
| 11 |
-
#include <cutlass/numeric_types.h>
|
| 12 |
-
#include <cutlass/kernel_hardware_info.h>
|
| 13 |
-
|
| 14 |
-
#include "utils.h"
|
| 15 |
-
|
| 16 |
-
namespace flash {
|
| 17 |
-
|
| 18 |
-
using namespace cute;
|
| 19 |
-
|
| 20 |
-
template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
|
| 21 |
-
class FlashAttnBwdSm80 {
|
| 22 |
-
|
| 23 |
-
public:
|
| 24 |
-
|
| 25 |
-
// Type Aliases
|
| 26 |
-
static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
|
| 27 |
-
static constexpr bool Is_local = CollectiveMainloop_::Is_local;
|
| 28 |
-
static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
|
| 29 |
-
static constexpr bool Varlen = CollectiveMainloop_::Varlen;
|
| 30 |
-
|
| 31 |
-
// Mainloop derived types
|
| 32 |
-
using CollectiveMainloop = CollectiveMainloop_;
|
| 33 |
-
using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
|
| 34 |
-
using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
|
| 35 |
-
using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
|
| 36 |
-
using ArchTag = typename CollectiveMainloop::ArchTag;
|
| 37 |
-
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
| 38 |
-
using MainloopParams = typename CollectiveMainloop::Params;
|
| 39 |
-
static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
|
| 40 |
-
|
| 41 |
-
// Epilogue derived types
|
| 42 |
-
using CollectiveEpilogue = CollectiveEpilogue_;
|
| 43 |
-
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
| 44 |
-
using EpilogueParams = typename CollectiveEpilogue::Params;
|
| 45 |
-
|
| 46 |
-
static_assert(ArchTag::kMinComputeCapability >= 80);
|
| 47 |
-
|
| 48 |
-
using TileScheduler = TileScheduler_;
|
| 49 |
-
using TileSchedulerArguments = typename flash::TileSchedulerArguments;
|
| 50 |
-
using TileSchedulerParams = typename TileScheduler::Params;
|
| 51 |
-
|
| 52 |
-
static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMmaSdP{}));
|
| 53 |
-
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{}));
|
| 54 |
-
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
| 55 |
-
|
| 56 |
-
// Kernel level shared memory storage
|
| 57 |
-
struct SharedStorage {
|
| 58 |
-
struct TensorStorage : cute::aligned_struct<128> {
|
| 59 |
-
union {
|
| 60 |
-
typename CollectiveMainloop::TensorStorage mainloop;
|
| 61 |
-
typename CollectiveEpilogue::TensorStorage epilogue;
|
| 62 |
-
};
|
| 63 |
-
} tensors;
|
| 64 |
-
|
| 65 |
-
alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
|
| 66 |
-
|
| 67 |
-
};
|
| 68 |
-
|
| 69 |
-
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
| 70 |
-
|
| 71 |
-
// Device side arguments
|
| 72 |
-
struct Arguments {
|
| 73 |
-
MainloopArguments mainloop{};
|
| 74 |
-
EpilogueArguments epilogue{};
|
| 75 |
-
cutlass::KernelHardwareInfo hw_info{};
|
| 76 |
-
TileSchedulerArguments scheduler{};
|
| 77 |
-
};
|
| 78 |
-
|
| 79 |
-
// Kernel entry point API
|
| 80 |
-
struct Params {
|
| 81 |
-
MainloopParams mainloop{};
|
| 82 |
-
EpilogueParams epilogue{};
|
| 83 |
-
cutlass::KernelHardwareInfo hw_info{};
|
| 84 |
-
TileSchedulerParams scheduler{};
|
| 85 |
-
};
|
| 86 |
-
|
| 87 |
-
//
|
| 88 |
-
// Methods
|
| 89 |
-
//
|
| 90 |
-
|
| 91 |
-
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
| 92 |
-
static
|
| 93 |
-
Params
|
| 94 |
-
to_underlying_arguments(Arguments const& args) {
|
| 95 |
-
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
| 96 |
-
|
| 97 |
-
// Get SM count if needed, otherwise use user supplied SM count
|
| 98 |
-
int sm_count = args.hw_info.sm_count;
|
| 99 |
-
if (sm_count <= 0) {
|
| 100 |
-
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
| 101 |
-
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
| 102 |
-
sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
|
| 103 |
-
}
|
| 104 |
-
|
| 105 |
-
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
| 106 |
-
|
| 107 |
-
cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
| 108 |
-
return {
|
| 109 |
-
CollectiveMainloop::to_underlying_arguments(args.mainloop),
|
| 110 |
-
CollectiveEpilogue::to_underlying_arguments(args.epilogue),
|
| 111 |
-
hw_info,
|
| 112 |
-
TileScheduler::to_underlying_arguments(args.scheduler)
|
| 113 |
-
};
|
| 114 |
-
}
|
| 115 |
-
|
| 116 |
-
// Computes the kernel launch grid shape based on runtime parameters
|
| 117 |
-
static dim3
|
| 118 |
-
get_grid_shape(Params const& params) {
|
| 119 |
-
return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
|
| 120 |
-
}
|
| 121 |
-
|
| 122 |
-
static dim3
|
| 123 |
-
get_block_shape() {
|
| 124 |
-
return dim3(MaxThreadsPerBlock, 1, 1);
|
| 125 |
-
}
|
| 126 |
-
|
| 127 |
-
CUTLASS_DEVICE
|
| 128 |
-
void
|
| 129 |
-
operator()(Params const& params, char* smem_buf) {
|
| 130 |
-
|
| 131 |
-
static constexpr int kBlockM = get<0>(TileShape_MNK{});
|
| 132 |
-
static constexpr int kBlockN = get<1>(TileShape_MNK{});
|
| 133 |
-
|
| 134 |
-
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
| 135 |
-
|
| 136 |
-
CollectiveMainloop mainloop;
|
| 137 |
-
CollectiveEpilogue epilogue;
|
| 138 |
-
|
| 139 |
-
TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
|
| 140 |
-
// Initialize matmul objects.
|
| 141 |
-
TiledMmadKV tiled_mma_dKV;
|
| 142 |
-
|
| 143 |
-
scheduler.init_consumer();
|
| 144 |
-
|
| 145 |
-
int warp_idx = cutlass::canonical_warp_idx_sync();
|
| 146 |
-
CUTLASS_PRAGMA_NO_UNROLL
|
| 147 |
-
for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
|
| 148 |
-
work_tile_info.is_valid(params.scheduler);
|
| 149 |
-
work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
|
| 150 |
-
|
| 151 |
-
auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
|
| 152 |
-
auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
|
| 153 |
-
cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
|
| 154 |
-
|
| 155 |
-
// dK and dV output accumulator.
|
| 156 |
-
Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
|
| 157 |
-
Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
|
| 158 |
-
bool tile_valid = mainloop.mma(params.mainloop, tdKrdK, tdVrdV, threadIdx.x,
|
| 159 |
-
block_coord, shared_storage);
|
| 160 |
-
scheduler.prefetch_next_work(params.scheduler, work_tile_info);
|
| 161 |
-
if (tile_valid) {
|
| 162 |
-
epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
|
| 163 |
-
threadIdx.x, block_coord);
|
| 164 |
-
} else {
|
| 165 |
-
epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);
|
| 166 |
-
}
|
| 167 |
-
}
|
| 168 |
-
|
| 169 |
-
}
|
| 170 |
-
|
| 171 |
-
};
|
| 172 |
-
|
| 173 |
-
} // namespace flash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/flash_bwd_kernel_sm90.h
DELETED
|
@@ -1,282 +0,0 @@
|
|
| 1 |
-
|
| 2 |
-
/******************************************************************************
|
| 3 |
-
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 4 |
-
******************************************************************************/
|
| 5 |
-
|
| 6 |
-
#pragma once
|
| 7 |
-
|
| 8 |
-
#include "cute/tensor.hpp"
|
| 9 |
-
|
| 10 |
-
#include <cutlass/cutlass.h>
|
| 11 |
-
#include <cutlass/arch/reg_reconfig.h>
|
| 12 |
-
#include <cutlass/array.h>
|
| 13 |
-
#include <cutlass/numeric_types.h>
|
| 14 |
-
#include <cutlass/numeric_conversion.h>
|
| 15 |
-
#include <cutlass/kernel_hardware_info.h>
|
| 16 |
-
#include "cutlass/pipeline/pipeline.hpp"
|
| 17 |
-
|
| 18 |
-
#include "utils.h"
|
| 19 |
-
|
| 20 |
-
namespace flash {
|
| 21 |
-
|
| 22 |
-
using namespace cute;
|
| 23 |
-
|
| 24 |
-
template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
|
| 25 |
-
class FlashAttnBwdSm90 {
|
| 26 |
-
|
| 27 |
-
public:
|
| 28 |
-
|
| 29 |
-
// Type Aliases
|
| 30 |
-
static constexpr bool Is_causal = CollectiveMainloop_::Is_causal;
|
| 31 |
-
static constexpr bool Is_local = CollectiveMainloop_::Is_local;
|
| 32 |
-
static_assert(CollectiveMainloop_::Varlen == CollectiveEpilogue_::Varlen);
|
| 33 |
-
static constexpr bool Varlen = CollectiveMainloop_::Varlen;
|
| 34 |
-
|
| 35 |
-
// Mainloop derived types
|
| 36 |
-
using CollectiveMainloop = CollectiveMainloop_;
|
| 37 |
-
using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
|
| 38 |
-
using TiledMmaSdP = typename CollectiveMainloop::TiledMmaSdP;
|
| 39 |
-
using TiledMmadKV = typename CollectiveMainloop::TiledMmadKV;
|
| 40 |
-
using ArchTag = typename CollectiveMainloop::ArchTag;
|
| 41 |
-
using ClusterShape = typename CollectiveMainloop::ClusterShape;
|
| 42 |
-
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
| 43 |
-
using MainloopParams = typename CollectiveMainloop::Params;
|
| 44 |
-
static constexpr bool dKV_swapAB = CollectiveMainloop::dKV_swapAB;
|
| 45 |
-
|
| 46 |
-
// Epilogue derived types
|
| 47 |
-
using CollectiveEpilogue = CollectiveEpilogue_;
|
| 48 |
-
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
| 49 |
-
using EpilogueParams = typename CollectiveEpilogue::Params;
|
| 50 |
-
|
| 51 |
-
static_assert(ArchTag::kMinComputeCapability >= 90);
|
| 52 |
-
|
| 53 |
-
using TileScheduler = TileScheduler_;
|
| 54 |
-
using TileSchedulerArguments = typename flash::TileSchedulerArguments;
|
| 55 |
-
using TileSchedulerParams = typename TileScheduler::Params;
|
| 56 |
-
|
| 57 |
-
static constexpr uint32_t NumLoadWarpGroups = 1;
|
| 58 |
-
static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaSdP{})) / cutlass::NumThreadsPerWarpGroup;
|
| 59 |
-
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaSdP{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
|
| 60 |
-
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
| 61 |
-
static_assert(NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
|
| 62 |
-
|
| 63 |
-
/// Register requirement for Load and Math WGs
|
| 64 |
-
static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 2 ? 24 : 32;
|
| 65 |
-
static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 240 : 160;
|
| 66 |
-
// If you want to print from the producer warp, you'd need to increase the number of registers
|
| 67 |
-
// Otherwise you'll get CUDA error.
|
| 68 |
-
// static constexpr uint32_t LoadRegisterRequirement = 40;
|
| 69 |
-
// static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
|
| 70 |
-
|
| 71 |
-
// Kernel level shared memory storage
|
| 72 |
-
struct SharedStorage {
|
| 73 |
-
struct TensorStorage : cute::aligned_struct<128> {
|
| 74 |
-
union {
|
| 75 |
-
typename CollectiveMainloop::TensorStorage mainloop;
|
| 76 |
-
typename CollectiveEpilogue::TensorStorage epilogue;
|
| 77 |
-
};
|
| 78 |
-
} tensors;
|
| 79 |
-
|
| 80 |
-
struct PipelineStorage : cute::aligned_struct<16> {
|
| 81 |
-
alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_KV;
|
| 82 |
-
alignas(16) typename CollectiveMainloop::MainloopPipeline::SharedStorage pipeline_q;
|
| 83 |
-
alignas(16) typename CollectiveMainloop::MainloopPipeline_dO::SharedStorage pipeline_do;
|
| 84 |
-
alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
|
| 85 |
-
} pipelines;
|
| 86 |
-
|
| 87 |
-
};
|
| 88 |
-
|
| 89 |
-
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
| 90 |
-
|
| 91 |
-
// Device side arguments
|
| 92 |
-
struct Arguments {
|
| 93 |
-
MainloopArguments mainloop{};
|
| 94 |
-
EpilogueArguments epilogue{};
|
| 95 |
-
cutlass::KernelHardwareInfo hw_info{};
|
| 96 |
-
TileSchedulerArguments scheduler{};
|
| 97 |
-
};
|
| 98 |
-
|
| 99 |
-
// Kernel entry point API
|
| 100 |
-
struct Params {
|
| 101 |
-
MainloopParams mainloop{};
|
| 102 |
-
EpilogueParams epilogue{};
|
| 103 |
-
cutlass::KernelHardwareInfo hw_info{};
|
| 104 |
-
TileSchedulerParams scheduler{};
|
| 105 |
-
};
|
| 106 |
-
|
| 107 |
-
//
|
| 108 |
-
// Methods
|
| 109 |
-
//
|
| 110 |
-
|
| 111 |
-
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
| 112 |
-
static
|
| 113 |
-
Params
|
| 114 |
-
to_underlying_arguments(Arguments const& args) {
|
| 115 |
-
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
| 116 |
-
|
| 117 |
-
// Get SM count if needed, otherwise use user supplied SM count
|
| 118 |
-
int sm_count = args.hw_info.sm_count;
|
| 119 |
-
if (sm_count <= 0) {
|
| 120 |
-
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
| 121 |
-
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
| 122 |
-
sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
|
| 123 |
-
}
|
| 124 |
-
|
| 125 |
-
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
| 126 |
-
|
| 127 |
-
cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
| 128 |
-
return {
|
| 129 |
-
CollectiveMainloop::to_underlying_arguments(args.mainloop),
|
| 130 |
-
CollectiveEpilogue::to_underlying_arguments(args.epilogue),
|
| 131 |
-
hw_info,
|
| 132 |
-
TileScheduler::to_underlying_arguments(args.scheduler)
|
| 133 |
-
};
|
| 134 |
-
}
|
| 135 |
-
|
| 136 |
-
// Computes the kernel launch grid shape based on runtime parameters
|
| 137 |
-
static dim3
|
| 138 |
-
get_grid_shape(Params const& params) {
|
| 139 |
-
return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
|
| 140 |
-
}
|
| 141 |
-
|
| 142 |
-
static dim3
|
| 143 |
-
get_block_shape() {
|
| 144 |
-
return dim3(MaxThreadsPerBlock, 1, 1);
|
| 145 |
-
}
|
| 146 |
-
|
| 147 |
-
CUTLASS_DEVICE
|
| 148 |
-
void
|
| 149 |
-
operator()(Params const& params, char* smem_buf) {
|
| 150 |
-
|
| 151 |
-
static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
|
| 152 |
-
static constexpr int NumCopyThreads = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
|
| 153 |
-
static constexpr int kBlockM = get<0>(TileShape_MNK{});
|
| 154 |
-
static constexpr int kBlockN = get<1>(TileShape_MNK{});
|
| 155 |
-
|
| 156 |
-
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
|
| 157 |
-
using PipelineParams = typename MainloopPipeline::Params;
|
| 158 |
-
using PipelineState = typename MainloopPipeline::PipelineState;
|
| 159 |
-
using MainloopPipeline_dO = typename CollectiveMainloop::MainloopPipeline_dO;
|
| 160 |
-
using PipelineParams_dO = typename MainloopPipeline_dO::Params;
|
| 161 |
-
using PipelineState_dO = typename MainloopPipeline_dO::PipelineState;
|
| 162 |
-
static constexpr bool Q_dO_same_stages = std::is_same_v<MainloopPipeline, MainloopPipeline_dO>;
|
| 163 |
-
|
| 164 |
-
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
| 165 |
-
|
| 166 |
-
int const lane_predicate = cute::elect_one_sync();
|
| 167 |
-
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
| 168 |
-
|
| 169 |
-
// Issue Tma Descriptor Prefetch from a single thread
|
| 170 |
-
if (warp_idx == 0 && lane_predicate) {
|
| 171 |
-
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
| 172 |
-
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
| 173 |
-
}
|
| 174 |
-
|
| 175 |
-
// Obtain warp index
|
| 176 |
-
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
| 177 |
-
|
| 178 |
-
PipelineParams pipeline_params;
|
| 179 |
-
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesQ + CollectiveMainloop::TmaTransactionBytesLSE;
|
| 180 |
-
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
| 181 |
-
pipeline_params.role = warp_group_idx == 0
|
| 182 |
-
? MainloopPipeline::ThreadCategory::Producer
|
| 183 |
-
: MainloopPipeline::ThreadCategory::Consumer;
|
| 184 |
-
pipeline_params.is_leader = warp_group_thread_idx == 0;
|
| 185 |
-
pipeline_params.num_consumers = NumMmaThreads;
|
| 186 |
-
|
| 187 |
-
if (warp_idx == 0 && lane_predicate) {
|
| 188 |
-
shared_storage.pipelines.barrier_KV.init(1 /*numThreads*/);
|
| 189 |
-
}
|
| 190 |
-
// We're counting on pipeline_q to call cutlass::arch::fence_barrier_init();
|
| 191 |
-
MainloopPipeline pipeline_q(shared_storage.pipelines.pipeline_q, pipeline_params, ClusterShape{});
|
| 192 |
-
auto role_dO = warp_group_idx == 0
|
| 193 |
-
? MainloopPipeline_dO::ThreadCategory::Producer
|
| 194 |
-
: MainloopPipeline_dO::ThreadCategory::Consumer;
|
| 195 |
-
PipelineParams_dO pipeline_params_dO {pipeline_params.transaction_bytes, role_dO, pipeline_params.is_leader, pipeline_params.num_consumers};
|
| 196 |
-
MainloopPipeline_dO pipeline_do(shared_storage.pipelines.pipeline_do, cute::conditional_return<Q_dO_same_stages>(pipeline_params, pipeline_params_dO), ClusterShape{});
|
| 197 |
-
|
| 198 |
-
CollectiveMainloop mainloop;
|
| 199 |
-
CollectiveEpilogue epilogue;
|
| 200 |
-
|
| 201 |
-
// We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
|
| 202 |
-
if constexpr (size(ClusterShape{}) > 1) {
|
| 203 |
-
cute::cluster_arrive_relaxed();
|
| 204 |
-
cute::cluster_wait();
|
| 205 |
-
} else {
|
| 206 |
-
__syncthreads();
|
| 207 |
-
}
|
| 208 |
-
|
| 209 |
-
TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
|
| 210 |
-
|
| 211 |
-
if (warp_group_idx == 0) { // Producer
|
| 212 |
-
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
| 213 |
-
|
| 214 |
-
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
|
| 215 |
-
if (warp_idx_in_warpgroup == 0) { // Load K, V, and do TMA on Q and dO
|
| 216 |
-
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
|
| 217 |
-
PipelineState_dO smem_pipe_write_do = cutlass::make_producer_start_state<MainloopPipeline_dO>();
|
| 218 |
-
for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler);
|
| 219 |
-
work_tile_info.is_valid(params.scheduler);
|
| 220 |
-
work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info)) {
|
| 221 |
-
auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
|
| 222 |
-
auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
|
| 223 |
-
cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
|
| 224 |
-
auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() {
|
| 225 |
-
scheduler.prefetch_next_work(params.scheduler, work_tile_info);
|
| 226 |
-
};
|
| 227 |
-
mainloop.load(params.mainloop, pipeline_q, pipeline_do, smem_pipe_write,
|
| 228 |
-
smem_pipe_write_do, shared_storage, scheduler_prefetch, block_coord);
|
| 229 |
-
}
|
| 230 |
-
mainloop.load_tail(pipeline_q, pipeline_do, smem_pipe_write, smem_pipe_write_do);
|
| 231 |
-
} else if (warp_idx_in_warpgroup == 1) {
|
| 232 |
-
for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
|
| 233 |
-
work_tile_info.is_valid(params.scheduler);
|
| 234 |
-
work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
|
| 235 |
-
auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
|
| 236 |
-
auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
|
| 237 |
-
cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
|
| 238 |
-
mainloop.store_dq(params.mainloop, shared_storage, block_coord);
|
| 239 |
-
}
|
| 240 |
-
}
|
| 241 |
-
} else { // Consumer
|
| 242 |
-
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
| 243 |
-
// Initialize matmul objects.
|
| 244 |
-
TiledMmadKV tiled_mma_dKV;
|
| 245 |
-
|
| 246 |
-
PipelineState smem_pipe_read;
|
| 247 |
-
PipelineState_dO smem_pipe_read_do;
|
| 248 |
-
|
| 249 |
-
mainloop.mma_init();
|
| 250 |
-
scheduler.init_consumer();
|
| 251 |
-
|
| 252 |
-
int work_idx = 0;
|
| 253 |
-
CUTLASS_PRAGMA_NO_UNROLL
|
| 254 |
-
for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
|
| 255 |
-
work_tile_info.is_valid(params.scheduler);
|
| 256 |
-
work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
|
| 257 |
-
auto block_coord_ = work_tile_info.get_block_coord(params.scheduler);
|
| 258 |
-
auto [n_block, bidh, bidb, _ /*split_idx*/] = block_coord_;
|
| 259 |
-
cute::tuple<int32_t, int32_t, int32_t> block_coord = {n_block, bidh, bidb};
|
| 260 |
-
|
| 261 |
-
// dK and dV output accumulator.
|
| 262 |
-
Tensor tdKrdK = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
|
| 263 |
-
Tensor tdVrdV = partition_fragment_C(tiled_mma_dKV, select<!dKV_swapAB ? 1 : 2, !dKV_swapAB? 2 : 1>(TileShape_MNK{}));
|
| 264 |
-
bool tile_valid = mainloop.mma(
|
| 265 |
-
params.mainloop, pipeline_q, pipeline_do, smem_pipe_read, smem_pipe_read_do,
|
| 266 |
-
tdKrdK, tdVrdV, threadIdx.x - NumCopyThreads, work_idx, block_coord, shared_storage);
|
| 267 |
-
if (tile_valid) {
|
| 268 |
-
epilogue.store(params.epilogue, tdKrdK, tdVrdV, shared_storage, tiled_mma_dKV,
|
| 269 |
-
threadIdx.x - NumCopyThreads, block_coord);
|
| 270 |
-
} else {
|
| 271 |
-
epilogue.store_zero(params.epilogue, threadIdx.x - NumCopyThreads, block_coord);
|
| 272 |
-
}
|
| 273 |
-
|
| 274 |
-
}
|
| 275 |
-
epilogue.store_tail();
|
| 276 |
-
}
|
| 277 |
-
|
| 278 |
-
}
|
| 279 |
-
|
| 280 |
-
};
|
| 281 |
-
|
| 282 |
-
} // namespace flash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/flash_bwd_launch_template.h
DELETED
|
@@ -1,390 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
#include "cute/tensor.hpp"
|
| 8 |
-
|
| 9 |
-
#include "cutlass/device_kernel.h" // For device_kernel
|
| 10 |
-
#include "cutlass/kernel_launch.h" // For kernel_launch
|
| 11 |
-
#include "cutlass/cluster_launch.hpp" // For ClusterLauncher
|
| 12 |
-
|
| 13 |
-
#include "static_switch.h"
|
| 14 |
-
#include "flash.h"
|
| 15 |
-
#include "flash_bwd_preprocess_kernel.h"
|
| 16 |
-
#include "flash_bwd_postprocess_kernel.h"
|
| 17 |
-
#include "tile_scheduler.hpp"
|
| 18 |
-
#include "mainloop_bwd_sm90_tma_gmma_ws.hpp"
|
| 19 |
-
#include "mainloop_bwd_sm80.hpp"
|
| 20 |
-
#include "epilogue_bwd.hpp"
|
| 21 |
-
#include "flash_bwd_kernel_sm90.h"
|
| 22 |
-
#include "flash_bwd_kernel_sm80.h"
|
| 23 |
-
|
| 24 |
-
using namespace cute;
|
| 25 |
-
|
| 26 |
-
template <int Arch, int kHeadDim, int kBlockM, int kBlockN, typename Element,
|
| 27 |
-
bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool Deterministic, bool GQA,
|
| 28 |
-
int Stages_dO=2, int Stages_dS_or_QSm80=2,
|
| 29 |
-
bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
|
| 30 |
-
int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
|
| 31 |
-
bool V_in_regs=false>
|
| 32 |
-
void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 33 |
-
static_assert(!(Is_causal && Is_local), "Is_causal and Is_local cannot be true at the same time.");
|
| 34 |
-
using ElementAccum = float;
|
| 35 |
-
using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
|
| 36 |
-
|
| 37 |
-
int const total_q_padded_rounded = cute::round_up(params.total_q + params.b * kBlockM, kBlockM);
|
| 38 |
-
int const total_k_padded_rounded = cute::round_up(params.total_k + params.b * kBlockN, kBlockN);
|
| 39 |
-
bool const is_varlen_q = params.cu_seqlens_q;
|
| 40 |
-
bool const is_varlen_k = params.cu_seqlens_k;
|
| 41 |
-
int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
|
| 42 |
-
int seqlen_k = !is_varlen_k ? params.seqlen_k : params.total_k;
|
| 43 |
-
int seqlen_q_rounded = !is_varlen_q ? params.seqlen_q_rounded : total_q_padded_rounded;
|
| 44 |
-
int seqlen_k_rounded = !is_varlen_k ? params.seqlen_k_rounded : total_k_padded_rounded;
|
| 45 |
-
int batch_q = !is_varlen_q ? params.b : 1;
|
| 46 |
-
int batch_k = !is_varlen_k ? params.b : 1;
|
| 47 |
-
|
| 48 |
-
using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kHeadDim>>;
|
| 49 |
-
using PreprocessKernel = flash::FlashAttnBwdPreprocess<TileShape_MK, Element, ElementAccum, ArchTag, /*Clear_dQaccum=*/true, Varlen>;
|
| 50 |
-
typename PreprocessKernel::Arguments preprocess_args {
|
| 51 |
-
static_cast<Element const*>(params.o_ptr),
|
| 52 |
-
{seqlen_q, params.dv, params.h, batch_q}, // shape_O
|
| 53 |
-
{params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O
|
| 54 |
-
static_cast<Element const*>(params.do_ptr),
|
| 55 |
-
{params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
|
| 56 |
-
static_cast<float*>(params.dsoftmax_sum),
|
| 57 |
-
{seqlen_q_rounded, params.h, batch_q}, // shape_dPsum
|
| 58 |
-
{_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
|
| 59 |
-
static_cast<float*>(params.softmax_lse_ptr),
|
| 60 |
-
{_1{}, seqlen_q, !is_varlen_q ? params.h * params.seqlen_q : 0}, // stride_LSE
|
| 61 |
-
static_cast<float*>(params.softmax_lse_log2_ptr),
|
| 62 |
-
{_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
|
| 63 |
-
static_cast<ElementAccum*>(params.dq_accum_ptr),
|
| 64 |
-
{seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
|
| 65 |
-
{_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0}, // stride_dQaccum
|
| 66 |
-
params.b,
|
| 67 |
-
params.dq_semaphore,
|
| 68 |
-
params.cu_seqlens_q,
|
| 69 |
-
params.seqused_q
|
| 70 |
-
};
|
| 71 |
-
typename PreprocessKernel::Params preprocess_params = PreprocessKernel::to_underlying_arguments(preprocess_args);
|
| 72 |
-
int num_m_block = cute::ceil_div(params.seqlen_q, kBlockM);
|
| 73 |
-
dim3 grid_m(num_m_block, params.h, params.b);
|
| 74 |
-
cutlass::kernel_launch<PreprocessKernel>(grid_m, PreprocessKernel::MaxThreadsPerBlock, PreprocessKernel::SharedStorageSize, stream, preprocess_params, false /*launch_with_pdl*/);
|
| 75 |
-
CHECK_CUDA_KERNEL_LAUNCH();
|
| 76 |
-
|
| 77 |
-
using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
|
| 78 |
-
using ClusterShape = cute::Shape<_1, Int<1>, _1>; // Currently doesn't not support cluster
|
| 79 |
-
// Stages_dS_or_QSm80 is Stages_dS if Sm90 and Stages if Sm80
|
| 80 |
-
static constexpr int Stages = Arch >= 90 ? 2 : Stages_dS_or_QSm80;
|
| 81 |
-
static constexpr int Stages_dS = Arch >= 90 ? Stages_dS_or_QSm80 : 1;
|
| 82 |
-
using CollectiveMainloop = std::conditional_t<
|
| 83 |
-
Arch >= 90,
|
| 84 |
-
flash::CollectiveMainloopBwdSm90<Stages, Stages_dO, Stages_dS, ClusterShape, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm90,
|
| 85 |
-
Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
|
| 86 |
-
SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>,
|
| 87 |
-
flash::CollectiveMainloopBwdSm80<Stages, Stages_dO, TileShape_MNK, Element, ElementAccum, cutlass::arch::Sm80,
|
| 88 |
-
Is_causal, Is_local, Has_softcap, Varlen, Deterministic,
|
| 89 |
-
SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>
|
| 90 |
-
>;
|
| 91 |
-
using CollectiveEpilogue = std::conditional_t<
|
| 92 |
-
!GQA,
|
| 93 |
-
flash::CollectiveEpilogueBwd<TileShape_MNK, Element, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, dKV_swapAB, NumMmaWarpGroups * (Arch >= 90 ? 1 : cutlass::NumWarpsPerWarpGroup) / AtomLayoutNdKV>,
|
| 94 |
-
flash::CollectiveEpilogueBwdGQA<TileShape_MNK, ElementAccum, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, Deterministic>
|
| 95 |
-
>;
|
| 96 |
-
using Scheduler = std::conditional_t<
|
| 97 |
-
Is_causal && !Varlen,
|
| 98 |
-
flash::SingleTileBwdLPTScheduler,
|
| 99 |
-
flash::SingleTileScheduler<Varlen, false /*Split*/, false /*PackGQA*/, kBlockN>
|
| 100 |
-
>;
|
| 101 |
-
using AttnKernel = std::conditional_t<
|
| 102 |
-
Arch >= 90,
|
| 103 |
-
flash::enable_sm90_or_later<flash::FlashAttnBwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
|
| 104 |
-
flash::enable_sm80_to_sm89<flash::FlashAttnBwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
|
| 105 |
-
>;
|
| 106 |
-
|
| 107 |
-
typename CollectiveMainloop::Arguments mainloop_args {
|
| 108 |
-
static_cast<Element const*>(params.q_ptr),
|
| 109 |
-
{seqlen_q, params.d, params.h, batch_q}, // shape_Q
|
| 110 |
-
{params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
|
| 111 |
-
static_cast<Element const*>(params.k_ptr),
|
| 112 |
-
{seqlen_k, params.d, params.h_k, batch_k}, // shape_K
|
| 113 |
-
{params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
|
| 114 |
-
static_cast<Element const*>(params.v_ptr),
|
| 115 |
-
{seqlen_k, params.dv, params.h_k, batch_k}, // shape_V
|
| 116 |
-
{params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V
|
| 117 |
-
static_cast<Element const*>(params.do_ptr),
|
| 118 |
-
{seqlen_q, params.dv, params.h, batch_q}, // shape_dO
|
| 119 |
-
{params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO
|
| 120 |
-
static_cast<ElementAccum*>(params.dq_accum_ptr),
|
| 121 |
-
{seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
|
| 122 |
-
{_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
|
| 123 |
-
static_cast<float*>(params.softmax_lse_log2_ptr),
|
| 124 |
-
{seqlen_q_rounded, params.h, batch_q}, // shape_LSE
|
| 125 |
-
{_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2
|
| 126 |
-
static_cast<float*>(params.dsoftmax_sum),
|
| 127 |
-
{_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_dPsum
|
| 128 |
-
params.scale_softmax,
|
| 129 |
-
params.window_size_left, params.window_size_right, 0 /*attention_chunk*/,
|
| 130 |
-
params.softcap,
|
| 131 |
-
params.b,
|
| 132 |
-
params.dq_semaphore,
|
| 133 |
-
params.cu_seqlens_q, params.cu_seqlens_k,
|
| 134 |
-
params.seqused_q, params.seqused_k
|
| 135 |
-
};
|
| 136 |
-
// The case work with GQA is ugly but idk how to fix it.
|
| 137 |
-
typename CollectiveEpilogue::Arguments epilogue_args {
|
| 138 |
-
static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dk_ptr : params.dk_accum_ptr),
|
| 139 |
-
[&] {
|
| 140 |
-
if constexpr (!GQA) {
|
| 141 |
-
return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k}; // shape_dK
|
| 142 |
-
} else {
|
| 143 |
-
return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}; // shape_dKaccum
|
| 144 |
-
}
|
| 145 |
-
}(),
|
| 146 |
-
[&] {
|
| 147 |
-
if constexpr (!GQA) {
|
| 148 |
-
return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0}; // stride_dK
|
| 149 |
-
} else {
|
| 150 |
-
return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dKaccum
|
| 151 |
-
}
|
| 152 |
-
}(),
|
| 153 |
-
static_cast<typename CollectiveEpilogue::Element*>(!GQA ? params.dv_ptr : params.dv_accum_ptr),
|
| 154 |
-
[&] {
|
| 155 |
-
if constexpr (!GQA) {
|
| 156 |
-
return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.dv, params.h, batch_k}; // shape_dV
|
| 157 |
-
} else {
|
| 158 |
-
return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}; // shape_dVaccum
|
| 159 |
-
}
|
| 160 |
-
}(),
|
| 161 |
-
[&] {
|
| 162 |
-
if constexpr (!GQA) {
|
| 163 |
-
return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV
|
| 164 |
-
} else {
|
| 165 |
-
return typename CollectiveEpilogue::StridedKV {_1{}, params.dv_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.dv_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum
|
| 166 |
-
}
|
| 167 |
-
}(),
|
| 168 |
-
params.h,
|
| 169 |
-
params.dk_semaphore,
|
| 170 |
-
params.dv_semaphore,
|
| 171 |
-
params.cu_seqlens_k,
|
| 172 |
-
params.seqused_k,
|
| 173 |
-
};
|
| 174 |
-
|
| 175 |
-
int num_blocks_n = cutlass::ceil_div(params.seqlen_k, get<1>(TileShape_MNK{}));
|
| 176 |
-
num_blocks_n = cutlass::round_up(num_blocks_n, size<1>(ClusterShape{}));
|
| 177 |
-
typename flash::TileSchedulerArguments scheduler_args {
|
| 178 |
-
num_blocks_n, params.h, params.b, 1 /*num_splits*/,
|
| 179 |
-
params.h / params.h_k,
|
| 180 |
-
params.seqlen_k,
|
| 181 |
-
params.seqlen_q, params.d, params.dv, sizeof(Element),
|
| 182 |
-
params.tile_count_semaphore, params.cu_seqlens_k, params.seqused_k
|
| 183 |
-
};
|
| 184 |
-
|
| 185 |
-
int device;
|
| 186 |
-
cudaGetDevice(&device);
|
| 187 |
-
typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
|
| 188 |
-
mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
|
| 189 |
-
});
|
| 190 |
-
|
| 191 |
-
dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
|
| 192 |
-
dim3 block_dims = AttnKernel::get_block_shape();
|
| 193 |
-
int smem_size = AttnKernel::SharedStorageSize;
|
| 194 |
-
// int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
|
| 195 |
-
// int smem_size_do = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_do));
|
| 196 |
-
// int smem_size_ds = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_ds));
|
| 197 |
-
// int smem_size_dqacc = [&] {
|
| 198 |
-
// if constexpr (Arch >= 90) {
|
| 199 |
-
// return sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dqacc));
|
| 200 |
-
// } else {
|
| 201 |
-
// return 0;
|
| 202 |
-
// }
|
| 203 |
-
// }();
|
| 204 |
-
// int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
|
| 205 |
-
// int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
|
| 206 |
-
// int smem_size_lse = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_lse));
|
| 207 |
-
// int smem_size_dpsum = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_dpsum));
|
| 208 |
-
// printf("smem_size = %d, q = %d, k = %d, v = %d, do = %d, ds = %d, dqacc = %d, lse = %d, dpsum = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v, smem_size_do, smem_size_ds, smem_size_dqacc, smem_size_lse, smem_size_dpsum);
|
| 209 |
-
if constexpr (size(ClusterShape{}) > 1) {
|
| 210 |
-
void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
|
| 211 |
-
if (smem_size >= 48 * 1024) {
|
| 212 |
-
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
| 213 |
-
}
|
| 214 |
-
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
|
| 215 |
-
cutlass::ClusterLauncher::launch(
|
| 216 |
-
grid_dims, cluster_dims, block_dims, smem_size, stream, kernel, kernel_params, false /*launch_with_pdl*/);
|
| 217 |
-
} else {
|
| 218 |
-
if (smem_size >= 48 * 1024) {
|
| 219 |
-
CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<AttnKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
| 220 |
-
}
|
| 221 |
-
cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params, false /*launch_with_pdl*/);
|
| 222 |
-
}
|
| 223 |
-
CHECK_CUDA_KERNEL_LAUNCH();
|
| 224 |
-
|
| 225 |
-
using PostprocessKernel = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_MK, Element, ElementAccum, ArchTag,
|
| 226 |
-
AttnKernel::CollectiveMainloop::NumMmaThreads,
|
| 227 |
-
typename AttnKernel::CollectiveMainloop::TiledMmadQ,
|
| 228 |
-
AttnKernel::CollectiveMainloop::dQ_swapAB
|
| 229 |
-
>;
|
| 230 |
-
typename PostprocessKernel::Arguments postprocess_args {
|
| 231 |
-
static_cast<ElementAccum const*>(params.dq_accum_ptr),
|
| 232 |
-
{seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum
|
| 233 |
-
{_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum
|
| 234 |
-
static_cast<Element*>(params.dq_ptr),
|
| 235 |
-
{seqlen_q, params.d, params.h, batch_q}, // shape_dQ
|
| 236 |
-
{params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride}, // stride_dQ
|
| 237 |
-
params.scale_softmax,
|
| 238 |
-
params.cu_seqlens_q,
|
| 239 |
-
params.seqused_q
|
| 240 |
-
};
|
| 241 |
-
typename PostprocessKernel::Params postprocess_params = PostprocessKernel::to_underlying_arguments(postprocess_args);
|
| 242 |
-
int num_m_block_postprocess = cute::ceil_div(params.seqlen_q, get<0>(TileShape_MK{}));
|
| 243 |
-
dim3 grid_m_postprocess(num_m_block_postprocess, params.h, params.b);
|
| 244 |
-
int smem_size_postprocess = PostprocessKernel::SharedStorageSize;
|
| 245 |
-
if (smem_size_postprocess >= 48 * 1024) {
|
| 246 |
-
CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));
|
| 247 |
-
}
|
| 248 |
-
cutlass::kernel_launch<PostprocessKernel>(grid_m_postprocess, PostprocessKernel::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_params, false /*launch_with_pdl*/);
|
| 249 |
-
CHECK_CUDA_KERNEL_LAUNCH();
|
| 250 |
-
|
| 251 |
-
if constexpr (GQA) {
|
| 252 |
-
using TileShape_NK = cute::Shape<Int<kBlockN>, Int<kHeadDim>>;
|
| 253 |
-
using PostprocessKerneldKV = flash::FlashAttnBwdPostprocessConvertdQ<TileShape_NK, Element, ElementAccum, ArchTag,
|
| 254 |
-
AttnKernel::CollectiveEpilogue::NumEpilogueThreads,
|
| 255 |
-
typename AttnKernel::CollectiveMainloop::TiledMmadKV,
|
| 256 |
-
AttnKernel::CollectiveMainloop::dKV_swapAB
|
| 257 |
-
>;
|
| 258 |
-
typename PostprocessKerneldKV::Arguments postprocess_dK_args {
|
| 259 |
-
static_cast<ElementAccum const*>(params.dk_accum_ptr),
|
| 260 |
-
{seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum
|
| 261 |
-
{_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum
|
| 262 |
-
static_cast<Element*>(params.dk_ptr),
|
| 263 |
-
{seqlen_k, params.d, params.h_k, batch_k}, // shape_dK
|
| 264 |
-
{params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride}, // stride_dK
|
| 265 |
-
1.f,
|
| 266 |
-
params.cu_seqlens_k,
|
| 267 |
-
params.seqused_k
|
| 268 |
-
};
|
| 269 |
-
typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args);
|
| 270 |
-
typename PostprocessKerneldKV::Arguments postprocess_dV_args {
|
| 271 |
-
static_cast<ElementAccum const*>(params.dv_accum_ptr),
|
| 272 |
-
{seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}, // shape_dVaccum
|
| 273 |
-
{_1{}, seqlen_k_rounded * params.dv_rounded, !is_varlen_k ? params.dv_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum
|
| 274 |
-
static_cast<Element*>(params.dv_ptr),
|
| 275 |
-
{seqlen_k, params.dv, params.h_k, batch_k}, // shape_dV
|
| 276 |
-
{params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV
|
| 277 |
-
1.f,
|
| 278 |
-
params.cu_seqlens_k,
|
| 279 |
-
params.seqused_k
|
| 280 |
-
};
|
| 281 |
-
typename PostprocessKerneldKV::Params postprocess_dV_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dV_args);
|
| 282 |
-
int num_n_block_postprocess = cute::ceil_div(params.seqlen_k, get<0>(TileShape_NK{}));
|
| 283 |
-
dim3 grid_n_postprocess(num_n_block_postprocess, params.h_k, params.b);
|
| 284 |
-
int smem_size_postprocess = PostprocessKerneldKV::SharedStorageSize;
|
| 285 |
-
if (smem_size_postprocess >= 48 * 1024) {
|
| 286 |
-
CHECK_CUDA(cudaFuncSetAttribute(cutlass::device_kernel<PostprocessKerneldKV>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_postprocess));
|
| 287 |
-
}
|
| 288 |
-
cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dK_params, false /*launch_with_pdl*/);
|
| 289 |
-
CHECK_CUDA_KERNEL_LAUNCH();
|
| 290 |
-
cutlass::kernel_launch<PostprocessKerneldKV>(grid_n_postprocess, PostprocessKerneldKV::MaxThreadsPerBlock, smem_size_postprocess, stream, postprocess_dV_params, false /*launch_with_pdl*/);
|
| 291 |
-
CHECK_CUDA_KERNEL_LAUNCH();
|
| 292 |
-
}
|
| 293 |
-
|
| 294 |
-
}
|
| 295 |
-
|
| 296 |
-
template<int Arch, typename T, int kBlockM, int kBlockN, int kHeadDim, bool Is_causal, bool Is_local, bool Has_softcap,
|
| 297 |
-
int Stages_dO=2, int Stages_dS_or_QSm80=2,
|
| 298 |
-
bool SdP_swapAB=true, bool dKV_swapAB=false, bool dQ_swapAB=false,
|
| 299 |
-
int NumMmaWarpGroups=2, int AtomLayoutMSdP=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=1,
|
| 300 |
-
bool V_in_regs=false>
|
| 301 |
-
void run_mha_bwd_dispatch(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 302 |
-
VARLEN_SWITCH(params.cu_seqlens_q != nullptr || params.cu_seqlens_k != nullptr, Varlen, [&] {
|
| 303 |
-
BOOL_SWITCH(params.h != params.h_k, GQA, [&] {
|
| 304 |
-
// BOOL_SWITCH(params.deterministic, Deterministic, [&] {
|
| 305 |
-
// run_flash_bwd<kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen, false, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ>(params, stream);
|
| 306 |
-
run_flash_bwd<Arch, kHeadDim, kBlockM, kBlockN, T, Is_causal, Is_local, Has_softcap, Varlen /*Varlen*/, false /*Deterministic*/, GQA, Stages_dO, Stages_dS_or_QSm80, SdP_swapAB, dKV_swapAB, dQ_swapAB, NumMmaWarpGroups, AtomLayoutMSdP, AtomLayoutNdKV, AtomLayoutMdQ, V_in_regs>(params, stream);
|
| 307 |
-
// });
|
| 308 |
-
});
|
| 309 |
-
});
|
| 310 |
-
}
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
template<int Arch, typename T, bool Has_softcap>
|
| 314 |
-
void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 315 |
-
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
|
| 316 |
-
if constexpr (Arch >= 90) {
|
| 317 |
-
if constexpr (Is_causal && Has_softcap) {
|
| 318 |
-
// register spill with 128 x 128
|
| 319 |
-
run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 2, false>(params, stream);
|
| 320 |
-
} else {
|
| 321 |
-
// With ShuffleStats we no longer have register spilling when Has_softcap and using 128 x 128 block.
|
| 322 |
-
run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 2, false>(params, stream);
|
| 323 |
-
}
|
| 324 |
-
} else if constexpr (Arch == 86 || Arch == 89) {
|
| 325 |
-
run_mha_bwd_dispatch<Arch, T, 64, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
|
| 326 |
-
// run_mha_bwd_dispatch<Arch, T, 96, 96, 64, Is_causal, Is_local, Has_softcap, 1, 2, false, true, true, 2, 2, 4, 4, false>(params, stream);
|
| 327 |
-
// run_mha_bwd_dispatch<Arch, T, 80, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 2, 4, 2, true>(params, stream);
|
| 328 |
-
// run_mha_bwd_dispatch<Arch, T, 96, 128, 64, Is_causal, Is_local, Has_softcap, 1, 2, true, false, true, 2, 1, 8, 4, false>(params, stream);
|
| 329 |
-
} else {
|
| 330 |
-
run_mha_bwd_dispatch<Arch, T, 128, 128, 64, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 4, 4, 4, false>(params, stream);
|
| 331 |
-
}
|
| 332 |
-
});
|
| 333 |
-
}
|
| 334 |
-
|
| 335 |
-
template<int Arch, typename T, bool Has_softcap>
|
| 336 |
-
void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 337 |
-
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
|
| 338 |
-
if constexpr (Arch >= 90) {
|
| 339 |
-
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, true>(params, stream);
|
| 340 |
-
} else if constexpr (Arch == 86 || Arch == 89) {
|
| 341 |
-
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 4, 2, true>(params, stream);
|
| 342 |
-
} else {
|
| 343 |
-
run_mha_bwd_dispatch<Arch, T, 64, 128, 96, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 4, 2, false>(params, stream);
|
| 344 |
-
}
|
| 345 |
-
});
|
| 346 |
-
}
|
| 347 |
-
|
| 348 |
-
template<int Arch, typename T, bool Has_softcap>
|
| 349 |
-
void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 350 |
-
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
|
| 351 |
-
if constexpr (Arch >= 90) {
|
| 352 |
-
if constexpr (Is_causal || Is_local || Has_softcap) {
|
| 353 |
-
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, false, 2, 1, 2, 1, false>(params, stream);
|
| 354 |
-
} else {
|
| 355 |
-
run_mha_bwd_dispatch<Arch, T, 80, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, true, false, true, 2, 1, 2, 1, false>(params, stream);
|
| 356 |
-
}
|
| 357 |
-
} else if constexpr (Arch == 86 || Arch == 89) {
|
| 358 |
-
run_mha_bwd_dispatch<Arch, T, 64, 96, 128, Is_causal, Is_local, Has_softcap, 1, 2, false, false, false, 2, 2, 2, 2, true>(params, stream);
|
| 359 |
-
} else {
|
| 360 |
-
run_mha_bwd_dispatch<Arch, T, 64, 128, 128, Is_causal, Is_local, Has_softcap, 2, 2, false, false, false, 2, 2, 2, 2, false>(params, stream);
|
| 361 |
-
}
|
| 362 |
-
});
|
| 363 |
-
}
|
| 364 |
-
|
| 365 |
-
template<int Arch, typename T, bool Has_softcap>
|
| 366 |
-
void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 367 |
-
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
|
| 368 |
-
if constexpr (Arch >= 90) {
|
| 369 |
-
run_mha_bwd_dispatch<Arch, T, 64, 96, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, true, false, 3, 1, 1, 1, false>(params, stream);
|
| 370 |
-
} else if constexpr (Arch == 86 || Arch == 89) {
|
| 371 |
-
run_mha_bwd_dispatch<Arch, T, 64, 64, 192, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 2, true>(params, stream);
|
| 372 |
-
} else {
|
| 373 |
-
run_mha_bwd_dispatch<Arch, T, 64, 80, 192, Is_causal, Is_local, Has_softcap, 1, 2, false, true, false, 2, 4, 2, 2, false>(params, stream);
|
| 374 |
-
}
|
| 375 |
-
});
|
| 376 |
-
}
|
| 377 |
-
|
| 378 |
-
template<int Arch, typename T, bool Has_softcap>
|
| 379 |
-
void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 380 |
-
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
|
| 381 |
-
if constexpr (Arch >= 90) {
|
| 382 |
-
run_mha_bwd_dispatch<Arch, T, 64, 80, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, true, true, 2, 1, 1, 1, false>(params, stream);
|
| 383 |
-
} else if constexpr (Arch == 86 || Arch == 89) {
|
| 384 |
-
run_mha_bwd_dispatch<Arch, T, 32, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 2, 2, 1, true>(params, stream);
|
| 385 |
-
// run_mha_bwd_dispatch<Arch, T, 64, 32, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 1, 2, true>(params, stream);
|
| 386 |
-
} else {
|
| 387 |
-
run_mha_bwd_dispatch<Arch, T, 64, 64, 256, Is_causal, Is_local, Has_softcap, 1, 1, false, false, false, 2, 4, 2, 2, false>(params, stream);
|
| 388 |
-
}
|
| 389 |
-
});
|
| 390 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/flash_bwd_postprocess_kernel.h
DELETED
|
@@ -1,256 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
#include "cute/tensor.hpp"
|
| 8 |
-
|
| 9 |
-
#include <cutlass/cutlass.h>
|
| 10 |
-
#include <cutlass/array.h>
|
| 11 |
-
#include <cutlass/numeric_types.h>
|
| 12 |
-
#include <cutlass/numeric_conversion.h>
|
| 13 |
-
#include "cutlass/arch/barrier.h"
|
| 14 |
-
|
| 15 |
-
#include "seqlen.h"
|
| 16 |
-
#include "utils.h"
|
| 17 |
-
|
| 18 |
-
namespace flash {
|
| 19 |
-
|
| 20 |
-
using namespace cute;
|
| 21 |
-
|
| 22 |
-
template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, int kNThreads, class TiledMma, bool dQ_swapAB>
|
| 23 |
-
class FlashAttnBwdPostprocessConvertdQ {
|
| 24 |
-
|
| 25 |
-
public:
|
| 26 |
-
|
| 27 |
-
// Type Aliases
|
| 28 |
-
using TileShape_MK = TileShape_MK_;
|
| 29 |
-
using ArchTag = ArchTag_;
|
| 30 |
-
|
| 31 |
-
static_assert(ArchTag::kMinComputeCapability >= 75);
|
| 32 |
-
static constexpr bool IsSm90 = ArchTag::kMinComputeCapability >= 90;
|
| 33 |
-
|
| 34 |
-
static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
|
| 35 |
-
static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
|
| 36 |
-
|
| 37 |
-
static constexpr int kBlockM = get<0>(TileShape_MK{});
|
| 38 |
-
static constexpr int kHeadDim = get<1>(TileShape_MK{});
|
| 39 |
-
static_assert(!IsSm90 || kNThreads % cutlass::NumThreadsPerWarpGroup == 0, "kNThreads must be a multiple of NumThreadsPerWarpGroup");
|
| 40 |
-
static constexpr int NumdQWarpGgroups = kNThreads / cutlass::NumThreadsPerWarpGroup;
|
| 41 |
-
using R2SLayoutAtomdQaccum = std::conditional_t<
|
| 42 |
-
IsSm90,
|
| 43 |
-
Layout<Shape<Int<cutlass::NumThreadsPerWarpGroup>, Int<NumdQWarpGgroups>>>,
|
| 44 |
-
Layout<Shape<Int<kNThreads>>>
|
| 45 |
-
>;
|
| 46 |
-
using R2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, R2SLayoutAtomdQaccum{},
|
| 47 |
-
Layout<Shape<Int<IsSm90 ? 4 : 1>>>{})); // Val layout, 1 or 4 vals per read
|
| 48 |
-
using G2SLayoutAtomdQaccum = Layout<Shape<Int<kNThreads>>>;
|
| 49 |
-
// UniversalCopy instead of AutoVectorizingCopyWithAssumedAlignment as the latter generates cp.async instructions
|
| 50 |
-
using G2STiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom<UniversalCopy<uint128_t>, ElementAccum>{}, G2SLayoutAtomdQaccum{},
|
| 51 |
-
Layout<Shape<_4>>{})); // Val layout, 4 vals per read
|
| 52 |
-
// We don't do bound checking for the gmem -> smem load so we just assert here.
|
| 53 |
-
static_assert(IsSm90 || (kBlockM * kHeadDim) % (kNThreads * 4) == 0);
|
| 54 |
-
static constexpr int SmemdQaccumSize = size(TileShape_MK{});
|
| 55 |
-
using SmemLayoutdQaccumFlat = Layout<Shape<Int<SmemdQaccumSize>>>;
|
| 56 |
-
using SmemLayoutdQaccum = std::conditional_t<
|
| 57 |
-
IsSm90,
|
| 58 |
-
Layout<Shape<Int<kBlockM * kHeadDim / NumdQWarpGgroups>, Int<NumdQWarpGgroups>>>,
|
| 59 |
-
Layout<Shape<Int<kBlockM * kHeadDim>>>
|
| 60 |
-
>;
|
| 61 |
-
|
| 62 |
-
// We can't just use kHeadDim here. E.g. if MMA shape is 64 x 96 but split across 2 WGs,
|
| 63 |
-
// then setting kBlockKSmem to 32 will cause "Static shape_div failure".
|
| 64 |
-
// We want to treat it as 64 x 48, so kBlockKSmem should be 16.
|
| 65 |
-
static constexpr int MmaShapeN = get<1>(typename TiledMma::AtomShape_MNK{});
|
| 66 |
-
static constexpr int kBlockKSmem = MmaShapeN % 64 == 0 ? 64 : (MmaShapeN % 32 == 0 ? 32 : 16);
|
| 67 |
-
static constexpr int kSwizzle = kBlockKSmem == 64 ? 3 : (kBlockKSmem == 32 ? 2 : 1);
|
| 68 |
-
using SmemLayoutAtomdQ =
|
| 69 |
-
decltype(composition(Swizzle<kSwizzle, 3, 3>{},
|
| 70 |
-
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
|
| 71 |
-
Stride<Int<kBlockKSmem>, _1>>{}));
|
| 72 |
-
using SmemLayoutdQ = decltype(tile_to_shape(SmemLayoutAtomdQ{}, TileShape_MK{}));
|
| 73 |
-
using SmemLayoutdQt =
|
| 74 |
-
decltype(cute::composition(SmemLayoutdQ{},
|
| 75 |
-
make_layout(make_shape(get<1>(TileShape_MK{}), get<0>(TileShape_MK{})),
|
| 76 |
-
make_stride(Int<get<0>(TileShape_MK{})>{}, _1{}))));
|
| 77 |
-
|
| 78 |
-
using SmemCopyAtomdQ = Copy_Atom<
|
| 79 |
-
std::conditional_t<
|
| 80 |
-
IsSm90,
|
| 81 |
-
std::conditional_t<!dQ_swapAB, cute::SM90_U32x4_STSM_N, cute::SM90_U16x8_STSM_T>,
|
| 82 |
-
AutoVectorizingCopyWithAssumedAlignment<128>
|
| 83 |
-
>,
|
| 84 |
-
Element>;
|
| 85 |
-
|
| 86 |
-
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
| 87 |
-
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
|
| 88 |
-
static constexpr int kGmemThreadsPerRow = cutlass::gcd(kHeadDim / kGmemElemsPerLoad, int(MaxThreadsPerBlock));
|
| 89 |
-
static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
|
| 90 |
-
using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
| 91 |
-
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
| 92 |
-
using GmemTiledCopy = decltype(
|
| 93 |
-
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
| 94 |
-
GmemLayoutAtom{},
|
| 95 |
-
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
|
| 96 |
-
|
| 97 |
-
struct SharedStorage : cute::aligned_struct<128> {
|
| 98 |
-
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdQaccum>> smem_dqacc;
|
| 99 |
-
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdQ>> smem_dq;
|
| 100 |
-
alignas(16) cutlass::arch::ClusterTransactionBarrier barrier_dQaccum;
|
| 101 |
-
};
|
| 102 |
-
|
| 103 |
-
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
| 104 |
-
|
| 105 |
-
using ShapedQ = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
|
| 106 |
-
using StridedQ = cute::Stride<int64_t, _1, int64_t, int64_t>;
|
| 107 |
-
using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
|
| 108 |
-
using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
|
| 109 |
-
|
| 110 |
-
// Device side arguments
|
| 111 |
-
struct Arguments {
|
| 112 |
-
ElementAccum const* ptr_dQaccum;
|
| 113 |
-
ShapedQaccum const shape_dQaccum;
|
| 114 |
-
StridedQaccum const stride_dQaccum;
|
| 115 |
-
Element* ptr_dQ;
|
| 116 |
-
ShapedQ const shape_dQ;
|
| 117 |
-
StridedQ const stride_dQ;
|
| 118 |
-
float const softmax_scale;
|
| 119 |
-
int const* cu_seqlens = nullptr;
|
| 120 |
-
int const* seqused = nullptr;
|
| 121 |
-
};
|
| 122 |
-
|
| 123 |
-
// Kernel entry point API
|
| 124 |
-
struct Params {
|
| 125 |
-
ElementAccum const* ptr_dQaccum;
|
| 126 |
-
ShapedQaccum const shape_dQaccum;
|
| 127 |
-
StridedQaccum const stride_dQaccum;
|
| 128 |
-
Element* ptr_dQ;
|
| 129 |
-
ShapedQ const shape_dQ;
|
| 130 |
-
StridedQ const stride_dQ;
|
| 131 |
-
float const softmax_scale;
|
| 132 |
-
int const* cu_seqlens = nullptr;
|
| 133 |
-
int const* seqused = nullptr;
|
| 134 |
-
};
|
| 135 |
-
|
| 136 |
-
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
| 137 |
-
static
|
| 138 |
-
Params
|
| 139 |
-
to_underlying_arguments(Arguments const& args) {
|
| 140 |
-
return {
|
| 141 |
-
args.ptr_dQaccum,
|
| 142 |
-
args.shape_dQaccum,
|
| 143 |
-
args.stride_dQaccum,
|
| 144 |
-
args.ptr_dQ,
|
| 145 |
-
args.shape_dQ,
|
| 146 |
-
args.stride_dQ,
|
| 147 |
-
args.softmax_scale,
|
| 148 |
-
args.cu_seqlens,
|
| 149 |
-
args.seqused
|
| 150 |
-
};
|
| 151 |
-
}
|
| 152 |
-
|
| 153 |
-
CUTLASS_DEVICE
|
| 154 |
-
void
|
| 155 |
-
operator()(Params const& params, char* smem_buf) {
|
| 156 |
-
|
| 157 |
-
static constexpr int kBlockM = get<0>(TileShape_MK{});
|
| 158 |
-
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
| 159 |
-
|
| 160 |
-
Tensor sdQaccum = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccum{});
|
| 161 |
-
Tensor sdQaccum_flat = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), SmemLayoutdQaccumFlat{});
|
| 162 |
-
Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQ{});
|
| 163 |
-
Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), SmemLayoutdQt{});
|
| 164 |
-
|
| 165 |
-
int const thread_idx = threadIdx.x;
|
| 166 |
-
int const m_block = blockIdx.x;
|
| 167 |
-
int const bidh = blockIdx.y;
|
| 168 |
-
int const bidb = blockIdx.z;
|
| 169 |
-
|
| 170 |
-
flash::SeqlenInfo<true /*Varlen*/, kBlockM> seqlen_info(bidb, size<0>(params.shape_dQ), params.cu_seqlens, params.seqused);
|
| 171 |
-
bool const is_varlen = params.cu_seqlens;
|
| 172 |
-
if (is_varlen && m_block * kBlockM >= seqlen_info.seqlen) { return; }
|
| 173 |
-
|
| 174 |
-
// Step 1: load dQaccum from gmem to smem
|
| 175 |
-
Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum const*>(params.ptr_dQaccum)),
|
| 176 |
-
params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
|
| 177 |
-
Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block)); // (M * K)
|
| 178 |
-
if constexpr (IsSm90) { // Use BulkCopy
|
| 179 |
-
static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast<uint32_t>(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v<ElementAccum> / 8);
|
| 180 |
-
auto bulk_copy = Copy_Traits<SM90_BULK_COPY_AUTO>{};
|
| 181 |
-
// if (thread0()) { print(gdQaccum); printf("\n"); print(sdQaccum_flat); printf("\n"); }
|
| 182 |
-
if (thread_idx == 0) {
|
| 183 |
-
shared_storage.barrier_dQaccum.init(1 /*numThreads*/);
|
| 184 |
-
shared_storage.barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum);
|
| 185 |
-
copy(bulk_copy.with(*reinterpret_cast<uint64_t*>(&shared_storage.barrier_dQaccum)), gdQaccum, sdQaccum_flat);
|
| 186 |
-
}
|
| 187 |
-
__syncthreads();
|
| 188 |
-
shared_storage.barrier_dQaccum.wait(0);
|
| 189 |
-
} else {
|
| 190 |
-
G2STiledCopydQaccum g2s_tiled_copy_dQaccum;
|
| 191 |
-
auto g2s_thr_copy_dQaccum = g2s_tiled_copy_dQaccum.get_thread_slice(thread_idx);
|
| 192 |
-
Tensor tdQgdQaccumg2s = g2s_thr_copy_dQaccum.partition_S(gdQaccum);
|
| 193 |
-
Tensor tdQsdQaccumg2s = g2s_thr_copy_dQaccum.partition_D(sdQaccum);
|
| 194 |
-
cute::copy(g2s_tiled_copy_dQaccum, tdQgdQaccumg2s, tdQsdQaccumg2s);
|
| 195 |
-
__syncthreads();
|
| 196 |
-
}
|
| 197 |
-
|
| 198 |
-
// __syncthreads(); if (cute::thread0()) { print_tensor(sdQaccum); }
|
| 199 |
-
|
| 200 |
-
// Step 2: Load dQaccum from smem to register, then convert fp32 -> fp16/bf16
|
| 201 |
-
R2STiledCopydQaccum s2r_tiled_copy_dQaccum;
|
| 202 |
-
auto s2r_thr_copy_dQaccum = s2r_tiled_copy_dQaccum.get_thread_slice(thread_idx);
|
| 203 |
-
Tensor tdQsdQaccum = s2r_thr_copy_dQaccum.partition_S(sdQaccum);
|
| 204 |
-
TiledMma tiled_mma_dQ;
|
| 205 |
-
Tensor taccdQrdQaccum = partition_fragment_C(tiled_mma_dQ, select<!dQ_swapAB ? 0 : 1, !dQ_swapAB ? 1 : 0>(TileShape_MK{}));
|
| 206 |
-
// if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tiled_mma_dQ); printf("\n"); }
|
| 207 |
-
// if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(tdQsdQaccum); }
|
| 208 |
-
// if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 1) { print(taccdQrdQaccum); }
|
| 209 |
-
CUTE_STATIC_ASSERT_V(size(taccdQrdQaccum) == size(tdQsdQaccum));
|
| 210 |
-
Tensor tdQrdQaccum = s2r_thr_copy_dQaccum.retile_D(taccdQrdQaccum);
|
| 211 |
-
cute::copy(s2r_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum);
|
| 212 |
-
#pragma unroll
|
| 213 |
-
for (int i = 0; i < size(taccdQrdQaccum); ++i) { taccdQrdQaccum(i) *= params.softmax_scale; }
|
| 214 |
-
// Convert tdQrdQ from fp32 to fp16
|
| 215 |
-
Tensor rdQ = make_tensor_like<Element>(taccdQrdQaccum);
|
| 216 |
-
flash::convert_type_out(taccdQrdQaccum, rdQ);
|
| 217 |
-
|
| 218 |
-
// Step 3: Copy dQ from register to smem
|
| 219 |
-
auto smem_tiled_copy_dQ = make_tiled_copy_C(SmemCopyAtomdQ{}, tiled_mma_dQ);
|
| 220 |
-
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(thread_idx);
|
| 221 |
-
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
|
| 222 |
-
// if (cute::thread0()) { print(smem_tiled_copy_dQ); }
|
| 223 |
-
// if (cute::thread0()) { print(smem_thr_copy_dQ); }
|
| 224 |
-
// if (cute::thread0()) { print(sdQ); }
|
| 225 |
-
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(cute::conditional_return<!dQ_swapAB>(sdQ, sdQt)); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
| 226 |
-
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
|
| 227 |
-
__syncthreads();
|
| 228 |
-
|
| 229 |
-
// Step 4: Copy dQ from smem to register to prepare for coalesced write to gmem
|
| 230 |
-
Tensor mdQ = make_tensor(make_gmem_ptr(params.ptr_dQ), params.shape_dQ, params.stride_dQ)(_, _, bidh, !is_varlen ? bidb : 0);
|
| 231 |
-
Tensor gdQ = local_tile(domain_offset(make_coord(seqlen_info.offset, _0{}), mdQ), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
|
| 232 |
-
GmemTiledCopy gmem_tiled_copy_dQ;
|
| 233 |
-
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(thread_idx);
|
| 234 |
-
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
| 235 |
-
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
|
| 236 |
-
|
| 237 |
-
Tensor tdQrdQ = make_fragment_like(tdQsdQ);
|
| 238 |
-
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cute::make_identity_tensor(TileShape_MK{}));
|
| 239 |
-
Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
|
| 240 |
-
#pragma unroll
|
| 241 |
-
for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(_0{}, _0{}, k)) < get<1>(params.shape_dQ); }
|
| 242 |
-
// Need to check OOB when reading from smem if kBlockM isn't evenly tiled
|
| 243 |
-
static constexpr bool EvenM = kBlockM % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0;
|
| 244 |
-
flash::copy</*Is_even_MN=*/EvenM, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false>(
|
| 245 |
-
gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ, tdQcdQ, tdQpdQ, kBlockM);
|
| 246 |
-
|
| 247 |
-
// Step 5: Copy dQ from register to gmem
|
| 248 |
-
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
| 249 |
-
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
| 250 |
-
gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, std::min(seqlen_info.seqlen - m_block * kBlockM, kBlockM)
|
| 251 |
-
);
|
| 252 |
-
}
|
| 253 |
-
|
| 254 |
-
};
|
| 255 |
-
|
| 256 |
-
} // namespace flash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/flash_bwd_preprocess_kernel.h
DELETED
|
@@ -1,252 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
#include "cute/tensor.hpp"
|
| 8 |
-
|
| 9 |
-
#include <cutlass/cutlass.h>
|
| 10 |
-
#include <cutlass/array.h>
|
| 11 |
-
#include <cutlass/numeric_types.h>
|
| 12 |
-
#include <cutlass/numeric_conversion.h>
|
| 13 |
-
|
| 14 |
-
#include "seqlen.h"
|
| 15 |
-
#include "utils.h"
|
| 16 |
-
|
| 17 |
-
namespace flash {
|
| 18 |
-
|
| 19 |
-
using namespace cute;
|
| 20 |
-
|
| 21 |
-
template <class TileShape_MK_, class Element, class ElementAccum, class ArchTag_, bool Clear_dQaccum, bool Varlen>
|
| 22 |
-
class FlashAttnBwdPreprocess {
|
| 23 |
-
|
| 24 |
-
public:
|
| 25 |
-
|
| 26 |
-
// Type Aliases
|
| 27 |
-
using TileShape_MK = TileShape_MK_;
|
| 28 |
-
using ArchTag = ArchTag_;
|
| 29 |
-
|
| 30 |
-
static_assert(std::is_same_v<Element, cutlass::half_t> && ArchTag::kMinComputeCapability >= 75 ||
|
| 31 |
-
std::is_same_v<Element, cutlass::bfloat16_t> && ArchTag::kMinComputeCapability >= 80 ||
|
| 32 |
-
std::is_same_v<Element, cutlass::float_e4m3_t> && ArchTag::kMinComputeCapability >= 89);
|
| 33 |
-
|
| 34 |
-
static constexpr uint32_t MaxThreadsPerBlock = 256;
|
| 35 |
-
static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
|
| 36 |
-
static constexpr int SharedStorageSize = 0;
|
| 37 |
-
|
| 38 |
-
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
| 39 |
-
static_assert(get<1>(TileShape_MK{}) % kGmemElemsPerLoad == 0, "Headdim must be a multiple of kGmemElemsPerLoad");
|
| 40 |
-
static constexpr int kBlockM = get<0>(TileShape_MK{});
|
| 41 |
-
static constexpr int kHeadDim = get<1>(TileShape_MK{});
|
| 42 |
-
// We want kBlockKGmem to be a power of 2 so that when we do the summing,
|
| 43 |
-
// it's just between threads in the same warp
|
| 44 |
-
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
| 45 |
-
static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
|
| 46 |
-
static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
|
| 47 |
-
using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
| 48 |
-
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
| 49 |
-
using GmemTiledCopy = decltype(
|
| 50 |
-
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
| 51 |
-
GmemLayoutAtom{},
|
| 52 |
-
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 8 or 16 vals per load
|
| 53 |
-
|
| 54 |
-
static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
|
| 55 |
-
static_assert((kBlockM * kHeadDim / kGmemElemsPerLoadAccum) % MaxThreadsPerBlock == 0, "MaxThreadsPerBlock must divide kBlockM * kHeadDim / kGmemElemsPerLoadAccum");
|
| 56 |
-
using GmemLayoutAtomAccum = Layout<Shape<Int<MaxThreadsPerBlock>>>;
|
| 57 |
-
using GmemTiledCopyAccum = decltype(
|
| 58 |
-
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
| 59 |
-
GmemLayoutAtomAccum{},
|
| 60 |
-
Layout<Shape<Int<kGmemElemsPerLoadAccum>>>{})); // Val layout, 4 vals per store
|
| 61 |
-
|
| 62 |
-
using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
|
| 63 |
-
using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
|
| 64 |
-
using ShapedPsum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q, head, batch)
|
| 65 |
-
using StridedPsum = cute::Stride<_1, int64_t, int64_t>;
|
| 66 |
-
using ShapedQaccum = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen_q * d, head, batch)
|
| 67 |
-
using StridedQaccum = cute::Stride<_1, int64_t, int64_t>;
|
| 68 |
-
|
| 69 |
-
// Device side arguments
|
| 70 |
-
struct Arguments {
|
| 71 |
-
Element const* ptr_O;
|
| 72 |
-
ShapeO const shape_O;
|
| 73 |
-
StrideO const stride_O;
|
| 74 |
-
Element const* ptr_dO;
|
| 75 |
-
StrideO const stride_dO;
|
| 76 |
-
float* ptr_dPsum;
|
| 77 |
-
ShapedPsum const shape_dPsum;
|
| 78 |
-
StridedPsum const stride_dPsum;
|
| 79 |
-
float const* ptr_LSE;
|
| 80 |
-
StridedPsum const stride_LSE;
|
| 81 |
-
float *ptr_LSE_log2;
|
| 82 |
-
StridedPsum const stride_LSE_log2;
|
| 83 |
-
ElementAccum* ptr_dQaccum;
|
| 84 |
-
ShapedQaccum const shape_dQaccum;
|
| 85 |
-
StridedQaccum const stride_dQaccum;
|
| 86 |
-
int num_batch; // We need this to know the size of dq_semaphore in case of varlen
|
| 87 |
-
int* dq_semaphore;
|
| 88 |
-
int const* cu_seqlens = nullptr;
|
| 89 |
-
int const* seqused = nullptr;
|
| 90 |
-
};
|
| 91 |
-
|
| 92 |
-
// Kernel entry point API
|
| 93 |
-
struct Params {
|
| 94 |
-
Element const* ptr_O;
|
| 95 |
-
ShapeO const shape_O;
|
| 96 |
-
StrideO const stride_O;
|
| 97 |
-
Element const* ptr_dO;
|
| 98 |
-
StrideO const stride_dO;
|
| 99 |
-
float* ptr_dPsum;
|
| 100 |
-
ShapedPsum const shape_dPsum;
|
| 101 |
-
StridedPsum const stride_dPsum;
|
| 102 |
-
float const* ptr_LSE;
|
| 103 |
-
StridedPsum const stride_LSE;
|
| 104 |
-
float* ptr_LSE_log2;
|
| 105 |
-
StridedPsum const stride_LSE_log2;
|
| 106 |
-
ElementAccum* ptr_dQaccum;
|
| 107 |
-
ShapedQaccum const shape_dQaccum;
|
| 108 |
-
StridedQaccum const stride_dQaccum;
|
| 109 |
-
int num_batch;
|
| 110 |
-
int* dq_semaphore;
|
| 111 |
-
int const* cu_seqlens = nullptr;
|
| 112 |
-
int const* seqused = nullptr;
|
| 113 |
-
};
|
| 114 |
-
|
| 115 |
-
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
| 116 |
-
static
|
| 117 |
-
Params
|
| 118 |
-
to_underlying_arguments(Arguments const& args) {
|
| 119 |
-
return {
|
| 120 |
-
args.ptr_O,
|
| 121 |
-
args.shape_O,
|
| 122 |
-
args.stride_O,
|
| 123 |
-
args.ptr_dO,
|
| 124 |
-
args.stride_dO,
|
| 125 |
-
args.ptr_dPsum,
|
| 126 |
-
args.shape_dPsum,
|
| 127 |
-
args.stride_dPsum,
|
| 128 |
-
args.ptr_LSE,
|
| 129 |
-
args.stride_LSE,
|
| 130 |
-
args.ptr_LSE_log2,
|
| 131 |
-
args.stride_LSE_log2,
|
| 132 |
-
args.ptr_dQaccum,
|
| 133 |
-
args.shape_dQaccum,
|
| 134 |
-
args.stride_dQaccum,
|
| 135 |
-
args.num_batch,
|
| 136 |
-
args.dq_semaphore,
|
| 137 |
-
args.cu_seqlens,
|
| 138 |
-
args.seqused
|
| 139 |
-
};
|
| 140 |
-
}
|
| 141 |
-
|
| 142 |
-
CUTLASS_DEVICE
|
| 143 |
-
void
|
| 144 |
-
operator()(Params const& params, [[maybe_unused]] char* smem_buf) {
|
| 145 |
-
|
| 146 |
-
static constexpr int kBlockM = get<0>(TileShape_MK{});
|
| 147 |
-
|
| 148 |
-
int const thread_idx = threadIdx.x;
|
| 149 |
-
int const m_block = blockIdx.x;
|
| 150 |
-
int const bidh = blockIdx.y;
|
| 151 |
-
int const bidb = blockIdx.z;
|
| 152 |
-
|
| 153 |
-
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info(bidb, size<0>(params.shape_O), params.cu_seqlens, params.seqused);
|
| 154 |
-
bool const is_varlen = Varlen && params.cu_seqlens;
|
| 155 |
-
int const seqlen_o = seqlen_info.seqlen;
|
| 156 |
-
if (is_varlen && m_block * kBlockM >= seqlen_o) { return; }
|
| 157 |
-
|
| 158 |
-
Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O), params.shape_O, params.stride_O)(_, _, bidh, !is_varlen ? bidb : 0);
|
| 159 |
-
Tensor gO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
|
| 160 |
-
Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_O, params.stride_dO)(_, _, bidh, !is_varlen ? bidb : 0);
|
| 161 |
-
Tensor gdO = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdO), TileShape_MK{}, make_coord(m_block, _0{})); // (M, K)
|
| 162 |
-
|
| 163 |
-
auto shape_LSE = select<0, 2, 3>(params.shape_O);
|
| 164 |
-
Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), shape_LSE, params.stride_LSE)(_, bidh, !is_varlen ? bidb : 0);
|
| 165 |
-
Tensor gLSE = local_tile(cute::domain_offset(make_coord(seqlen_info.offset), mLSE), Shape<Int<kBlockM>>{}, make_coord(m_block));
|
| 166 |
-
static_assert(kBlockM <= MaxThreadsPerBlock);
|
| 167 |
-
float lse = thread_idx < seqlen_o - m_block * kBlockM && thread_idx < kBlockM ? gLSE(thread_idx) : INFINITY;
|
| 168 |
-
|
| 169 |
-
GmemTiledCopy gmem_tiled_copy_O;
|
| 170 |
-
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
| 171 |
-
|
| 172 |
-
Tensor tOgO = gmem_thr_copy_O.partition_S(gO);
|
| 173 |
-
Tensor tOgdO = gmem_thr_copy_O.partition_S(gdO);
|
| 174 |
-
// Construct identity layout for gO
|
| 175 |
-
Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
| 176 |
-
// Repeat the partitioning with identity layouts
|
| 177 |
-
Tensor tOcO = gmem_thr_copy_O.partition_D(cO);
|
| 178 |
-
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
|
| 179 |
-
#pragma unroll
|
| 180 |
-
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O); }
|
| 181 |
-
|
| 182 |
-
// (8, kBlockM / 32, kHeadDim / 64) or (8, kBlockM / 16, kHeadDim / 128)
|
| 183 |
-
Tensor tOrO = make_fragment_like(tOgO);
|
| 184 |
-
Tensor tOrdO = make_fragment_like(tOgdO);
|
| 185 |
-
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(
|
| 186 |
-
gmem_tiled_copy_O, tOgO, tOrO, tOcO, tOpO, seqlen_o - m_block * kBlockM
|
| 187 |
-
);
|
| 188 |
-
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true, /*Clearn_OOB_K=*/true>(
|
| 189 |
-
gmem_tiled_copy_O, tOgdO, tOrdO, tOcO, tOpO, seqlen_o - m_block * kBlockM
|
| 190 |
-
);
|
| 191 |
-
// if (threadIdx.x == 222) { printf("bidx = %d, bidy = %d, bidz = %d, seqlen_o = %d, m_block = %d, seqlen_o - m_block * kBlockM = %d, tOgO addr = %p\n", blockIdx.x, blockIdx.y, blockIdx.z, seqlen_o, m_block, seqlen_o - m_block * kBlockM, &tOgO(0));}
|
| 192 |
-
|
| 193 |
-
// Reshape from e.g. (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, (8, kHeadDim / 64))
|
| 194 |
-
Layout l = make_layout(get<1>(tOrO.layout()), make_layout(get<0>(tOrO.layout()), get<2>(tOrO.layout())));
|
| 195 |
-
Tensor tOrO_l = make_tensor(tOrO.data(), l);
|
| 196 |
-
Tensor o_fp32 = make_tensor_like<float>(tOrO_l);
|
| 197 |
-
flash::convert_type_out(tOrO_l, o_fp32);
|
| 198 |
-
Tensor tOrdO_l = make_tensor(tOrdO.data(), l);
|
| 199 |
-
Tensor do_fp32 = make_tensor_like<float>(tOrdO_l);
|
| 200 |
-
flash::convert_type_out(tOrdO_l, do_fp32);
|
| 201 |
-
// Sum across the last dimension
|
| 202 |
-
Tensor dP_sum = make_tensor<float>(make_shape(size<0>(o_fp32)));
|
| 203 |
-
#pragma unroll
|
| 204 |
-
for (int mi = 0; mi < size<0>(o_fp32); ++mi) {
|
| 205 |
-
float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
|
| 206 |
-
#pragma unroll
|
| 207 |
-
for (int ni = 1; ni < size<1>(o_fp32); ni++) {
|
| 208 |
-
dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
|
| 209 |
-
}
|
| 210 |
-
flash::SumOp<float> sum_op;
|
| 211 |
-
dP_sum(mi) = flash::Allreduce<kGmemThreadsPerRow>::run(dP_sum_cur, sum_op);
|
| 212 |
-
}
|
| 213 |
-
|
| 214 |
-
Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_dPsum, params.stride_dPsum)(_, bidh, !is_varlen ? bidb : 0);
|
| 215 |
-
Tensor gdPsum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mdPsum), Shape<Int<kBlockM>>{}, make_coord(m_block));
|
| 216 |
-
if (get<1>(tOcO(_0{}, _0{}, _0{})) == 0) {
|
| 217 |
-
#pragma unroll
|
| 218 |
-
for (int mi = 0; mi < size(dP_sum); ++mi) {
|
| 219 |
-
int const row = get<0>(tOcO(_0{}, mi, _0{}));
|
| 220 |
-
gdPsum(row) = row < seqlen_o - m_block * kBlockM ? dP_sum(mi) : 0;
|
| 221 |
-
}
|
| 222 |
-
}
|
| 223 |
-
|
| 224 |
-
int const seqlen_rounded = cute::round_up(seqlen_o, kBlockM);
|
| 225 |
-
Tensor mLSElog2 = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_dPsum, params.stride_LSE_log2)(_, bidh, !is_varlen ? bidb : 0);
|
| 226 |
-
Tensor gLSElog2 = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded), mLSElog2), Shape<Int<kBlockM>>{}, make_coord(m_block));
|
| 227 |
-
if (thread_idx < seqlen_rounded - m_block * kBlockM && thread_idx < kBlockM) {
|
| 228 |
-
gLSElog2(thread_idx) = lse == -INFINITY ? 0.f : lse * float(M_LOG2E);
|
| 229 |
-
}
|
| 230 |
-
|
| 231 |
-
if constexpr (Clear_dQaccum) {
|
| 232 |
-
Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0);
|
| 233 |
-
Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape<Int<kBlockM * kHeadDim>>{}, make_coord(m_block));
|
| 234 |
-
GmemTiledCopyAccum gmem_tiled_copy_dQaccum;
|
| 235 |
-
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx);
|
| 236 |
-
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
|
| 237 |
-
Tensor zero = make_fragment_like(tdQgdQaccum);
|
| 238 |
-
clear(zero);
|
| 239 |
-
cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{}, zero, tdQgdQaccum);
|
| 240 |
-
}
|
| 241 |
-
|
| 242 |
-
if (params.dq_semaphore != nullptr && thread_idx == 0) {
|
| 243 |
-
int const num_batch = params.num_batch;
|
| 244 |
-
int const num_head = get<2>(params.shape_O);
|
| 245 |
-
params.dq_semaphore[bidh + bidb * num_head + m_block * num_head * num_batch] = 0;
|
| 246 |
-
}
|
| 247 |
-
|
| 248 |
-
}
|
| 249 |
-
|
| 250 |
-
};
|
| 251 |
-
|
| 252 |
-
} // namespace flash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/flash_fwd_combine.cu
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Tri Dao.
|
| 2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
-
|
| 4 |
-
#include "flash_fwd_combine_launch_template.h"
|
| 5 |
-
|
| 6 |
-
template void run_mha_fwd_combine_<float, float, 64>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
|
| 7 |
-
template void run_mha_fwd_combine_<float, float, 128>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
|
| 8 |
-
|
| 9 |
-
template void run_mha_fwd_combine_<cutlass::half_t, float, 64>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
|
| 10 |
-
template void run_mha_fwd_combine_<cutlass::half_t, float, 128>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
|
| 11 |
-
|
| 12 |
-
template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 64>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
|
| 13 |
-
template void run_mha_fwd_combine_<cutlass::bfloat16_t, float, 128>(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/flash_fwd_combine_kernel.h
DELETED
|
@@ -1,482 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
#include "cute/tensor.hpp"
|
| 8 |
-
|
| 9 |
-
#include <cutlass/cutlass.h>
|
| 10 |
-
#include <cutlass/arch/memory.h>
|
| 11 |
-
#include <cutlass/array.h>
|
| 12 |
-
#include <cutlass/numeric_types.h>
|
| 13 |
-
#include <cutlass/numeric_conversion.h>
|
| 14 |
-
|
| 15 |
-
#include "cutlass/arch/grid_dependency_control.h"
|
| 16 |
-
|
| 17 |
-
#include "seqlen.h"
|
| 18 |
-
#include "utils.h"
|
| 19 |
-
|
| 20 |
-
namespace flash {
|
| 21 |
-
|
| 22 |
-
using namespace cute;
|
| 23 |
-
|
| 24 |
-
template <class TileShape_MK_, int kLogMaxSplits_, int kNThreads, int AlignmentLSE_,
|
| 25 |
-
bool Is_even_K, bool Varlen, class Element, class ElementPartial, class ArchTag_>
|
| 26 |
-
class FlashAttnFwdCombine {
|
| 27 |
-
|
| 28 |
-
public:
|
| 29 |
-
|
| 30 |
-
// Type Aliases
|
| 31 |
-
using TileShape_MK = TileShape_MK_;
|
| 32 |
-
using ArchTag = ArchTag_;
|
| 33 |
-
static constexpr int kMaxSplits = 1 << kLogMaxSplits_;
|
| 34 |
-
static constexpr int AlignmentLSE = std::min(AlignmentLSE_, int(128 / 8 / sizeof(float)));
|
| 35 |
-
static_assert(AlignmentLSE >= 1);
|
| 36 |
-
static constexpr int kStages = 4;
|
| 37 |
-
|
| 38 |
-
static_assert(ArchTag::kMinComputeCapability >= 75);
|
| 39 |
-
static constexpr bool Has_cp_async = ArchTag::kMinComputeCapability >= 80;
|
| 40 |
-
|
| 41 |
-
static constexpr uint32_t MaxThreadsPerBlock = kNThreads;
|
| 42 |
-
static constexpr uint32_t MinBlocksPerMultiprocessor = 2;
|
| 43 |
-
|
| 44 |
-
static constexpr int kBlockM = get<0>(TileShape_MK{});
|
| 45 |
-
static constexpr int kBlockK = get<1>(TileShape_MK{});
|
| 46 |
-
|
| 47 |
-
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(ElementPartial);
|
| 48 |
-
static_assert(kBlockK % kGmemElemsPerLoad == 0, "kBlockK must be a multiple of kGmemElemsPerLoad");
|
| 49 |
-
static constexpr int kBlockKGmem = kBlockK % 128 == 0 ? 128 : (kBlockK % 64 == 0 ? 64 : 32);
|
| 50 |
-
static constexpr int kGmemThreadsPerRow = kBlockKGmem / kGmemElemsPerLoad;
|
| 51 |
-
static_assert(MaxThreadsPerBlock % kGmemThreadsPerRow == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRow");
|
| 52 |
-
using GmemCopyAtom = std::conditional_t<
|
| 53 |
-
Has_cp_async,
|
| 54 |
-
cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL<uint128_t>, ElementPartial>,
|
| 55 |
-
cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>
|
| 56 |
-
>;
|
| 57 |
-
using GmemLayoutAtom = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
| 58 |
-
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
| 59 |
-
static_assert(kBlockM % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0);
|
| 60 |
-
using GmemTiledCopyAccum = decltype(
|
| 61 |
-
make_tiled_copy(GmemCopyAtom{},
|
| 62 |
-
GmemLayoutAtom{},
|
| 63 |
-
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 4 vals per load
|
| 64 |
-
using GmemTiledCopy = decltype(
|
| 65 |
-
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
| 66 |
-
GmemLayoutAtom{},
|
| 67 |
-
Layout<Shape<_1, Int<kGmemElemsPerLoad>>>{})); // Val layout, 4 vals per load
|
| 68 |
-
|
| 69 |
-
using AlignmentTypeLSE = cute::uint_byte_t<static_cast<int>(sizeof(float)) * AlignmentLSE>;
|
| 70 |
-
static constexpr int kGmemElemsPerLoadLSE = sizeof(AlignmentTypeLSE) / sizeof(float);
|
| 71 |
-
static_assert(kBlockM % kGmemElemsPerLoadLSE == 0, "kBlockM must be a multiple of kGmemElemsPerLoadLSE");
|
| 72 |
-
static_assert(kBlockM % 8 == 0, "kBlockM must be a multiple of 8");
|
| 73 |
-
static constexpr int kBlockMSmem = kBlockM % 128 == 0 ? 128 : (kBlockM % 64 == 0 ? 64 : (kBlockM % 32 == 0 ? 32 : (kBlockM % 16 == 0 ? 16 : 8)));
|
| 74 |
-
static constexpr int kGmemThreadsPerRowLSE = kBlockMSmem / kGmemElemsPerLoadLSE;
|
| 75 |
-
static_assert(MaxThreadsPerBlock % kGmemThreadsPerRowLSE == 0, "MaxThreadsPerBlock must be a multiple of kGmemThreadsPerRowLSE");
|
| 76 |
-
using GmemLayoutAtomLSE = Layout<Shape <Int<MaxThreadsPerBlock / kGmemThreadsPerRowLSE>, Int<kGmemThreadsPerRowLSE>>,
|
| 77 |
-
Stride<Int<kGmemThreadsPerRowLSE>, _1>>;
|
| 78 |
-
static_assert(kMaxSplits % CUTE_STATIC_V(shape<0>(GmemLayoutAtomLSE{})) == 0);
|
| 79 |
-
using GmemCopyAtomLSE = std::conditional_t<
|
| 80 |
-
Has_cp_async,
|
| 81 |
-
cute::Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<AlignmentTypeLSE>, float>,
|
| 82 |
-
cute::Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<AlignmentLSE * sizeof(float) * 8>, float>
|
| 83 |
-
>;
|
| 84 |
-
using GmemTiledCopyLSE = decltype(
|
| 85 |
-
make_tiled_copy(GmemCopyAtomLSE{},
|
| 86 |
-
GmemLayoutAtomLSE{},
|
| 87 |
-
Layout<Shape<_1, Int<kGmemElemsPerLoadLSE>>>{})); // Val layout, 4 vals per load
|
| 88 |
-
|
| 89 |
-
// Otherwise we get IMA when some threads access sLSE, as we're not doing any masking
|
| 90 |
-
static_assert((kBlockM * kMaxSplits * AlignmentLSE) % kNThreads == 0, "kNThreads must divide kBlockM * kMaxSplits * AlignmentLSE");
|
| 91 |
-
// This works for kBlockMSmem = 8, 16, 32, 64, 128, no bank conflicts
|
| 92 |
-
using SmemLSESwizzle = std::conditional_t<
|
| 93 |
-
kBlockMSmem == 8,
|
| 94 |
-
Swizzle<5, 0, 5>,
|
| 95 |
-
std::conditional_t<kBlockMSmem == 16, Swizzle<4, 0, 4>, Swizzle<3, 2, 3>>
|
| 96 |
-
>;
|
| 97 |
-
using SmemLayoutAtomLSE =
|
| 98 |
-
decltype(composition(SmemLSESwizzle{},
|
| 99 |
-
Layout<Shape<Int<8>, Int<kBlockMSmem>>,
|
| 100 |
-
Stride<Int<kBlockMSmem>, _1>>{}));
|
| 101 |
-
using SmemLayoutLSE = decltype(tile_to_shape(SmemLayoutAtomLSE{}, Shape<Int<kMaxSplits>, Int<kBlockM>>{}));
|
| 102 |
-
|
| 103 |
-
using SmemLayoutO = Layout<Shape<Int<kBlockM>, Int<kBlockK>, Int<kStages>>,
|
| 104 |
-
Stride<Int<kBlockK>, _1, Int<kBlockM * kBlockK>>>;
|
| 105 |
-
|
| 106 |
-
// We want each column (kMaxSplits) to be processed by threads in the same warp.
|
| 107 |
-
// To reduce the number of shuffles, we want as few threads on the same column as possible.
|
| 108 |
-
// E.g., if kBlockM is divisible by 64, and there are 256 threads, we want 4 threads (0, 1, 2, 4) per column
|
| 109 |
-
// have have 64 such quads.
|
| 110 |
-
static_assert(MaxThreadsPerBlock % kBlockMSmem == 0, "MaxThreadsPerBlock must be a multiple of kBlockMSmem");
|
| 111 |
-
static constexpr int kSmemThreadsPerColLSEt = MaxThreadsPerBlock / kBlockMSmem;
|
| 112 |
-
static_assert(cutlass::NumThreadsPerWarp % kSmemThreadsPerColLSEt == 0, "kSmemThreadsPerColLSEt must divide NumThreadsPerWarp");
|
| 113 |
-
using S2RLayoutAtomLSE = Layout<Shape<Int<kSmemThreadsPerColLSEt>, Int<MaxThreadsPerBlock / kSmemThreadsPerColLSEt>>>;
|
| 114 |
-
using S2RTiledCopyLSE = decltype(make_tiled_copy(cute::Copy_Atom<cute::DefaultCopy, float>{}, S2RLayoutAtomLSE{}, Layout<_1>{}));
|
| 115 |
-
|
| 116 |
-
using ShapeOPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, num_splits, head, batch)
|
| 117 |
-
using StrideOPartial = cute::Stride<int64_t, _1, int64_t, int64_t, int64_t>;
|
| 118 |
-
using ShapeLSEPartial = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, num_splits, head, batch)
|
| 119 |
-
using StrideLSEPartial = cute::Stride<_1, int64_t, int64_t, int64_t>; // (seqlen, num_splits, head, batch)
|
| 120 |
-
using ShapeO = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen, d, head, batch)
|
| 121 |
-
using StrideO = cute::Stride<int64_t, _1, int64_t, int64_t>;
|
| 122 |
-
using ShapeLSE = cute::Shape<int32_t, int32_t, int32_t>; // (seqlen, head, batch)
|
| 123 |
-
using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch)
|
| 124 |
-
|
| 125 |
-
struct SharedStorage : cute::aligned_struct<128> {
|
| 126 |
-
cute::array_aligned<float, cute::cosize_v<SmemLayoutLSE>> smem_lse_partial;
|
| 127 |
-
cute::array_aligned<int, kBlockM> smem_max_valid_split;
|
| 128 |
-
cute::array_aligned<ElementPartial, cute::cosize_v<SmemLayoutO>> smem_o_partial;
|
| 129 |
-
};
|
| 130 |
-
|
| 131 |
-
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
| 132 |
-
|
| 133 |
-
// Device side arguments
|
| 134 |
-
struct Arguments {
|
| 135 |
-
ElementPartial const* const ptr_O_partial;
|
| 136 |
-
ShapeOPartial const shape_O_partial;
|
| 137 |
-
StrideOPartial const stride_O_partial;
|
| 138 |
-
float const* const ptr_LSE_partial;
|
| 139 |
-
ShapeLSEPartial const shape_LSE_partial;
|
| 140 |
-
StrideLSEPartial const stride_LSE_partial;
|
| 141 |
-
Element* const ptr_O;
|
| 142 |
-
StrideO const stride_O;
|
| 143 |
-
float* const ptr_LSE;
|
| 144 |
-
StrideLSE const stride_LSE;
|
| 145 |
-
int const* const cu_seqlens = nullptr;
|
| 146 |
-
int const* const seqused = nullptr;
|
| 147 |
-
int const* const num_splits_dynamic_ptr = nullptr;
|
| 148 |
-
int* const semaphore_to_reset = nullptr;
|
| 149 |
-
};
|
| 150 |
-
|
| 151 |
-
// Kernel entry point API
|
| 152 |
-
struct Params {
|
| 153 |
-
ElementPartial const* const ptr_O_partial;
|
| 154 |
-
ShapeOPartial const shape_O_partial;
|
| 155 |
-
StrideOPartial const stride_O_partial;
|
| 156 |
-
float const* const ptr_LSE_partial;
|
| 157 |
-
ShapeLSEPartial const shape_LSE_partial;
|
| 158 |
-
StrideLSEPartial const stride_LSE_partial;
|
| 159 |
-
Element* const ptr_O;
|
| 160 |
-
StrideO const stride_O;
|
| 161 |
-
float* const ptr_LSE;
|
| 162 |
-
StrideLSE const stride_LSE;
|
| 163 |
-
cutlass::FastDivmod seqlen_divmod, head_divmod;
|
| 164 |
-
int const* const cu_seqlens = nullptr;
|
| 165 |
-
int const* const seqused = nullptr;
|
| 166 |
-
int const* const num_splits_dynamic_ptr = nullptr;
|
| 167 |
-
int* const semaphore_to_reset = nullptr;
|
| 168 |
-
};
|
| 169 |
-
|
| 170 |
-
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
| 171 |
-
static
|
| 172 |
-
Params
|
| 173 |
-
to_underlying_arguments(Arguments const& args) {
|
| 174 |
-
assert(get<1>(args.shape_LSE_partial) <= kMaxSplits);
|
| 175 |
-
return {
|
| 176 |
-
args.ptr_O_partial,
|
| 177 |
-
args.shape_O_partial,
|
| 178 |
-
args.stride_O_partial,
|
| 179 |
-
args.ptr_LSE_partial,
|
| 180 |
-
args.shape_LSE_partial,
|
| 181 |
-
args.stride_LSE_partial,
|
| 182 |
-
args.ptr_O,
|
| 183 |
-
args.stride_O,
|
| 184 |
-
args.ptr_LSE,
|
| 185 |
-
args.stride_LSE,
|
| 186 |
-
cutlass::FastDivmod(get<0>(args.shape_LSE_partial)), cutlass::FastDivmod(get<2>(args.shape_LSE_partial)),
|
| 187 |
-
args.cu_seqlens,
|
| 188 |
-
args.seqused,
|
| 189 |
-
args.num_splits_dynamic_ptr,
|
| 190 |
-
args.semaphore_to_reset
|
| 191 |
-
};
|
| 192 |
-
}
|
| 193 |
-
|
| 194 |
-
CUTLASS_DEVICE
|
| 195 |
-
void
|
| 196 |
-
operator()(Params const& params, char* smem_buf) {
|
| 197 |
-
|
| 198 |
-
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
| 199 |
-
Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse_partial.data()), SmemLayoutLSE{});
|
| 200 |
-
Tensor sMaxValidSplit = make_tensor(make_smem_ptr(shared_storage.smem_max_valid_split.data()), Shape<Int<kBlockM>>{});
|
| 201 |
-
Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o_partial.data()), SmemLayoutO{});
|
| 202 |
-
|
| 203 |
-
int const thread_idx = threadIdx.x;
|
| 204 |
-
int const m_block = blockIdx.x;
|
| 205 |
-
int const k_block = blockIdx.y;
|
| 206 |
-
int const batch = blockIdx.z;
|
| 207 |
-
int const num_splits = params.num_splits_dynamic_ptr ? params.num_splits_dynamic_ptr[batch] : get<1>(params.shape_LSE_partial);
|
| 208 |
-
|
| 209 |
-
if (params.semaphore_to_reset && threadIdx.x == 0 && blockIdx.x == gridDim.x - 1 && blockIdx.y == gridDim.y - 1 && blockIdx.z == gridDim.z - 1) {
|
| 210 |
-
cutlass::arch::wait_on_dependent_grids();
|
| 211 |
-
*params.semaphore_to_reset = 0;
|
| 212 |
-
}
|
| 213 |
-
if (num_splits <= 1) { return; }
|
| 214 |
-
flash::SeqlenInfo<Varlen, kBlockM> seqlen_info{batch, size<0>(params.shape_LSE_partial), params.cu_seqlens, params.seqused};
|
| 215 |
-
int const offset = seqlen_info.offset;
|
| 216 |
-
int const seqlen = seqlen_info.seqlen;
|
| 217 |
-
int max_idx = seqlen * get<2>(params.shape_LSE_partial);
|
| 218 |
-
if constexpr (Varlen) {
|
| 219 |
-
if (m_block * kBlockM >= max_idx) { return; }
|
| 220 |
-
}
|
| 221 |
-
|
| 222 |
-
cutlass::FastDivmod seqlen_divmod_dynamic(seqlen);
|
| 223 |
-
|
| 224 |
-
// Step 1: load LSE_partial from gmem -> smem
|
| 225 |
-
Tensor mLSEpartial = make_tensor(make_gmem_ptr(params.ptr_LSE_partial + offset * get<0>(params.stride_LSE_partial)),
|
| 226 |
-
select<1, 0, 2, 3>(params.shape_LSE_partial),
|
| 227 |
-
select<1, 0, 2, 3>(params.stride_LSE_partial))(_, _, _, !Varlen ? batch : 0); // (num_splits, seqlen, head)
|
| 228 |
-
Tensor mLSEpartial_copy = cute::tiled_divide(mLSEpartial, Shape<_1, Int<kGmemElemsPerLoadLSE>>{});
|
| 229 |
-
GmemTiledCopyLSE gmem_tiled_copy_LSE;
|
| 230 |
-
auto gmem_thr_copy_LSE = gmem_tiled_copy_LSE.get_thread_slice(thread_idx);
|
| 231 |
-
Tensor tLSEsLSE = gmem_thr_copy_LSE.partition_D(sLSE);
|
| 232 |
-
|
| 233 |
-
// Construct identity layout for sLSE
|
| 234 |
-
Tensor cLSE = make_identity_tensor(make_shape(size<0>(sLSE), size<1>(sLSE))); // (NUM_SPLITS, BLK_M) -> (num_splits, blk_m)
|
| 235 |
-
// Repeat the partitioning with identity layouts
|
| 236 |
-
Tensor tLSEcLSE = gmem_thr_copy_LSE.partition_S(cLSE);
|
| 237 |
-
|
| 238 |
-
cutlass::arch::wait_on_dependent_grids();
|
| 239 |
-
|
| 240 |
-
#pragma unroll
|
| 241 |
-
for (int m = 0; m < size<2>(tLSEcLSE); ++m) {
|
| 242 |
-
int mi = int(get<1>(tLSEcLSE(_0{}, _0{}, m)));
|
| 243 |
-
int idx = m_block * kBlockM + mi;
|
| 244 |
-
if (idx < max_idx) {
|
| 245 |
-
int m_idx, bidh;
|
| 246 |
-
if constexpr (!Varlen) {
|
| 247 |
-
bidh = params.seqlen_divmod.divmod(m_idx, idx);
|
| 248 |
-
} else {
|
| 249 |
-
bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);
|
| 250 |
-
}
|
| 251 |
-
Tensor mLSEpartial_cur_copy = mLSEpartial_copy(_, _, m_idx, bidh);
|
| 252 |
-
#pragma unroll
|
| 253 |
-
for (int s = 0; s < size<1>(tLSEcLSE); ++s) {
|
| 254 |
-
int si = get<0>(tLSEcLSE(_0{}, s, _0{}));
|
| 255 |
-
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast<float *>(&(tLSEsLSE(_0{}, s, m))), reinterpret_cast<int>(&(tLSEsLSE(_0{}, s, m))) / 4 % 32);}
|
| 256 |
-
if (si < num_splits) {
|
| 257 |
-
cute::copy(gmem_tiled_copy_LSE, mLSEpartial_cur_copy(_, si), tLSEsLSE(_, s, m));
|
| 258 |
-
} else {
|
| 259 |
-
cute::fill(tLSEsLSE(_, s, m), -INFINITY);
|
| 260 |
-
}
|
| 261 |
-
}
|
| 262 |
-
} else {
|
| 263 |
-
// We don't need to zero out the rest of the LSEs, as we will not write the output to gmem
|
| 264 |
-
// cute::fill(tLSEsLSE(_, _, m), -INFINITY);
|
| 265 |
-
}
|
| 266 |
-
}
|
| 267 |
-
if constexpr (Has_cp_async) { cute::cp_async_fence(); }
|
| 268 |
-
|
| 269 |
-
// Step 2: Load O_partial from gmem -> smem for split = 0, 1, ..., kStages - 2.
|
| 270 |
-
// We want these async loads to be in flight as we compute the LSE.
|
| 271 |
-
GmemTiledCopyAccum gmem_tiled_copy_O_partial;
|
| 272 |
-
auto gmem_thr_copy_O_partial = gmem_tiled_copy_O_partial.get_thread_slice(thread_idx);
|
| 273 |
-
// Construct identity layout for gO
|
| 274 |
-
Tensor cO = cute::make_identity_tensor(TileShape_MK{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
| 275 |
-
// Repeat the partitioning with identity layouts
|
| 276 |
-
Tensor tOcO = gmem_thr_copy_O_partial.partition_D(cO);
|
| 277 |
-
Tensor mOpartial = make_tensor(make_gmem_ptr(params.ptr_O_partial + offset * get<0>(params.stride_O_partial)),
|
| 278 |
-
params.shape_O_partial, params.stride_O_partial)(_, _, _, _, !Varlen ? batch : 0); // (seqlen, d, num_splits, head)
|
| 279 |
-
|
| 280 |
-
// Precompute these values to avoid recomputing them in the loop
|
| 281 |
-
Tensor tOmidx = make_tensor<int>(make_shape(size<1>(tOcO)));
|
| 282 |
-
Tensor tObidh = make_tensor<int>(make_shape(size<1>(tOcO)));
|
| 283 |
-
Tensor tOrOptr = make_tensor<ElementPartial const*>(make_shape(size<1>(tOcO)));
|
| 284 |
-
#pragma unroll
|
| 285 |
-
for (int m = 0; m < size<1>(tOcO); ++m) {
|
| 286 |
-
int mi = get<0>(tOcO(_0{}, m, _0{}));
|
| 287 |
-
int idx = m_block * kBlockM + mi;
|
| 288 |
-
if constexpr (!Varlen) {
|
| 289 |
-
tObidh(m) = params.seqlen_divmod.divmod(tOmidx(m), idx);
|
| 290 |
-
} else {
|
| 291 |
-
tObidh[m] = seqlen_divmod_dynamic.divmod(tOmidx(m), idx);
|
| 292 |
-
}
|
| 293 |
-
tOrOptr[m] = &mOpartial(tOmidx(m), k_block * kBlockK, _0{}, tObidh(m));
|
| 294 |
-
if (idx >= max_idx) {
|
| 295 |
-
tObidh[m] = -1;
|
| 296 |
-
}
|
| 297 |
-
}
|
| 298 |
-
|
| 299 |
-
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOcO)));
|
| 300 |
-
if constexpr (!(Is_even_K)) {
|
| 301 |
-
#pragma unroll
|
| 302 |
-
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(params.shape_O_partial) - k_block * kBlockK; }
|
| 303 |
-
}
|
| 304 |
-
|
| 305 |
-
Tensor tOsOpartial = gmem_thr_copy_O_partial.partition_D(sO);
|
| 306 |
-
|
| 307 |
-
auto load_O_partial = [&] (int split, int stage) {
|
| 308 |
-
Tensor tOsOpartial_cur = tOsOpartial(_, _, _, stage);
|
| 309 |
-
#pragma unroll
|
| 310 |
-
for (int m = 0; m < size<1>(tOcO); ++m) {
|
| 311 |
-
if (tObidh(m) >= 0) {
|
| 312 |
-
Tensor mOpartial_cur = make_tensor(make_gmem_ptr(tOrOptr[m]), mOpartial(_0{}, _, _, _0{}).layout());
|
| 313 |
-
Tensor mOpartial_cur_copy = cute::tiled_divide(mOpartial_cur, Shape<Int<kGmemElemsPerLoad>>{});
|
| 314 |
-
#pragma unroll
|
| 315 |
-
for (int k = 0; k < size<2>(tOcO); ++k) {
|
| 316 |
-
int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;
|
| 317 |
-
if (Is_even_K || tOpO(k)) {
|
| 318 |
-
cute::copy(gmem_tiled_copy_O_partial, mOpartial_cur_copy(_, k_idx, split), tOsOpartial_cur(_, m, k));
|
| 319 |
-
}
|
| 320 |
-
}
|
| 321 |
-
}
|
| 322 |
-
}
|
| 323 |
-
};
|
| 324 |
-
|
| 325 |
-
for (int s = 0; s < kStages - 1; ++s) {
|
| 326 |
-
if (s < num_splits) { load_O_partial(s, s); }
|
| 327 |
-
if constexpr (Has_cp_async) { cute::cp_async_fence(); }
|
| 328 |
-
}
|
| 329 |
-
|
| 330 |
-
// Step 3: load and transpose LSE_partial from smem -> rmem
|
| 331 |
-
if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }
|
| 332 |
-
__syncthreads();
|
| 333 |
-
|
| 334 |
-
S2RTiledCopyLSE s2r_tiled_copy_LSE;
|
| 335 |
-
auto s2r_thr_copy_LSE = s2r_tiled_copy_LSE.get_thread_slice(thread_idx);
|
| 336 |
-
Tensor ts2rsLSE = s2r_thr_copy_LSE.partition_S(sLSE);
|
| 337 |
-
Tensor ts2rrLSE = make_fragment_like(ts2rsLSE);
|
| 338 |
-
cute::copy(s2r_tiled_copy_LSE, ts2rsLSE, ts2rrLSE);
|
| 339 |
-
|
| 340 |
-
// Step 4: compute the final LSE along the split dimension
|
| 341 |
-
Tensor lse_sum = make_tensor<float>(make_shape(size<2>(ts2rrLSE)));
|
| 342 |
-
Tensor ts2rcLSE = s2r_thr_copy_LSE.partition_D(cLSE);
|
| 343 |
-
// We compute the max valid split for each row to short-circuit the computation later
|
| 344 |
-
Tensor max_valid_split = make_tensor<int>(make_shape(size<2>(ts2rrLSE)));
|
| 345 |
-
static_assert(CUTE_STATIC_V(size<0>(ts2rrLSE)) == 1);
|
| 346 |
-
#pragma unroll
|
| 347 |
-
for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
|
| 348 |
-
float lse_max = ts2rrLSE(_0{}, _0{}, m);
|
| 349 |
-
#pragma unroll
|
| 350 |
-
for (int s = 1; s < size<1>(ts2rrLSE); ++s) { lse_max = max(lse_max, ts2rrLSE(_0{}, s, m)); }
|
| 351 |
-
MaxOp<float> max_op;
|
| 352 |
-
lse_max = Allreduce<kSmemThreadsPerColLSEt>::run(lse_max, max_op);
|
| 353 |
-
int max_valid_idx = -1;
|
| 354 |
-
#pragma unroll
|
| 355 |
-
for (int s = 0; s < size<1>(ts2rrLSE); ++s) {
|
| 356 |
-
if (ts2rrLSE(_0{}, s, m) != -INFINITY) { max_valid_idx = get<0>(ts2rcLSE(_0{}, s, _0{})); }
|
| 357 |
-
}
|
| 358 |
-
MaxOp<int> max_int_op;
|
| 359 |
-
max_valid_split[m] = Allreduce<kSmemThreadsPerColLSEt>::run(max_valid_idx, max_int_op);
|
| 360 |
-
float lse_max_cur = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf
|
| 361 |
-
float lse_sum_cur = 0.f;
|
| 362 |
-
#pragma unroll
|
| 363 |
-
for (int s = 0; s < size<1>(ts2rrLSE); ++s) {
|
| 364 |
-
float scale = expf(ts2rrLSE(_0{}, s, m) - lse_max_cur);
|
| 365 |
-
lse_sum_cur += scale;
|
| 366 |
-
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && thread_idx < 32) { printf("thread_idx = %d, m = %d, s = %d, addr = %p, bank = %d\n", thread_idx, m, s, reinterpret_cast<float *>(&(ts2rsLSE(_0{}, s, m))), reinterpret_cast<int>(&(ts2rsLSE(_0{}, s, m))) / 4 % 32);}
|
| 367 |
-
// ts2rsLSE(_0{}, m, s) = scale;
|
| 368 |
-
ts2rrLSE(_0{}, s, m) = scale;
|
| 369 |
-
}
|
| 370 |
-
SumOp<float> sum_op;
|
| 371 |
-
lse_sum_cur = Allreduce<kSmemThreadsPerColLSEt>::run(lse_sum_cur, sum_op);
|
| 372 |
-
lse_sum(m) = logf(lse_sum_cur) + lse_max;
|
| 373 |
-
float inv_sum = (lse_sum_cur == 0.f || lse_sum_cur != lse_sum_cur) ? 0.f : 1.f / lse_sum_cur;
|
| 374 |
-
#pragma unroll
|
| 375 |
-
for (int s = 0; s < size<1>(ts2rrLSE); ++s) { ts2rrLSE(_0{}, s, m) *= inv_sum; }
|
| 376 |
-
}
|
| 377 |
-
// Store the scales exp(lse - lse_logsum) back to smem
|
| 378 |
-
cute::copy(s2r_tiled_copy_LSE, ts2rrLSE, ts2rsLSE);
|
| 379 |
-
|
| 380 |
-
// Store max_valid_split to smem
|
| 381 |
-
#pragma unroll
|
| 382 |
-
for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
|
| 383 |
-
if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to smem
|
| 384 |
-
int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));
|
| 385 |
-
if (mi < kBlockM) { sMaxValidSplit[mi] = max_valid_split[m]; }
|
| 386 |
-
}
|
| 387 |
-
}
|
| 388 |
-
|
| 389 |
-
// Step 5: store final LSE back to gmem
|
| 390 |
-
if (k_block == 0) {
|
| 391 |
-
auto shape_LSE = select<0, 2, 3>(params.shape_LSE_partial);
|
| 392 |
-
Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE + offset * get<0>(params.stride_LSE)), shape_LSE, params.stride_LSE)(_, _, !Varlen ? batch : 0);
|
| 393 |
-
#pragma unroll
|
| 394 |
-
for (int m = 0; m < size<2>(ts2rrLSE); ++m) {
|
| 395 |
-
if (get<0>(ts2rcLSE(_0{}, _0{}, m)) == 0) { // Only the thread responsible for s=0 writes to gmem
|
| 396 |
-
int mi = int(get<1>(ts2rcLSE(_0{}, _0{}, m)));
|
| 397 |
-
int idx = m_block * kBlockM + mi;
|
| 398 |
-
if (idx < max_idx) {
|
| 399 |
-
int m_idx, bidh;
|
| 400 |
-
if constexpr (!Varlen) {
|
| 401 |
-
bidh = params.seqlen_divmod.divmod(m_idx, idx);
|
| 402 |
-
} else {
|
| 403 |
-
bidh = seqlen_divmod_dynamic.divmod(m_idx, idx);
|
| 404 |
-
}
|
| 405 |
-
// printf("thread_idx = %d, m = %d, mi = %d, idx = %d, m_idx = %d, bidh = %d, bidb = %d, lse_sum = %f\n", thread_idx, m, mi, idx, m_idx, bidh, bidb, lse_sum(m));
|
| 406 |
-
mLSE(m_idx, bidh) = lse_sum(m);
|
| 407 |
-
}
|
| 408 |
-
}
|
| 409 |
-
}
|
| 410 |
-
}
|
| 411 |
-
|
| 412 |
-
// Step 6: read O_partial from gmem -> smem -> rmem and accumulate the final O
|
| 413 |
-
__syncthreads();
|
| 414 |
-
int thr_max_valid_split = sMaxValidSplit[get<0>(tOcO(_0{}, _0{}, _0{}))];
|
| 415 |
-
#pragma unroll
|
| 416 |
-
for (int m = 1; m < size<1>(tOcO); ++m) { thr_max_valid_split = max(thr_max_valid_split, sMaxValidSplit[get<0>(tOcO(_0{}, m, _0{}))]); }
|
| 417 |
-
Layout tOrOpartial_layout = gmem_thr_copy_O_partial.partition_S(make_tensor<ElementPartial>(TileShape_MK{})).layout();
|
| 418 |
-
Tensor tOrOpartial = make_fragment_like<ElementPartial>(tOrOpartial_layout);
|
| 419 |
-
Tensor tOrO = make_fragment_like<float>(tOrOpartial);
|
| 420 |
-
clear(tOrO);
|
| 421 |
-
int stage_load = kStages - 1, stage_compute = 0;
|
| 422 |
-
#pragma unroll 4 // Already tuned for speed
|
| 423 |
-
for (int s = 0; s <= thr_max_valid_split; ++s) {
|
| 424 |
-
Tensor scale = make_tensor<float>(make_shape(size<1>(tOrOpartial)));
|
| 425 |
-
#pragma unroll
|
| 426 |
-
for (int m = 0; m < size<1>(tOrOpartial); ++m) { scale(m) = sLSE(s, get<0>(tOcO(_0{}, m, _0{}))); }
|
| 427 |
-
|
| 428 |
-
if (s + kStages - 1 <= thr_max_valid_split) { load_O_partial(s + kStages - 1, stage_load); }
|
| 429 |
-
if constexpr (Has_cp_async) { cute::cp_async_fence(); }
|
| 430 |
-
stage_load = stage_load < kStages - 1 ? stage_load + 1 : 0;
|
| 431 |
-
if constexpr (Has_cp_async) { cutlass::arch::cp_async_wait<kStages - 1>(); }
|
| 432 |
-
// We don't need __syncthreads() because each thread is just reading its own data from smem
|
| 433 |
-
cute::copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementPartial>{},
|
| 434 |
-
tOsOpartial(_, _, _, stage_compute), tOrOpartial);
|
| 435 |
-
stage_compute = stage_compute < kStages - 1 ? stage_compute + 1 : 0;
|
| 436 |
-
|
| 437 |
-
#pragma unroll
|
| 438 |
-
for (int m = 0; m < size<1>(tOrOpartial); ++m) {
|
| 439 |
-
if (tObidh(m) >= 0 && scale(m) > 0.f) {
|
| 440 |
-
#pragma unroll
|
| 441 |
-
for (int k = 0; k < size<2>(tOrOpartial); ++k) {
|
| 442 |
-
if (Is_even_K || tOpO(k)) {
|
| 443 |
-
Tensor rOpartial = make_tensor_like<float>(tOrOpartial(_, m, k));
|
| 444 |
-
flash::convert_type_out(tOrOpartial(_, m, k), rOpartial);
|
| 445 |
-
#pragma unroll
|
| 446 |
-
for (int i = 0; i < size<0>(tOrOpartial); ++i) {
|
| 447 |
-
tOrO(i, m, k) += scale(m) * rOpartial[i];
|
| 448 |
-
}
|
| 449 |
-
}
|
| 450 |
-
}
|
| 451 |
-
}
|
| 452 |
-
}
|
| 453 |
-
}
|
| 454 |
-
|
| 455 |
-
// Step 7: Write the final O to gmem
|
| 456 |
-
Tensor rO = make_tensor_like<Element>(tOrO);
|
| 457 |
-
flash::convert_type_out(tOrO, rO);
|
| 458 |
-
auto shape_O = make_shape(get<0>(params.shape_O_partial), get<1>(params.shape_O_partial) - k_block * kBlockK, get<3>(params.shape_O_partial), get<4>(params.shape_O_partial));
|
| 459 |
-
Tensor mO = make_tensor(make_gmem_ptr(params.ptr_O + offset * get<0>(params.stride_O) + k_block * kBlockK * get<1>(params.stride_O)),
|
| 460 |
-
shape_O, params.stride_O)(_, _, _, !Varlen ? batch : 0);
|
| 461 |
-
Tensor mO_copy = cute::tiled_divide(mO, Shape<_1, Int<kGmemElemsPerLoad>>{});
|
| 462 |
-
GmemTiledCopy gmem_tiled_copy_O;
|
| 463 |
-
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx);
|
| 464 |
-
|
| 465 |
-
#pragma unroll
|
| 466 |
-
for (int m = 0; m < size<1>(tOcO); ++m) {
|
| 467 |
-
if (tObidh(m) >= 0) {
|
| 468 |
-
#pragma unroll
|
| 469 |
-
for (int k = 0; k < size<2>(tOcO); ++k) {
|
| 470 |
-
int k_idx = get<1>(tOcO(_0{}, _0{}, k)) / kGmemElemsPerLoad;
|
| 471 |
-
if (Is_even_K || tOpO(k)) {
|
| 472 |
-
cute::copy(gmem_tiled_copy_O, rO(_, m, k), mO_copy(_, tOmidx(m), k_idx, tObidh(m)));
|
| 473 |
-
}
|
| 474 |
-
}
|
| 475 |
-
}
|
| 476 |
-
}
|
| 477 |
-
|
| 478 |
-
}
|
| 479 |
-
|
| 480 |
-
};
|
| 481 |
-
|
| 482 |
-
} // namespace flash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/flash_fwd_combine_launch_template.h
DELETED
|
@@ -1,80 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
#include "cute/tensor.hpp"
|
| 8 |
-
|
| 9 |
-
#include "cutlass/cutlass.h"
|
| 10 |
-
#include "cutlass/arch/arch.h" // For cutlass::arch::Sm80
|
| 11 |
-
#include "cutlass/device_kernel.h" // For device_kernel
|
| 12 |
-
#include "cutlass/kernel_launch.h" // For kernel_launch
|
| 13 |
-
|
| 14 |
-
#include "static_switch.h"
|
| 15 |
-
#include "flash.h"
|
| 16 |
-
#include "flash_fwd_combine_kernel.h"
|
| 17 |
-
|
| 18 |
-
using namespace cute;
|
| 19 |
-
|
| 20 |
-
template <int Arch, int kBlockM, int kBlockK, int kLogMaxSplits, bool IsEvenK, bool Varlen, typename Element, typename ElementPartial>
|
| 21 |
-
void run_flash_fwd_combine(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) {
|
| 22 |
-
using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
|
| 23 |
-
using TileShape_MK = cute::Shape<Int<kBlockM>, Int<kBlockK>>;
|
| 24 |
-
using CombineKernel = flash::FlashAttnFwdCombine<TileShape_MK, kLogMaxSplits, 256 /*kNThreads*/, 1 /*AlignmentLSE*/,
|
| 25 |
-
IsEvenK, Varlen, Element, ElementPartial, ArchTag>;
|
| 26 |
-
|
| 27 |
-
typename CombineKernel::Arguments args {
|
| 28 |
-
static_cast<ElementPartial const*>(params.oaccum_ptr),
|
| 29 |
-
{!Varlen ? params.seqlen_q : params.total_q, params.dv, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_O_partial
|
| 30 |
-
{params.oaccum_row_stride, _1{}, params.oaccum_split_stride, params.oaccum_head_stride, !Varlen ? params.oaccum_batch_stride : 0}, // stride_O_partial
|
| 31 |
-
static_cast<float*>(params.softmax_lseaccum_ptr),
|
| 32 |
-
{!Varlen ? params.seqlen_q : params.total_q, params.num_splits, params.h, !Varlen ? params.b : 1}, // shape_LSE_partial
|
| 33 |
-
{_1{}, params.lseaccum_split_stride, params.lseaccum_head_stride, !Varlen ? params.lseaccum_batch_stride : 0}, // stride_LSE_partial
|
| 34 |
-
static_cast<Element*>(params.o_ptr),
|
| 35 |
-
{params.o_row_stride, _1{}, params.o_head_stride, !Varlen ? params.o_batch_stride : 0}, // stride_O
|
| 36 |
-
static_cast<float*>(params.softmax_lse_ptr),
|
| 37 |
-
{_1{}, !Varlen ? params.seqlen_q : params.total_q, !Varlen ? params.h * params.seqlen_q : 0}, // stride_LSE
|
| 38 |
-
params.cu_seqlens_q, params.seqused_q, params.num_splits_dynamic_ptr, params.tile_count_semaphore
|
| 39 |
-
};
|
| 40 |
-
|
| 41 |
-
typename CombineKernel::Params kernel_params = CombineKernel::to_underlying_arguments(args);
|
| 42 |
-
int num_blocks_k = cute::ceil_div(params.dv, kBlockK);
|
| 43 |
-
int num_blocks_m = cute::ceil_div(params.seqlen_q * params.h, kBlockM);
|
| 44 |
-
dim3 grid_m(num_blocks_m, num_blocks_k, params.b);
|
| 45 |
-
auto kernel = cutlass::device_kernel<CombineKernel>;
|
| 46 |
-
int smem_size = CombineKernel::SharedStorageSize;
|
| 47 |
-
if (smem_size >= 48 * 1024) {
|
| 48 |
-
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
| 49 |
-
}
|
| 50 |
-
// kernel<<<grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream>>>(kernel_params);
|
| 51 |
-
cutlass::kernel_launch<CombineKernel>(grid_m, CombineKernel::MaxThreadsPerBlock, smem_size, stream, kernel_params, Arch >= 90 && enable_pdl /*launch_with_pdl*/);
|
| 52 |
-
CHECK_CUDA_KERNEL_LAUNCH();
|
| 53 |
-
}
|
| 54 |
-
|
| 55 |
-
template<typename T, typename Tpartial, int kBlockK>
|
| 56 |
-
void run_mha_fwd_combine_(Flash_fwd_params ¶ms, cudaStream_t stream, bool enable_pdl) {
|
| 57 |
-
// We want kBlockM to be as small as possible to maximize parallelism.
|
| 58 |
-
// E.g., if hdim is 64, we want kBlockM to be 16 so that we can use 256 threads, each reading 4 elements (floats).
|
| 59 |
-
static_assert(kBlockK % 32 == 0, "kBlockK must be a multiple of 32");
|
| 60 |
-
static constexpr int kBlockM = kBlockK % 128 == 0 ? 8 : (kBlockK % 64 == 0 ? 16 : 32);
|
| 61 |
-
ARCH_SWITCH(params.arch, Arch, [&] {
|
| 62 |
-
BOOL_SWITCH(params.cu_seqlens_q || params.seqused_q, Varlen, [&] {
|
| 63 |
-
if constexpr (kBlockM >= 16) { // If kBlockM == 8 then the minimum number of splits is 32.
|
| 64 |
-
if (params.num_splits <= 16) {
|
| 65 |
-
run_flash_fwd_combine<Arch, kBlockM, kBlockK, 4, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
|
| 66 |
-
return;
|
| 67 |
-
}
|
| 68 |
-
}
|
| 69 |
-
if (params.num_splits <= 32) {
|
| 70 |
-
run_flash_fwd_combine<Arch, kBlockM, kBlockK, 5, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
|
| 71 |
-
} else if (params.num_splits <= 64) {
|
| 72 |
-
run_flash_fwd_combine<Arch, kBlockM, kBlockK, 6, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
|
| 73 |
-
} else if (params.num_splits <= 128) {
|
| 74 |
-
run_flash_fwd_combine<Arch, kBlockM, kBlockK, 7, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
|
| 75 |
-
} else {
|
| 76 |
-
run_flash_fwd_combine<Arch, kBlockM, kBlockK, 8, false /*IsEvenK*/, Varlen, T, Tpartial>(params, stream, enable_pdl);
|
| 77 |
-
}
|
| 78 |
-
});
|
| 79 |
-
});
|
| 80 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/flash_fwd_kernel_sm80.h
DELETED
|
@@ -1,215 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
#include "cute/tensor.hpp"
|
| 8 |
-
|
| 9 |
-
#include <cutlass/cutlass.h>
|
| 10 |
-
#include <cutlass/array.h>
|
| 11 |
-
#include <cutlass/numeric_types.h>
|
| 12 |
-
#include <cutlass/kernel_hardware_info.h>
|
| 13 |
-
|
| 14 |
-
#include "seqlen.h"
|
| 15 |
-
#include "utils.h"
|
| 16 |
-
#include "softmax.h"
|
| 17 |
-
|
| 18 |
-
namespace flash {
|
| 19 |
-
|
| 20 |
-
using namespace cute;
|
| 21 |
-
|
| 22 |
-
template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
|
| 23 |
-
class FlashAttnFwdSm80 {
|
| 24 |
-
|
| 25 |
-
public:
|
| 26 |
-
|
| 27 |
-
// Type Aliases
|
| 28 |
-
using CollectiveMainloop = CollectiveMainloop_;
|
| 29 |
-
using CollectiveEpilogue = CollectiveEpilogue_;
|
| 30 |
-
static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
|
| 31 |
-
static constexpr bool Is_local = CollectiveMainloop::Is_local;
|
| 32 |
-
static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
|
| 33 |
-
static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
|
| 34 |
-
static constexpr bool Varlen = CollectiveMainloop::Varlen;
|
| 35 |
-
static constexpr bool PagedKV = CollectiveMainloop::PagedKV;
|
| 36 |
-
static constexpr bool Split = CollectiveMainloop::Split;
|
| 37 |
-
static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
|
| 38 |
-
static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
|
| 39 |
-
static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
|
| 40 |
-
static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
|
| 41 |
-
static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
|
| 42 |
-
using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
|
| 43 |
-
|
| 44 |
-
// Mainloop derived types
|
| 45 |
-
using TileShape_MNK = typename CollectiveMainloop::TileShape_MNK;
|
| 46 |
-
using TiledMma = typename CollectiveMainloop::TiledMma;
|
| 47 |
-
using ArchTag = typename CollectiveMainloop::ArchTag;
|
| 48 |
-
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
| 49 |
-
using MainloopParams = typename CollectiveMainloop::Params;
|
| 50 |
-
|
| 51 |
-
// Epilogue derived types
|
| 52 |
-
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
| 53 |
-
using EpilogueParams = typename CollectiveEpilogue::Params;
|
| 54 |
-
|
| 55 |
-
static_assert(ArchTag::kMinComputeCapability >= 80);
|
| 56 |
-
|
| 57 |
-
using TileScheduler = TileScheduler_;
|
| 58 |
-
using TileSchedulerArguments = typename flash::TileSchedulerArguments;
|
| 59 |
-
using TileSchedulerParams = typename TileScheduler::Params;
|
| 60 |
-
|
| 61 |
-
static constexpr uint32_t NumThreads = CUTE_STATIC_V(size(TiledMma{}));
|
| 62 |
-
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{}));
|
| 63 |
-
static constexpr uint32_t MinBlocksPerMultiprocessor = NumThreads == 128 ? 2 : 1;
|
| 64 |
-
|
| 65 |
-
// Kernel level shared memory storage
|
| 66 |
-
// We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v + smem_k and not smem_q
|
| 67 |
-
// and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v) + sizeof(smem_k).
|
| 68 |
-
static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage))
|
| 69 |
-
- int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)))
|
| 70 |
-
- int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k)));
|
| 71 |
-
static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
|
| 72 |
-
struct SharedStorage {
|
| 73 |
-
struct TensorStorage : cute::aligned_struct<128> {
|
| 74 |
-
union {
|
| 75 |
-
struct {
|
| 76 |
-
cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
|
| 77 |
-
typename CollectiveMainloop::TensorStorage mainloop;
|
| 78 |
-
};
|
| 79 |
-
// We want smem_o to line up with the start of smem_v
|
| 80 |
-
typename CollectiveEpilogue::TensorStorage epilogue;
|
| 81 |
-
};
|
| 82 |
-
} tensors;
|
| 83 |
-
|
| 84 |
-
alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
|
| 85 |
-
|
| 86 |
-
};
|
| 87 |
-
|
| 88 |
-
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
| 89 |
-
|
| 90 |
-
// Device side arguments
|
| 91 |
-
struct Arguments {
|
| 92 |
-
MainloopArguments mainloop{};
|
| 93 |
-
EpilogueArguments epilogue{};
|
| 94 |
-
cutlass::KernelHardwareInfo hw_info{};
|
| 95 |
-
TileSchedulerArguments scheduler{};
|
| 96 |
-
};
|
| 97 |
-
|
| 98 |
-
// Kernel entry point API
|
| 99 |
-
struct Params {
|
| 100 |
-
MainloopParams mainloop{};
|
| 101 |
-
EpilogueParams epilogue{};
|
| 102 |
-
cutlass::KernelHardwareInfo hw_info{};
|
| 103 |
-
TileSchedulerParams scheduler{};
|
| 104 |
-
};
|
| 105 |
-
|
| 106 |
-
//
|
| 107 |
-
// Methods
|
| 108 |
-
//
|
| 109 |
-
|
| 110 |
-
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
| 111 |
-
static
|
| 112 |
-
Params
|
| 113 |
-
to_underlying_arguments(Arguments const& args) {
|
| 114 |
-
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
| 115 |
-
|
| 116 |
-
// Get SM count if needed, otherwise use user supplied SM count
|
| 117 |
-
int sm_count = args.hw_info.sm_count;
|
| 118 |
-
if (sm_count <= 0) {
|
| 119 |
-
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
| 120 |
-
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
| 121 |
-
sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
|
| 122 |
-
}
|
| 123 |
-
|
| 124 |
-
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
| 125 |
-
|
| 126 |
-
cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
| 127 |
-
return {
|
| 128 |
-
CollectiveMainloop::to_underlying_arguments(args.mainloop),
|
| 129 |
-
CollectiveEpilogue::to_underlying_arguments(args.epilogue),
|
| 130 |
-
hw_info,
|
| 131 |
-
TileScheduler::to_underlying_arguments(args.scheduler)
|
| 132 |
-
};
|
| 133 |
-
}
|
| 134 |
-
|
| 135 |
-
// Computes the kernel launch grid shape based on runtime parameters
|
| 136 |
-
static dim3
|
| 137 |
-
get_grid_shape(Params const& params) {
|
| 138 |
-
return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count * MinBlocksPerMultiprocessor);
|
| 139 |
-
}
|
| 140 |
-
|
| 141 |
-
static dim3
|
| 142 |
-
get_block_shape() {
|
| 143 |
-
return dim3(MaxThreadsPerBlock, 1, 1);
|
| 144 |
-
}
|
| 145 |
-
|
| 146 |
-
CUTLASS_DEVICE
|
| 147 |
-
void
|
| 148 |
-
operator()(Params const& params, char* smem_buf) {
|
| 149 |
-
|
| 150 |
-
static constexpr int kBlockM = get<0>(TileShape_MNK{});
|
| 151 |
-
|
| 152 |
-
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
| 153 |
-
|
| 154 |
-
CollectiveMainloop mainloop;
|
| 155 |
-
CollectiveEpilogue epilogue;
|
| 156 |
-
|
| 157 |
-
TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.smem_scheduler));
|
| 158 |
-
// Initialize matmul objects.
|
| 159 |
-
TiledMma tiled_mma;
|
| 160 |
-
|
| 161 |
-
scheduler.init_consumer();
|
| 162 |
-
|
| 163 |
-
int warp_idx = cutlass::canonical_warp_idx_sync();
|
| 164 |
-
CUTLASS_PRAGMA_NO_UNROLL
|
| 165 |
-
for (auto work_tile_info = warp_idx == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
|
| 166 |
-
work_tile_info.is_valid(params.scheduler);
|
| 167 |
-
work_tile_info = warp_idx == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
|
| 168 |
-
// Attention output (GEMM-II) accumulator.
|
| 169 |
-
Tensor tOrO = partition_fragment_C(tiled_mma, select<0, 2>(TileShape_MNK{}));
|
| 170 |
-
float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
|
| 171 |
-
// If there's tanh softcap, the scaling will be done before tanh.
|
| 172 |
-
auto block_coord = work_tile_info.get_block_coord(params.scheduler);
|
| 173 |
-
int const bidb = get<2>(block_coord);
|
| 174 |
-
if constexpr (Is_FP8 && !Has_softcap) {
|
| 175 |
-
int const bidh = get<1>(block_coord);
|
| 176 |
-
int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
|
| 177 |
-
float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
|
| 178 |
-
float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
|
| 179 |
-
softmax_scale_log2 *= q_descale * k_descale;
|
| 180 |
-
}
|
| 181 |
-
flash::Softmax<2 * (2 * kBlockM / NumThreads), /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
|
| 182 |
-
|
| 183 |
-
SeqlenInfo_t seqlen_info{
|
| 184 |
-
bidb,
|
| 185 |
-
get<0>(params.mainloop.shape_Q),
|
| 186 |
-
!PagedKV ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
|
| 187 |
-
get<0>(params.mainloop.shape_K_new),
|
| 188 |
-
params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
|
| 189 |
-
params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
|
| 190 |
-
params.mainloop.seqlens_rotary
|
| 191 |
-
};
|
| 192 |
-
if constexpr (AppendKV) {
|
| 193 |
-
bool tile_new_valid = mainloop.store_kv_new(
|
| 194 |
-
params.mainloop, threadIdx.x, shared_storage, seqlen_info, block_coord);
|
| 195 |
-
if (tile_new_valid) { __syncthreads(); }
|
| 196 |
-
}
|
| 197 |
-
bool tile_valid = mainloop.mma(
|
| 198 |
-
params.mainloop, tOrO, softmax, threadIdx.x, seqlen_info, block_coord,
|
| 199 |
-
shared_storage);
|
| 200 |
-
scheduler.prefetch_next_work(params.scheduler, work_tile_info);
|
| 201 |
-
if (tile_valid) {
|
| 202 |
-
// if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
|
| 203 |
-
epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma,
|
| 204 |
-
threadIdx.x, block_coord);
|
| 205 |
-
} else {
|
| 206 |
-
// Write 0 to gO and -inf to gLSE.
|
| 207 |
-
epilogue.store_zero(params.epilogue, threadIdx.x, block_coord);
|
| 208 |
-
}
|
| 209 |
-
}
|
| 210 |
-
|
| 211 |
-
}
|
| 212 |
-
|
| 213 |
-
};
|
| 214 |
-
|
| 215 |
-
} // namespace flash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/flash_fwd_kernel_sm90.h
DELETED
|
@@ -1,458 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
#include "cute/tensor.hpp"
|
| 8 |
-
|
| 9 |
-
#include <cutlass/cutlass.h>
|
| 10 |
-
#include <cutlass/arch/reg_reconfig.h>
|
| 11 |
-
#include <cutlass/array.h>
|
| 12 |
-
#include <cutlass/numeric_types.h>
|
| 13 |
-
#include <cutlass/numeric_conversion.h>
|
| 14 |
-
#include <cutlass/kernel_hardware_info.h>
|
| 15 |
-
#include "cutlass/pipeline/pipeline.hpp"
|
| 16 |
-
|
| 17 |
-
#include "cutlass/arch/grid_dependency_control.h"
|
| 18 |
-
|
| 19 |
-
#include "seqlen.h"
|
| 20 |
-
#include "utils.h"
|
| 21 |
-
#include "softmax.h"
|
| 22 |
-
|
| 23 |
-
namespace flash {
|
| 24 |
-
|
| 25 |
-
using namespace cute;
|
| 26 |
-
|
| 27 |
-
template <class CollectiveMainloop_, class CollectiveEpilogue_, class TileScheduler_>
|
| 28 |
-
class FlashAttnFwdSm90 {
|
| 29 |
-
|
| 30 |
-
public:
|
| 31 |
-
|
| 32 |
-
// Type Aliases
|
| 33 |
-
using CollectiveMainloop = CollectiveMainloop_;
|
| 34 |
-
using CollectiveEpilogue = CollectiveEpilogue_;
|
| 35 |
-
static constexpr bool Is_causal = CollectiveMainloop::Is_causal;
|
| 36 |
-
static constexpr bool Is_local = CollectiveMainloop::Is_local;
|
| 37 |
-
static_assert(CollectiveMainloop::Varlen == CollectiveEpilogue::Varlen);
|
| 38 |
-
static constexpr bool Has_softcap = CollectiveMainloop::Has_softcap;
|
| 39 |
-
static constexpr bool Varlen = CollectiveMainloop::Varlen;
|
| 40 |
-
static constexpr bool Split = CollectiveMainloop::Split;
|
| 41 |
-
static constexpr bool Is_FP8 = CollectiveMainloop::Is_FP8;
|
| 42 |
-
static constexpr bool Transpose_V = CollectiveMainloop::Transpose_V;
|
| 43 |
-
static constexpr bool AppendKV = CollectiveMainloop::AppendKV;
|
| 44 |
-
static constexpr bool HasQv = CollectiveMainloop::HasQv;
|
| 45 |
-
static constexpr bool Use_TMA_Q = CollectiveMainloop::Use_TMA_Q;
|
| 46 |
-
static constexpr bool Use_TMA_KV = CollectiveMainloop::Use_TMA_KV;
|
| 47 |
-
static constexpr bool Use_TMA_O = CollectiveEpilogue::Use_TMA_O;
|
| 48 |
-
static constexpr bool PackGQA = CollectiveMainloop::PackGQA;
|
| 49 |
-
static constexpr int NumProducerThreads = CollectiveMainloop::NumProducerThreads;
|
| 50 |
-
static constexpr bool SameHeadDim = CollectiveMainloop::SameHeadDim;
|
| 51 |
-
static constexpr bool LargeHeadDimV = CollectiveMainloop::LargeHeadDimV;
|
| 52 |
-
static_assert(CollectiveMainloop::LargeHeadDimV == CollectiveEpilogue::LargeHeadDimV);
|
| 53 |
-
using SeqlenInfo_t = typename CollectiveMainloop::SeqlenInfo_t;
|
| 54 |
-
|
| 55 |
-
// Mainloop derived types
|
| 56 |
-
using TileShape_MNK_PV = typename CollectiveMainloop::TileShape_MNK_PV;
|
| 57 |
-
using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV;
|
| 58 |
-
using ArchTag = typename CollectiveMainloop::ArchTag;
|
| 59 |
-
using ClusterShape = typename CollectiveMainloop::ClusterShape;
|
| 60 |
-
using MainloopArguments = typename CollectiveMainloop::Arguments;
|
| 61 |
-
using MainloopParams = typename CollectiveMainloop::Params;
|
| 62 |
-
using BarrierQ = std::conditional_t<Use_TMA_Q, cutlass::arch::ClusterTransactionBarrier, cutlass::arch::ClusterBarrier>;
|
| 63 |
-
|
| 64 |
-
// Epilogue derived types
|
| 65 |
-
using EpilogueArguments = typename CollectiveEpilogue::Arguments;
|
| 66 |
-
using EpilogueParams = typename CollectiveEpilogue::Params;
|
| 67 |
-
|
| 68 |
-
static_assert(ArchTag::kMinComputeCapability >= 90);
|
| 69 |
-
|
| 70 |
-
using TileScheduler = TileScheduler_;
|
| 71 |
-
using TileSchedulerArguments = typename flash::TileSchedulerArguments;
|
| 72 |
-
using TileSchedulerParams = typename TileScheduler::Params;
|
| 73 |
-
|
| 74 |
-
static constexpr uint32_t NumLoadWarpGroups = 1;
|
| 75 |
-
static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(size(TiledMmaPV{})) / cutlass::NumThreadsPerWarpGroup;
|
| 76 |
-
static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMmaPV{})) + (NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup);
|
| 77 |
-
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;
|
| 78 |
-
static_assert(NumMmaWarpGroups == 1 || NumMmaWarpGroups == 2 || NumMmaWarpGroups == 3);
|
| 79 |
-
|
| 80 |
-
/// Register requirement for Load and Math WGs
|
| 81 |
-
// If we use cp.async to load K and V, we need more registers for the producer WG.
|
| 82 |
-
static constexpr uint32_t LoadRegisterRequirement = NumMmaWarpGroups == 1 ? 56 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 24 : 40) : 32);
|
| 83 |
-
static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 1 ? 256 : (NumMmaWarpGroups == 2 ? (Use_TMA_KV ? 240 : 232) : 160);
|
| 84 |
-
// If you want to print from the producer warp, you'd need to increase the number of registers
|
| 85 |
-
// Otherwise you'll get CUDA error.
|
| 86 |
-
// static constexpr uint32_t LoadRegisterRequirement = 40;
|
| 87 |
-
// static constexpr uint32_t MmaRegisterRequirement = NumMmaWarpGroups == 2 ? 232 : 152;
|
| 88 |
-
|
| 89 |
-
// Kernel level shared memory storage
|
| 90 |
-
// We overlap the shared memory for the mainloop and epilogue. However, we only want smem_o to overlap with smem_v
|
| 91 |
-
// and nothing else, so we'll pad in case sizeof(smem_o) > sizeof(smem_v).
|
| 92 |
-
static constexpr int mainloop_smem_padding_ = int(sizeof(typename CollectiveEpilogue::TensorStorage)) - int(sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v)));
|
| 93 |
-
static constexpr int mainloop_smem_padding = mainloop_smem_padding_ < 0 ? 0 : mainloop_smem_padding_;
|
| 94 |
-
struct SharedStorage {
|
| 95 |
-
struct TensorStorage : cute::aligned_struct<128, _1> {
|
| 96 |
-
union {
|
| 97 |
-
struct {
|
| 98 |
-
cute::array<uint32_t, mainloop_smem_padding / sizeof(uint32_t)> padding_;
|
| 99 |
-
typename CollectiveMainloop::TensorStorage mainloop;
|
| 100 |
-
};
|
| 101 |
-
// We want smem_o to line up with the start of smem_v
|
| 102 |
-
typename CollectiveEpilogue::TensorStorage epilogue;
|
| 103 |
-
};
|
| 104 |
-
} tensors;
|
| 105 |
-
struct PipelineStorage : cute::aligned_struct<16, _1> {
|
| 106 |
-
alignas(16) BarrierQ barrier_Q;
|
| 107 |
-
alignas(16) BarrierQ barrier_Qv;
|
| 108 |
-
alignas(16) cutlass::arch::ClusterBarrier barrier_O;
|
| 109 |
-
alignas(16) typename CollectiveMainloop::MainloopPipelineK::SharedStorage pipeline_k;
|
| 110 |
-
alignas(16) typename CollectiveMainloop::MainloopPipelineV::SharedStorage pipeline_v;
|
| 111 |
-
alignas(16) typename CollectiveMainloop::MainloopPipelineVt::SharedStorage pipeline_vt;
|
| 112 |
-
alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_k_new;
|
| 113 |
-
alignas(16) typename CollectiveMainloop::MainloopPipelineKVNew::SharedStorage pipeline_v_new;
|
| 114 |
-
alignas(16) typename TileScheduler::SharedStorage smem_scheduler;
|
| 115 |
-
} pipelines;
|
| 116 |
-
|
| 117 |
-
};
|
| 118 |
-
|
| 119 |
-
static constexpr int SharedStorageSize = sizeof(SharedStorage);
|
| 120 |
-
|
| 121 |
-
// Device side arguments
|
| 122 |
-
struct Arguments {
|
| 123 |
-
MainloopArguments mainloop{};
|
| 124 |
-
EpilogueArguments epilogue{};
|
| 125 |
-
cutlass::KernelHardwareInfo hw_info{};
|
| 126 |
-
TileSchedulerArguments scheduler{};
|
| 127 |
-
};
|
| 128 |
-
|
| 129 |
-
// Kernel entry point API
|
| 130 |
-
struct Params {
|
| 131 |
-
MainloopParams mainloop{};
|
| 132 |
-
EpilogueParams epilogue{};
|
| 133 |
-
cutlass::KernelHardwareInfo hw_info{};
|
| 134 |
-
TileSchedulerParams scheduler{};
|
| 135 |
-
};
|
| 136 |
-
|
| 137 |
-
//
|
| 138 |
-
// Methods
|
| 139 |
-
//
|
| 140 |
-
|
| 141 |
-
// Convert to underlying arguments. In this case, a simple copy for the aliased type.
|
| 142 |
-
static
|
| 143 |
-
Params
|
| 144 |
-
to_underlying_arguments(Arguments const& args) {
|
| 145 |
-
CUTLASS_TRACE_HOST("to_underlying_arguments():");
|
| 146 |
-
|
| 147 |
-
// Get SM count if needed, otherwise use user supplied SM count
|
| 148 |
-
int sm_count = args.hw_info.sm_count;
|
| 149 |
-
if (sm_count <= 0) {
|
| 150 |
-
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
| 151 |
-
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
| 152 |
-
sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id);
|
| 153 |
-
}
|
| 154 |
-
|
| 155 |
-
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
| 156 |
-
|
| 157 |
-
cutlass::KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count};
|
| 158 |
-
return {
|
| 159 |
-
CollectiveMainloop::to_underlying_arguments(args.mainloop),
|
| 160 |
-
CollectiveEpilogue::to_underlying_arguments(args.epilogue),
|
| 161 |
-
hw_info,
|
| 162 |
-
TileScheduler::to_underlying_arguments(args.scheduler)
|
| 163 |
-
};
|
| 164 |
-
}
|
| 165 |
-
|
| 166 |
-
// Computes the kernel launch grid shape based on runtime parameters
|
| 167 |
-
static dim3
|
| 168 |
-
get_grid_shape(Params const& params) {
|
| 169 |
-
return TileScheduler::get_grid_shape(params.scheduler, params.hw_info.sm_count);
|
| 170 |
-
}
|
| 171 |
-
|
| 172 |
-
static dim3
|
| 173 |
-
get_block_shape() {
|
| 174 |
-
return dim3(MaxThreadsPerBlock, 1, 1);
|
| 175 |
-
}
|
| 176 |
-
|
| 177 |
-
CUTLASS_DEVICE
|
| 178 |
-
void
|
| 179 |
-
operator()(Params const& params, char* smem_buf) {
|
| 180 |
-
|
| 181 |
-
static constexpr int NumMmaThreads = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup;
|
| 182 |
-
static constexpr int MmaThreadOffset = NumLoadWarpGroups * cutlass::NumThreadsPerWarpGroup;
|
| 183 |
-
static constexpr int kBlockM = get<0>(TileShape_MNK_PV{});
|
| 184 |
-
|
| 185 |
-
using MainloopPipelineK = typename CollectiveMainloop::MainloopPipelineK;
|
| 186 |
-
using MainloopPipelineV = typename CollectiveMainloop::MainloopPipelineV;
|
| 187 |
-
using MainloopPipelineVt = typename CollectiveMainloop::MainloopPipelineVt;
|
| 188 |
-
using MainloopPipelineKVNew = typename CollectiveMainloop::MainloopPipelineKVNew;
|
| 189 |
-
using PipelineState = typename CollectiveMainloop::PipelineState;
|
| 190 |
-
using PipelineParamsK = typename MainloopPipelineK::Params;
|
| 191 |
-
using PipelineParamsV = typename MainloopPipelineV::Params;
|
| 192 |
-
using PipelineParamsVt = typename MainloopPipelineVt::Params;
|
| 193 |
-
using PipelineParamsKVNew = typename MainloopPipelineKVNew::Params;
|
| 194 |
-
|
| 195 |
-
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_buf);
|
| 196 |
-
|
| 197 |
-
int const lane_predicate = cute::elect_one_sync();
|
| 198 |
-
int const warp_idx = cutlass::canonical_warp_idx_sync();
|
| 199 |
-
|
| 200 |
-
// Issue Tma Descriptor Prefetch from a single thread
|
| 201 |
-
if (warp_idx == 0 && lane_predicate) {
|
| 202 |
-
CollectiveMainloop::prefetch_tma_descriptors(params.mainloop);
|
| 203 |
-
CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue);
|
| 204 |
-
}
|
| 205 |
-
|
| 206 |
-
// Obtain warp index
|
| 207 |
-
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
|
| 208 |
-
int warp_group_idx = cutlass::canonical_warp_group_idx();
|
| 209 |
-
|
| 210 |
-
if (warp_idx == 0 && lane_predicate) {
|
| 211 |
-
shared_storage.pipelines.barrier_Q.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
|
| 212 |
-
if constexpr (HasQv) {
|
| 213 |
-
shared_storage.pipelines.barrier_Qv.init(Use_TMA_Q ? 1 : NumProducerThreads /*numThreads*/);
|
| 214 |
-
}
|
| 215 |
-
shared_storage.pipelines.barrier_O.init(size(ClusterShape{}) * (Use_TMA_O ? 1 : NumMmaThreads) /*numThreads*/);
|
| 216 |
-
}
|
| 217 |
-
|
| 218 |
-
// We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
|
| 219 |
-
PipelineParamsK pipeline_params_k;
|
| 220 |
-
pipeline_params_k.role = warp_group_idx == 0
|
| 221 |
-
? MainloopPipelineK::ThreadCategory::Producer
|
| 222 |
-
: MainloopPipelineK::ThreadCategory::Consumer;
|
| 223 |
-
if constexpr (Use_TMA_KV) {
|
| 224 |
-
pipeline_params_k.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
|
| 225 |
-
pipeline_params_k.is_leader = warp_group_thread_idx == 0;
|
| 226 |
-
pipeline_params_k.num_consumers = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
|
| 227 |
-
} else {
|
| 228 |
-
pipeline_params_k.consumer_arv_count = !LargeHeadDimV ? NumMmaThreads : cutlass::NumThreadsPerWarpGroup;
|
| 229 |
-
pipeline_params_k.producer_arv_count = NumProducerThreads;
|
| 230 |
-
}
|
| 231 |
-
|
| 232 |
-
static_assert(is_same_v<PipelineParamsK, PipelineParamsVt>);
|
| 233 |
-
PipelineParamsVt pipeline_params_vt = pipeline_params_k;
|
| 234 |
-
if constexpr (Use_TMA_KV && !SameHeadDim) {
|
| 235 |
-
pipeline_params_vt.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
|
| 236 |
-
if constexpr (LargeHeadDimV) { pipeline_params_vt.num_consumers = NumMmaThreads; }
|
| 237 |
-
} else {
|
| 238 |
-
if constexpr (LargeHeadDimV) { pipeline_params_vt.consumer_arv_count = NumMmaThreads; }
|
| 239 |
-
}
|
| 240 |
-
|
| 241 |
-
MainloopPipelineK pipeline_k = [&] {
|
| 242 |
-
if constexpr (Use_TMA_KV) {
|
| 243 |
-
return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k, ClusterShape{});
|
| 244 |
-
} else {
|
| 245 |
-
return MainloopPipelineK(shared_storage.pipelines.pipeline_k, pipeline_params_k);
|
| 246 |
-
}
|
| 247 |
-
}();
|
| 248 |
-
// MainloopPipelineV pipeline_v(shared_storage.pipelines.pipeline_v, pipeline_params_v, ClusterShape{});
|
| 249 |
-
MainloopPipelineV pipeline_v = [&] {
|
| 250 |
-
if constexpr (!Transpose_V) {
|
| 251 |
-
static_assert(is_same_v<PipelineParamsK, PipelineParamsV>);
|
| 252 |
-
if constexpr (Use_TMA_KV) {
|
| 253 |
-
return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt, ClusterShape{});
|
| 254 |
-
} else {
|
| 255 |
-
return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_vt);
|
| 256 |
-
}
|
| 257 |
-
} else {
|
| 258 |
-
PipelineParamsV pipeline_params_v;
|
| 259 |
-
pipeline_params_v.role = warp_group_idx == 0
|
| 260 |
-
? MainloopPipelineV::ThreadCategory::Producer
|
| 261 |
-
: MainloopPipelineV::ThreadCategory::Consumer;
|
| 262 |
-
pipeline_params_v.producer_arv_count = NumProducerThreads;
|
| 263 |
-
pipeline_params_v.consumer_arv_count = NumMmaThreads;
|
| 264 |
-
return MainloopPipelineV(shared_storage.pipelines.pipeline_v, pipeline_params_v);
|
| 265 |
-
}
|
| 266 |
-
}();
|
| 267 |
-
// If we need to transpose V (e.g. FP8 and V is row-major), we use pipeline_vt for the TMA, then
|
| 268 |
-
// the producer WG will read from pipeline_vt and write to pipeline_v.
|
| 269 |
-
// If we don't need to transpose V, we use pipeline_v for the TMA, and pipeline_vt won't be used.
|
| 270 |
-
// Technically for pipeline_params_vt, warp0 of WG0 is the producer and all of WG0 are consumers.
|
| 271 |
-
// However, the thread role isn't used in the pipeline implementation.
|
| 272 |
-
MainloopPipelineVt pipeline_vt = [&] {
|
| 273 |
-
if constexpr (Use_TMA_KV) {
|
| 274 |
-
pipeline_params_vt.num_consumers = NumProducerThreads; // TMA_V is only consumed by the producer WG
|
| 275 |
-
return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt, ClusterShape{});
|
| 276 |
-
} else {
|
| 277 |
-
pipeline_params_vt.consumer_arv_count = NumProducerThreads; // TMA_V is only consumed by the producer WG
|
| 278 |
-
return MainloopPipelineVt(shared_storage.pipelines.pipeline_vt, pipeline_params_vt);
|
| 279 |
-
}
|
| 280 |
-
}();
|
| 281 |
-
|
| 282 |
-
PipelineParamsKVNew pipeline_params_kv_new;
|
| 283 |
-
pipeline_params_kv_new.role = warp_group_idx == 0
|
| 284 |
-
? MainloopPipelineKVNew::ThreadCategory::Producer
|
| 285 |
-
: MainloopPipelineKVNew::ThreadCategory::Consumer;
|
| 286 |
-
pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
|
| 287 |
-
pipeline_params_kv_new.is_leader = warp_group_thread_idx == 0;
|
| 288 |
-
pipeline_params_kv_new.num_consumers = NumMmaThreads;
|
| 289 |
-
auto pipeline_k_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_k_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
|
| 290 |
-
if constexpr (!SameHeadDim) {
|
| 291 |
-
pipeline_params_kv_new.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
|
| 292 |
-
}
|
| 293 |
-
auto pipeline_v_new = cute::conditional_return<AppendKV>(MainloopPipelineKVNew(shared_storage.pipelines.pipeline_v_new, pipeline_params_kv_new, ClusterShape{}), nullptr);
|
| 294 |
-
|
| 295 |
-
CollectiveMainloop mainloop;
|
| 296 |
-
CollectiveEpilogue epilogue;
|
| 297 |
-
|
| 298 |
-
// We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster
|
| 299 |
-
if constexpr (size(ClusterShape{}) > 1) {
|
| 300 |
-
cute::cluster_arrive_relaxed();
|
| 301 |
-
cute::cluster_wait();
|
| 302 |
-
} else {
|
| 303 |
-
__syncthreads();
|
| 304 |
-
}
|
| 305 |
-
|
| 306 |
-
TileScheduler scheduler(reinterpret_cast<typename TileScheduler::SharedStorage*>(&shared_storage.pipelines.smem_scheduler));
|
| 307 |
-
|
| 308 |
-
if (warp_group_idx == 0) { // Producer
|
| 309 |
-
cutlass::arch::warpgroup_reg_dealloc<LoadRegisterRequirement>();
|
| 310 |
-
|
| 311 |
-
// The pipelines for AppendKV and main attention are different, since e.g. main attention
|
| 312 |
-
// might use cp.async to load KV (if PagedKVNonTMA) while AppendKV always uses TMA to load
|
| 313 |
-
// KV_new. Since the pipeline states are different, we have to manually sync to make
|
| 314 |
-
// sure the two pipelines don't race when accessing smem_k and smem_v.
|
| 315 |
-
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipelineK>();
|
| 316 |
-
PipelineState smem_pipe_write_new = cutlass::make_producer_start_state<MainloopPipelineKVNew>();
|
| 317 |
-
int work_idx = 0;
|
| 318 |
-
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
|
| 319 |
-
static constexpr bool SingleProducerWarp = NumProducerThreads == cutlass::NumThreadsPerWarp;
|
| 320 |
-
if constexpr (SingleProducerWarp) {
|
| 321 |
-
if (warp_idx_in_warpgroup != 0) { return; }
|
| 322 |
-
}
|
| 323 |
-
if (!SingleProducerWarp && warp_idx_in_warpgroup != 0) { scheduler.init_consumer(); }
|
| 324 |
-
|
| 325 |
-
cutlass::arch::wait_on_dependent_grids();
|
| 326 |
-
|
| 327 |
-
// Load Q, K, V
|
| 328 |
-
for (auto work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_initial_work</*IsProducerWarp=*/true>(params.scheduler) : scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
|
| 329 |
-
work_tile_info.is_valid(params.scheduler);
|
| 330 |
-
work_tile_info = SingleProducerWarp || warp_idx_in_warpgroup == 0 ? scheduler.template get_next_work</*IsProducerWarp=*/true>(params.scheduler, work_tile_info) : scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info)) {
|
| 331 |
-
|
| 332 |
-
auto block_coord = work_tile_info.get_block_coord(params.scheduler);
|
| 333 |
-
SeqlenInfo_t seqlen_info{
|
| 334 |
-
get<2>(block_coord) /*bidb*/,
|
| 335 |
-
get<0>(params.mainloop.shape_Q),
|
| 336 |
-
!params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
|
| 337 |
-
get<0>(params.mainloop.shape_K_new),
|
| 338 |
-
params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
|
| 339 |
-
params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
|
| 340 |
-
params.mainloop.seqlens_rotary
|
| 341 |
-
};
|
| 342 |
-
if constexpr (AppendKV) {
|
| 343 |
-
bool tile_new_valid = mainloop.load_kv_new(
|
| 344 |
-
params.mainloop, pipeline_k_new, pipeline_v_new,
|
| 345 |
-
smem_pipe_write_new, shared_storage, seqlen_info, block_coord, work_idx);
|
| 346 |
-
if (tile_new_valid) {
|
| 347 |
-
// if (threadIdx.x == 0) { printf("Producer: Before sync\n"); }
|
| 348 |
-
cutlass::arch::NamedBarrier::sync(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
|
| 349 |
-
// if (threadIdx.x == 0) { printf("Producer: After sync\n"); }
|
| 350 |
-
}
|
| 351 |
-
}
|
| 352 |
-
auto scheduler_prefetch = [&scheduler, ¶ms, &work_tile_info]() {
|
| 353 |
-
scheduler.prefetch_next_work(params.scheduler, work_tile_info);
|
| 354 |
-
};
|
| 355 |
-
// pipeline_vt won't be used if we don't need to transpose V.
|
| 356 |
-
mainloop.load(params.mainloop, pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write,
|
| 357 |
-
shared_storage, scheduler_prefetch, seqlen_info, block_coord, work_idx);
|
| 358 |
-
}
|
| 359 |
-
mainloop.load_tail(pipeline_k, pipeline_v, pipeline_vt, smem_pipe_write, shared_storage, work_idx);
|
| 360 |
-
} else { // Consumer
|
| 361 |
-
cutlass::arch::warpgroup_reg_alloc<MmaRegisterRequirement>();
|
| 362 |
-
|
| 363 |
-
// Initialize matmul objects.
|
| 364 |
-
TiledMmaPV tiled_mma_pv;
|
| 365 |
-
|
| 366 |
-
PipelineState smem_pipe_read;
|
| 367 |
-
PipelineState smem_pipe_read_new;
|
| 368 |
-
// We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
|
| 369 |
-
// (like in Cutlass's gemm) because the read and release pipeline states are always the same.
|
| 370 |
-
|
| 371 |
-
scheduler.init_consumer();
|
| 372 |
-
mainloop.mma_init();
|
| 373 |
-
|
| 374 |
-
int work_idx = 0;
|
| 375 |
-
CUTLASS_PRAGMA_NO_UNROLL
|
| 376 |
-
for (auto work_tile_info = scheduler.template get_initial_work</*IsProducerWarp=*/false>(params.scheduler);
|
| 377 |
-
work_tile_info.is_valid(params.scheduler);
|
| 378 |
-
// get_next_work will be called before the epilogue
|
| 379 |
-
) {
|
| 380 |
-
auto block_coord = work_tile_info.get_block_coord(params.scheduler);
|
| 381 |
-
int const bidb = get<2>(block_coord);
|
| 382 |
-
SeqlenInfo_t seqlen_info{
|
| 383 |
-
bidb,
|
| 384 |
-
get<0>(params.mainloop.shape_Q),
|
| 385 |
-
!params.mainloop.ptr_pagetable ? size<0>(params.mainloop.shape_K) : size<0>(params.mainloop.shape_K) * size<1>(params.mainloop.shape_pagetable),
|
| 386 |
-
get<0>(params.mainloop.shape_K_new),
|
| 387 |
-
params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
|
| 388 |
-
params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
|
| 389 |
-
params.mainloop.seqlens_rotary
|
| 390 |
-
};
|
| 391 |
-
if constexpr (AppendKV) {
|
| 392 |
-
bool tile_new_valid = mainloop.store_kv_new(
|
| 393 |
-
params.mainloop, pipeline_k_new, pipeline_v_new, smem_pipe_read_new,
|
| 394 |
-
threadIdx.x - MmaThreadOffset, shared_storage, seqlen_info, block_coord);
|
| 395 |
-
if (tile_new_valid) {
|
| 396 |
-
// if (threadIdx.x == 128) { printf("Consumer: Before sync\n"); }
|
| 397 |
-
// We need this sync so that the gmem write from the consumers is visible to the producer
|
| 398 |
-
// that might do TMA read after that.
|
| 399 |
-
asm volatile ("fence.proxy.async.global;");
|
| 400 |
-
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + NumProducerThreads, static_cast<uint32_t>(FwdNamedBarriers::AppendKV) /*id*/);
|
| 401 |
-
// arrive is enough, we don't need sync. The producer will sync, which means
|
| 402 |
-
// after that sync we're guaranteed that the AppendKV pipeline have finished
|
| 403 |
-
// loading and consumer smem_k and smem_v.
|
| 404 |
-
// if (threadIdx.x == 128) { printf("Consumer: After sync\n"); }
|
| 405 |
-
}
|
| 406 |
-
}
|
| 407 |
-
// If there's tanh softcap, the scaling will be done before tanh.
|
| 408 |
-
float softmax_scale_log2 = params.mainloop.softmax_scale_log2;
|
| 409 |
-
if constexpr (Is_FP8 && !Has_softcap) {
|
| 410 |
-
int const bidh = get<1>(block_coord);
|
| 411 |
-
int const bidh_kv = !PackGQA ? params.mainloop.qhead_per_khead_divmod.divide(bidh) : bidh;
|
| 412 |
-
float const q_descale = params.mainloop.ptr_q_descale == nullptr ? 1.0f : params.mainloop.ptr_q_descale[bidb * get<0>(params.mainloop.stride_q_descale) + bidh_kv * get<1>(params.mainloop.stride_q_descale)];
|
| 413 |
-
float const k_descale = params.mainloop.ptr_k_descale == nullptr ? 1.0f : params.mainloop.ptr_k_descale[bidb * get<0>(params.mainloop.stride_k_descale) + bidh_kv * get<1>(params.mainloop.stride_k_descale)];
|
| 414 |
-
softmax_scale_log2 *= q_descale * k_descale;
|
| 415 |
-
}
|
| 416 |
-
flash::Softmax<!LargeHeadDimV ? 2 * (2 * kBlockM / NumMmaThreads) : 2, /*Max_offset=*/!Is_FP8 ? 0 : 8> softmax(softmax_scale_log2);
|
| 417 |
-
// Attention output (GEMM-II) accumulator.
|
| 418 |
-
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_MNK_PV{}));
|
| 419 |
-
bool tile_valid;
|
| 420 |
-
if constexpr (!LargeHeadDimV) {
|
| 421 |
-
tile_valid = mainloop.mma(
|
| 422 |
-
params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
|
| 423 |
-
tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
|
| 424 |
-
} else { // mma_pv might not compile if !LargeHeadDimV
|
| 425 |
-
if (warp_group_idx == 1) {
|
| 426 |
-
tile_valid = mainloop.mma(
|
| 427 |
-
params.mainloop, pipeline_k, pipeline_v, smem_pipe_read,
|
| 428 |
-
tOrO, softmax, threadIdx.x - MmaThreadOffset, work_idx, seqlen_info, block_coord, shared_storage);
|
| 429 |
-
} else {
|
| 430 |
-
tile_valid = mainloop.mma_pv(
|
| 431 |
-
params.mainloop, pipeline_v, smem_pipe_read,
|
| 432 |
-
tOrO, softmax, threadIdx.x - MmaThreadOffset, seqlen_info, block_coord, shared_storage);
|
| 433 |
-
}
|
| 434 |
-
}
|
| 435 |
-
// Do this here before the epilogue so that the next tile is ready to go.
|
| 436 |
-
work_tile_info = scheduler.template get_next_work</*IsProducerWarp=*/false>(params.scheduler, work_tile_info);
|
| 437 |
-
if constexpr (Split && Varlen) {
|
| 438 |
-
if (!work_tile_info.is_valid(params.scheduler)) { // Last tile
|
| 439 |
-
cutlass::arch::launch_dependent_grids();
|
| 440 |
-
}
|
| 441 |
-
}
|
| 442 |
-
if (tile_valid) {
|
| 443 |
-
// if (threadIdx.x == 128) { printf("Before epilogue, bid.x = %d, bid.y = %d, bid.z = %d, m_block = %d, bidb = %d, split_idx = %d\n", blockIdx.x, blockIdx.y, blockIdx.z, m_block, bidb, split_idx); }
|
| 444 |
-
epilogue.store(params.epilogue, tOrO, softmax.row_sum, shared_storage, tiled_mma_pv,
|
| 445 |
-
threadIdx.x - MmaThreadOffset, block_coord);
|
| 446 |
-
} else {
|
| 447 |
-
// Write 0 to gO and -inf to gLSE.
|
| 448 |
-
epilogue.store_zero(params.epilogue, threadIdx.x - MmaThreadOffset, block_coord);
|
| 449 |
-
}
|
| 450 |
-
}
|
| 451 |
-
epilogue.store_tail();
|
| 452 |
-
}
|
| 453 |
-
|
| 454 |
-
}
|
| 455 |
-
|
| 456 |
-
};
|
| 457 |
-
|
| 458 |
-
} // namespace flash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/flash_fwd_launch_template.h
DELETED
|
@@ -1,223 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
#include "cute/tensor.hpp"
|
| 8 |
-
|
| 9 |
-
#include "cutlass/cutlass.h"
|
| 10 |
-
#include "cutlass/device_kernel.h" // For device_kernel
|
| 11 |
-
#include <cutlass/kernel_hardware_info.h>
|
| 12 |
-
#include "cutlass/cluster_launch.hpp"
|
| 13 |
-
#include "cutlass/kernel_launch.h"
|
| 14 |
-
|
| 15 |
-
#include "static_switch.h"
|
| 16 |
-
#include "flash.h"
|
| 17 |
-
#include "tile_size.h"
|
| 18 |
-
#include "tile_scheduler.hpp"
|
| 19 |
-
#include "flash_fwd_kernel_sm90.h"
|
| 20 |
-
#include "flash_fwd_kernel_sm80.h"
|
| 21 |
-
#include "mainloop_fwd_sm90_tma_gmma_ws.hpp"
|
| 22 |
-
#include "mainloop_fwd_sm80.hpp"
|
| 23 |
-
#include "epilogue_fwd.hpp"
|
| 24 |
-
|
| 25 |
-
using namespace cute;
|
| 26 |
-
|
| 27 |
-
template <int Arch, int kHeadDim, int kHeadDimV, int ClusterM, typename Element, typename ElementOut,
|
| 28 |
-
bool Is_causal, bool Is_local, bool Has_softcap, bool Varlen, bool PagedKVNonTMA, bool AppendKV, bool HasQv,
|
| 29 |
-
bool PackGQA, bool Split, bool V_colmajor>
|
| 30 |
-
void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
| 31 |
-
static_assert(!(Is_causal && Is_local), "Causal and Local cannot be enabled at the same time");
|
| 32 |
-
static_assert(!(AppendKV && V_colmajor), "AppendKV and V_colmajor cannot be enabled at the same time");
|
| 33 |
-
static_assert(!(AppendKV && !Varlen), "AppendKV requires Varlen");
|
| 34 |
-
static constexpr bool Is_FP8 = cute::is_same_v<Element, cutlass::float_e4m3_t> || cute::is_same_v<Element, cutlass::float_e5m2_t>;
|
| 35 |
-
static constexpr bool FP8_TransposeV = Is_FP8 && !V_colmajor;
|
| 36 |
-
using ArchTag = std::conditional_t<Arch >= 90, cutlass::arch::Sm90, cutlass::arch::Sm80>;
|
| 37 |
-
|
| 38 |
-
// Can't use structured binding since it's not compatible with constexpr
|
| 39 |
-
static constexpr std::tuple<int, int, bool, bool> kBlockMN_RS_IntraWGOverlap = tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap);
|
| 40 |
-
static constexpr std::tuple<int, int, int, int, bool> kBlockMN_kNWarps_Stages_RS = tile_size_fwd_sm8x(Arch == 86 || Arch == 89, kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(Element) /*element_size*/, PagedKVNonTMA, Varlen && Split, Has_softcap, AppendKV);
|
| 41 |
-
static constexpr int kBlockM = Arch >= 90 ? std::get<0>(kBlockMN_RS_IntraWGOverlap) : std::get<0>(kBlockMN_kNWarps_Stages_RS);
|
| 42 |
-
static constexpr int kBlockN = Arch >= 90 ? std::get<1>(kBlockMN_RS_IntraWGOverlap) : std::get<1>(kBlockMN_kNWarps_Stages_RS);
|
| 43 |
-
static constexpr bool MmaPV_is_RS = std::get<2>(kBlockMN_RS_IntraWGOverlap);
|
| 44 |
-
static constexpr bool IntraWGOverlap = std::get<3>(kBlockMN_RS_IntraWGOverlap);
|
| 45 |
-
static constexpr int kNWarps = std::get<2>(kBlockMN_kNWarps_Stages_RS);
|
| 46 |
-
static constexpr int kStages = Arch >= 90 ? 2 : std::get<3>(kBlockMN_kNWarps_Stages_RS);
|
| 47 |
-
static constexpr bool Q_in_regs = Arch >= 90 ? false : std::get<4>(kBlockMN_kNWarps_Stages_RS);
|
| 48 |
-
|
| 49 |
-
using TileShape_MNK = cute::Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>;
|
| 50 |
-
using TileShape_MNK_PV = cute::Shape<Int<kBlockM>, Int<kHeadDimV>, Int<kBlockN>>;
|
| 51 |
-
using ClusterShape = cute::Shape<Int<ClusterM>, _1, _1>;
|
| 52 |
-
using CollectiveMainloop = std::conditional_t<
|
| 53 |
-
Arch >= 90,
|
| 54 |
-
flash::CollectiveMainloopFwdSm90<kStages, ClusterShape, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm90, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, HasQv, MmaPV_is_RS, IntraWGOverlap, PackGQA, Split, V_colmajor>,
|
| 55 |
-
flash::CollectiveMainloopFwdSm80<kNWarps, kStages, Q_in_regs, TileShape_MNK, kHeadDimV, Element, float, cutlass::arch::Sm80, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV, PackGQA, Split>
|
| 56 |
-
>;
|
| 57 |
-
using CollectiveEpilogue = flash::CollectiveEpilogueFwd<TileShape_MNK_PV, ClusterShape, ElementOut, ArchTag, CollectiveMainloop::NumMmaThreads, Varlen, PackGQA, Split, FP8_TransposeV>;
|
| 58 |
-
|
| 59 |
-
static constexpr int NumProducerThreads = Arch >= 90 ? CollectiveMainloop::NumProducerThreads : CollectiveMainloop::NumMmaThreads;
|
| 60 |
-
using SchedulerPersistent = std::conditional_t<Varlen,
|
| 61 |
-
flash::VarlenDynamicPersistentTileScheduler<kBlockM, CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>,
|
| 62 |
-
std::conditional_t<!Is_causal && !Is_local,
|
| 63 |
-
flash::StaticPersistentTileScheduler<Split>,
|
| 64 |
-
flash::DynamicPersistentTileScheduler<CollectiveMainloop::NumMmaThreads, NumProducerThreads, Split, PackGQA, Arch >= 90 /*WarpSpecialized*/>
|
| 65 |
-
>
|
| 66 |
-
>;
|
| 67 |
-
using SchedulerSingleTile = flash::SingleTileScheduler<Varlen, Split, PackGQA, kBlockM>;
|
| 68 |
-
// If Split then we probably don't have enough work for PersistentScheduler to be useful.
|
| 69 |
-
// However, if Varlen (e.g., during decode where we have max_seqlens), using PersistentScheduler is better
|
| 70 |
-
// since we'll avoid launching a bunch of thread blocks that immediately exit.
|
| 71 |
-
// On Sm80, noncausal persistent seems a bit slower.
|
| 72 |
-
static constexpr bool UsePersistentScheduler = Arch >= 90 ? !(Split && !Varlen) : ((Is_causal && !Varlen) || (Varlen && Split));
|
| 73 |
-
using Scheduler = std::conditional_t<!UsePersistentScheduler, SchedulerSingleTile, SchedulerPersistent>;
|
| 74 |
-
using AttnKernel = std::conditional_t<
|
| 75 |
-
Arch >= 90,
|
| 76 |
-
flash::enable_sm90_or_later<flash::FlashAttnFwdSm90<CollectiveMainloop, CollectiveEpilogue, Scheduler>>,
|
| 77 |
-
flash::enable_sm80_to_sm89<flash::FlashAttnFwdSm80<CollectiveMainloop, CollectiveEpilogue, Scheduler>>
|
| 78 |
-
>;
|
| 79 |
-
|
| 80 |
-
bool const is_varlen_q = params.cu_seqlens_q;
|
| 81 |
-
bool const is_varlen_k = params.cu_seqlens_k;
|
| 82 |
-
bool const is_varlen_k_new = params.cu_seqlens_knew;
|
| 83 |
-
int seqlen_q = !is_varlen_q ? params.seqlen_q : params.total_q;
|
| 84 |
-
int batch_q = !is_varlen_q ? params.b : 1;
|
| 85 |
-
int batch_k = !is_varlen_k ? (params.kv_batch_idx ? params.b_k : params.b) : 1;
|
| 86 |
-
typename CollectiveMainloop::StrideV v_strides =
|
| 87 |
-
cute::conditional_return<!V_colmajor>(
|
| 88 |
-
make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0),
|
| 89 |
-
make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0));
|
| 90 |
-
typename CollectiveMainloop::Arguments mainloop_args {
|
| 91 |
-
static_cast<Element const*>(params.q_ptr),
|
| 92 |
-
{seqlen_q, params.d, params.h, batch_q}, // shape_Q
|
| 93 |
-
{params.q_row_stride, _1{}, params.q_head_stride, !is_varlen_q ? params.q_batch_stride : 0}, // stride_Q
|
| 94 |
-
static_cast<Element*>(params.k_ptr),
|
| 95 |
-
{!params.page_table ? (!is_varlen_k ? params.seqlen_k : params.total_k) : params.page_size,
|
| 96 |
-
params.d, params.h_k, !params.page_table ? batch_k : params.num_pages}, // shape_K
|
| 97 |
-
{params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K
|
| 98 |
-
static_cast<Element*>(params.v_ptr),
|
| 99 |
-
params.dv, // headdim_v
|
| 100 |
-
v_strides, // stride_V
|
| 101 |
-
static_cast<Element const*>(params.knew_ptr),
|
| 102 |
-
{!is_varlen_k_new ? params.seqlen_knew : params.total_knew, params.d, params.h_k, !is_varlen_k_new ? params.b : 1}, // shape_K_new
|
| 103 |
-
{params.knew_row_stride, _1{}, params.knew_head_stride, !is_varlen_k_new ? params.knew_batch_stride : 0}, // stride_K_new
|
| 104 |
-
static_cast<Element const*>(params.vnew_ptr),
|
| 105 |
-
{params.vnew_row_stride, _1{}, params.vnew_head_stride, !is_varlen_k_new ? params.vnew_batch_stride : 0}, // stride_V_new
|
| 106 |
-
static_cast<Element const*>(params.qv_ptr),
|
| 107 |
-
{params.qv_row_stride, _1{}, params.qv_head_stride, !is_varlen_q ? params.qv_batch_stride : 0}, // stride_Qv
|
| 108 |
-
static_cast<Element const*>(params.rotary_cos_ptr),
|
| 109 |
-
{params.seqlen_k, params.rotary_dim / 2}, // shape_rotary, the seqlen shape doesn't matter
|
| 110 |
-
{params.rotary_dim / 2, _1{}}, // stride_rotary_cos
|
| 111 |
-
static_cast<Element const*>(params.rotary_sin_ptr),
|
| 112 |
-
{params.rotary_dim / 2, _1{}}, // stride_rotary_sin
|
| 113 |
-
params.is_rotary_interleaved,
|
| 114 |
-
params.page_table,
|
| 115 |
-
// if page_size is not set, avoid dividing by zero
|
| 116 |
-
{params.kv_batch_idx ? params.b_k : params.b, !params.page_table ? 0 : params.seqlen_k / params.page_size}, // shape_page_table
|
| 117 |
-
{params.page_table_batch_stride, _1{}}, // stride_page_table
|
| 118 |
-
params.scale_softmax,
|
| 119 |
-
params.q_descale_ptr, params.k_descale_ptr, params.v_descale_ptr,
|
| 120 |
-
{params.q_descale_batch_stride, params.q_descale_head_stride},
|
| 121 |
-
{params.k_descale_batch_stride, params.k_descale_head_stride},
|
| 122 |
-
{params.v_descale_batch_stride, params.v_descale_head_stride},
|
| 123 |
-
params.window_size_left, params.window_size_right, params.attention_chunk,
|
| 124 |
-
params.softcap,
|
| 125 |
-
params.num_splits,
|
| 126 |
-
params.kv_batch_idx,
|
| 127 |
-
params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
|
| 128 |
-
params.seqused_q, params.seqused_k,
|
| 129 |
-
params.leftpad_k, params.seqlens_rotary
|
| 130 |
-
};
|
| 131 |
-
typename CollectiveEpilogue::Arguments epilogue_args {
|
| 132 |
-
static_cast<ElementOut*>(params.o_ptr),
|
| 133 |
-
{seqlen_q, params.dv, params.h, batch_q, params.num_splits}, // shape_O
|
| 134 |
-
{params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0, 0}, // stride_O
|
| 135 |
-
static_cast<float*>(params.oaccum_ptr),
|
| 136 |
-
{params.oaccum_row_stride, _1{}, params.oaccum_head_stride, !is_varlen_q ? params.oaccum_batch_stride : 0, params.oaccum_split_stride}, // stride_O_partial
|
| 137 |
-
static_cast<float*>(params.softmax_lse_ptr),
|
| 138 |
-
{_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, 0}, // stride_LSE
|
| 139 |
-
static_cast<float*>(params.softmax_lseaccum_ptr),
|
| 140 |
-
{_1{}, seqlen_q, !is_varlen_q ? params.h * seqlen_q : 0, params.h * seqlen_q * batch_q}, // stride_LSE_partial
|
| 141 |
-
params.h_k,
|
| 142 |
-
params.cu_seqlens_q, params.seqused_q
|
| 143 |
-
};
|
| 144 |
-
|
| 145 |
-
int qhead_per_khead = !PackGQA ? 1 : cutlass::ceil_div(params.h, params.h_k);
|
| 146 |
-
int num_blocks_m = cutlass::ceil_div(params.seqlen_q * qhead_per_khead, get<0>(TileShape_MNK{}));
|
| 147 |
-
num_blocks_m = cutlass::round_up(num_blocks_m, size<0>(ClusterShape{}));
|
| 148 |
-
typename flash::TileSchedulerArguments scheduler_args {
|
| 149 |
-
num_blocks_m, !PackGQA ? params.h : params.h_k, params.b, params.num_splits,
|
| 150 |
-
params.h / params.h_k,
|
| 151 |
-
params.seqlen_q,
|
| 152 |
-
params.seqlen_k, params.d, params.dv, sizeof(Element),
|
| 153 |
-
params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q,
|
| 154 |
-
// params.num_m_blocks_ptr,
|
| 155 |
-
params.num_splits_dynamic_ptr,
|
| 156 |
-
};
|
| 157 |
-
|
| 158 |
-
if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) {
|
| 159 |
-
prepare_varlen_num_blocks(params, stream, PackGQA, kBlockM, kBlockN, Arch >= 90 /*enable_pdl*/);
|
| 160 |
-
CHECK_CUDA_KERNEL_LAUNCH();
|
| 161 |
-
}
|
| 162 |
-
|
| 163 |
-
int device;
|
| 164 |
-
CHECK_CUDA(cudaGetDevice(&device));
|
| 165 |
-
typename AttnKernel::Params kernel_params = AttnKernel::to_underlying_arguments({
|
| 166 |
-
mainloop_args, epilogue_args, {device, params.num_sm}, scheduler_args
|
| 167 |
-
});
|
| 168 |
-
|
| 169 |
-
dim3 grid_dims = AttnKernel::get_grid_shape(kernel_params);
|
| 170 |
-
dim3 block_dims = AttnKernel::get_block_shape();
|
| 171 |
-
int smem_size = AttnKernel::SharedStorageSize;
|
| 172 |
-
// int smem_size_q = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_q));
|
| 173 |
-
// int smem_size_k = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_k));
|
| 174 |
-
// int smem_size_v = sizeof(decltype((typename CollectiveMainloop::TensorStorage{}).smem_v));
|
| 175 |
-
// printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
|
| 176 |
-
// Get the ptr to kernel function.
|
| 177 |
-
if constexpr (size(ClusterShape{}) > 1) {
|
| 178 |
-
void const* kernel = (void const*) cutlass::device_kernel<AttnKernel>;
|
| 179 |
-
if (smem_size >= 48 * 1024) {
|
| 180 |
-
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
| 181 |
-
}
|
| 182 |
-
dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{}));
|
| 183 |
-
cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream};
|
| 184 |
-
cutlass::launch_kernel_on_cluster(launch_params, kernel, kernel_params);
|
| 185 |
-
} else {
|
| 186 |
-
auto kernel = cutlass::device_kernel<AttnKernel>;
|
| 187 |
-
if (smem_size >= 48 * 1024) {
|
| 188 |
-
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
| 189 |
-
}
|
| 190 |
-
// kernel<<<grid_dims, block_dims, smem_size, stream>>>(kernel_params);
|
| 191 |
-
cutlass::kernel_launch<AttnKernel>(grid_dims, block_dims, smem_size, stream, kernel_params,
|
| 192 |
-
Arch >= 90 && Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation /*launch_with_pdl*/);
|
| 193 |
-
}
|
| 194 |
-
CHECK_CUDA_KERNEL_LAUNCH();
|
| 195 |
-
}
|
| 196 |
-
|
| 197 |
-
template<int Arch, typename T, int kHeadDim, int kHeadDimV, bool Split, bool PagedKVNonTMA, bool Has_softcap, bool PackGQA>
|
| 198 |
-
void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
| 199 |
-
static_assert(sizeof(T) == 2 || sizeof(T) == 1, "Only 16bit and 8bit are supported");
|
| 200 |
-
static constexpr bool Is_FP8 = cute::is_same_v<T, cutlass::float_e4m3_t> || cute::is_same_v<T, cutlass::float_e5m2_t>;
|
| 201 |
-
using T_out = std::conditional_t<!Is_FP8, T, cutlass::bfloat16_t>;
|
| 202 |
-
CAUSAL_LOCAL_SWITCH(params.is_causal, params.is_local, Is_causal, Is_local, [&] {
|
| 203 |
-
VCOLMAJOR_SWITCH(params.v_dim_stride != 1, V_colmajor_, [&] {
|
| 204 |
-
static constexpr bool V_colmajor = V_colmajor_ && sizeof(T) == 1;
|
| 205 |
-
VARLEN_SWITCH(params.cu_seqlens_q || params.cu_seqlens_k || params.seqused_q || params.seqused_k || params.leftpad_k, Varlen, [&] {
|
| 206 |
-
// Only needed here to decide if we should use cluster
|
| 207 |
-
static constexpr int kBlockM = Arch >= 90 ? std::get<0>(tile_size_fwd_sm90(kHeadDim, kHeadDimV, Is_causal, Is_local, sizeof(T) /*element_size*/, V_colmajor, PagedKVNonTMA, Has_softcap)) : 128;
|
| 208 |
-
|
| 209 |
-
static constexpr bool Enable_cluster = Arch == 90 && (sizeof(T) == 2 ? (kHeadDim >= 128) : (kHeadDim == 192)) && !Is_causal && !Is_local && !Split && !PagedKVNonTMA && !Varlen;
|
| 210 |
-
BOOL_SWITCH(params.qv_ptr, HasQV_, [&] {
|
| 211 |
-
static constexpr bool HasQv = HasQV_ && Arch == 90 && !Is_FP8 && kHeadDim == 64 && kHeadDimV >= 256;
|
| 212 |
-
APPENDKV_SWITCH(params.knew_ptr, AppendKV, [&] {
|
| 213 |
-
// Only use Cluster if number of tiles along seqlen_q is even and not varlen
|
| 214 |
-
CLUSTER_SWITCH(cutlass::ceil_div(params.seqlen_q * (!PackGQA ? 1 : params.h / params.h_k), kBlockM) % 2 == 0, Use_cluster, [&] {
|
| 215 |
-
static constexpr int ClusterM = Enable_cluster && Use_cluster ? 2 : 1;
|
| 216 |
-
run_flash_fwd<Arch, kHeadDim, kHeadDimV, ClusterM, T, T_out, Is_causal, Is_local, Has_softcap, Varlen, PagedKVNonTMA, AppendKV && Varlen, HasQv, PackGQA, Split, V_colmajor>(params, stream);
|
| 217 |
-
});
|
| 218 |
-
});
|
| 219 |
-
});
|
| 220 |
-
});
|
| 221 |
-
});
|
| 222 |
-
});
|
| 223 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/flash_prepare_scheduler.cu
DELETED
|
@@ -1,124 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#include "cutlass/fast_math.h"
|
| 6 |
-
#include "cutlass/barrier.h"
|
| 7 |
-
#include "cutlass/arch/barrier.h"
|
| 8 |
-
|
| 9 |
-
#include "cutlass/arch/grid_dependency_control.h"
|
| 10 |
-
|
| 11 |
-
#include "flash.h"
|
| 12 |
-
|
| 13 |
-
namespace flash {
|
| 14 |
-
|
| 15 |
-
__global__ void prepare_varlen_num_blocks_kernel(
|
| 16 |
-
int seqlen_q_static, int seqlen_k_static, int seqlen_k_new_static,
|
| 17 |
-
int const* const cu_seqlens_q, int const* const cu_seqlens_k, int const* const cu_seqlens_k_new,
|
| 18 |
-
int const* const seqused_q, int const* const seqused_k, int const* const leftpad_k_ptr,
|
| 19 |
-
int num_batch, int num_head, int qhead_per_khead, int num_sm, int num_splits_static,
|
| 20 |
-
cutlass::FastDivmod blockm_divmod, cutlass::FastDivmod blockn_divmod,
|
| 21 |
-
int* const tile_count_semaphore,
|
| 22 |
-
// int* const num_m_blocks_ptr,
|
| 23 |
-
int* const num_splits_dynamic_ptr,
|
| 24 |
-
bool enable_pdl) {
|
| 25 |
-
|
| 26 |
-
static constexpr int kNumBatchPerWarp = cutlass::NumThreadsPerWarp - 1;
|
| 27 |
-
static constexpr int kSmemSize = 1;
|
| 28 |
-
// Assume that there's only one block in the grid
|
| 29 |
-
__shared__ int total_blocks_smem[kSmemSize];
|
| 30 |
-
|
| 31 |
-
// There's only 1 block in the grid, so might as well start launching the main attn kernel
|
| 32 |
-
if (enable_pdl) { cutlass::arch::launch_dependent_grids(); }
|
| 33 |
-
|
| 34 |
-
if (threadIdx.x < kSmemSize) { total_blocks_smem[threadIdx.x] = 0; }
|
| 35 |
-
__syncthreads();
|
| 36 |
-
|
| 37 |
-
if (threadIdx.x == 0 && tile_count_semaphore) { *tile_count_semaphore = 0; }
|
| 38 |
-
|
| 39 |
-
int lane = threadIdx.x % cutlass::NumThreadsPerWarp;
|
| 40 |
-
|
| 41 |
-
auto get_num_m_blocks = [&](int bidb_start) {
|
| 42 |
-
int batch_idx = lane + bidb_start;
|
| 43 |
-
int seqlen;
|
| 44 |
-
if (seqused_q) {
|
| 45 |
-
seqlen = batch_idx < num_batch ? seqused_q[batch_idx] : 0;
|
| 46 |
-
} else if (cu_seqlens_q) {
|
| 47 |
-
int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_q[batch_idx] : 0;
|
| 48 |
-
int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
|
| 49 |
-
seqlen = next_cu_seqlen - cur_cu_seqlen;
|
| 50 |
-
} else {
|
| 51 |
-
seqlen = seqlen_q_static;
|
| 52 |
-
}
|
| 53 |
-
seqlen *= qhead_per_khead;
|
| 54 |
-
return batch_idx < num_batch && lane < kNumBatchPerWarp
|
| 55 |
-
? blockm_divmod.div(seqlen + blockm_divmod.divisor - 1) : 0;
|
| 56 |
-
};
|
| 57 |
-
|
| 58 |
-
auto get_num_n_blocks = [&](int bidb_start) {
|
| 59 |
-
int batch_idx = lane + bidb_start;
|
| 60 |
-
int leftpad_k = batch_idx < num_batch && leftpad_k_ptr != nullptr ? leftpad_k_ptr[batch_idx] : 0;
|
| 61 |
-
int seqlen;
|
| 62 |
-
if (seqused_k) {
|
| 63 |
-
seqlen = batch_idx < num_batch ? seqused_k[batch_idx] : 0;
|
| 64 |
-
} else if (cu_seqlens_k) {
|
| 65 |
-
int cur_cu_seqlen = batch_idx <= num_batch ? cu_seqlens_k[batch_idx] : 0;
|
| 66 |
-
int next_cu_seqlen = __shfl_down_sync(0xffffffff, cur_cu_seqlen, 1);
|
| 67 |
-
seqlen = next_cu_seqlen - cur_cu_seqlen;
|
| 68 |
-
} else {
|
| 69 |
-
seqlen = seqlen_k_static;
|
| 70 |
-
}
|
| 71 |
-
int seqlen_new;
|
| 72 |
-
if (cu_seqlens_k_new) {
|
| 73 |
-
int cur_cu_seqlen_new = batch_idx <= num_batch ? cu_seqlens_k_new[batch_idx] : 0;
|
| 74 |
-
int next_cu_seqlen_new = __shfl_down_sync(0xffffffff, cur_cu_seqlen_new, 1);
|
| 75 |
-
seqlen_new = next_cu_seqlen_new - cur_cu_seqlen_new;
|
| 76 |
-
} else {
|
| 77 |
-
seqlen_new = seqlen_k_new_static;
|
| 78 |
-
}
|
| 79 |
-
// if (threadIdx.x == 0) { printf("seqlen = %d, seqlen_new = %d, leftpad_k = %d\n", seqlen, seqlen_new, leftpad_k); }
|
| 80 |
-
seqlen = seqlen - leftpad_k + seqlen_new;
|
| 81 |
-
return batch_idx < num_batch && lane < kNumBatchPerWarp
|
| 82 |
-
? blockn_divmod.div(seqlen + blockn_divmod.divisor - 1) : 0;
|
| 83 |
-
};
|
| 84 |
-
|
| 85 |
-
int warp_idx = threadIdx.x / cutlass::NumThreadsPerWarp;
|
| 86 |
-
int bidb_start = kNumBatchPerWarp * warp_idx;
|
| 87 |
-
int num_m_blocks = get_num_m_blocks(bidb_start);
|
| 88 |
-
int num_n_blocks = get_num_n_blocks(bidb_start);
|
| 89 |
-
|
| 90 |
-
int total_blocks = num_m_blocks * num_n_blocks;
|
| 91 |
-
// Warp sum
|
| 92 |
-
#pragma unroll
|
| 93 |
-
for (int i = cutlass::NumThreadsPerWarp / 2; i >= 1; i /= 2) {
|
| 94 |
-
total_blocks += __shfl_down_sync(0xffffffff, total_blocks, i);
|
| 95 |
-
}
|
| 96 |
-
if (lane == 0) { atomicAdd(total_blocks_smem, total_blocks); }
|
| 97 |
-
__syncthreads();
|
| 98 |
-
total_blocks = total_blocks_smem[0];
|
| 99 |
-
// 10% margin
|
| 100 |
-
int blocks_per_sm = static_cast<int>(ceilf(float(total_blocks) * 1.1f * float(num_head) / float(num_sm)));
|
| 101 |
-
// blocks_per_sm = std::max(1, blocks_per_sm); // 1 is the minimum number of blocks per SM
|
| 102 |
-
int num_splits_dynamic = std::max(std::min((num_n_blocks + blocks_per_sm - 1) / blocks_per_sm, num_splits_static), 1);
|
| 103 |
-
if (bidb_start + lane < num_batch && lane < kNumBatchPerWarp) {
|
| 104 |
-
num_splits_dynamic_ptr[bidb_start + lane] = num_splits_dynamic;
|
| 105 |
-
// printf("idx = %d, num_m_blocks = %d, num_n_blocks = %d, num_split_static = %d, num_splits_dynamic = %d\n", bidb_start + lane, num_m_blocks_ptr[bidb_start + lane], num_n_blocks, num_splits_static, num_splits_dynamic);
|
| 106 |
-
}
|
| 107 |
-
}
|
| 108 |
-
|
| 109 |
-
} // flash
|
| 110 |
-
|
| 111 |
-
void prepare_varlen_num_blocks(Flash_fwd_params ¶ms, cudaStream_t stream, bool packgqa,
|
| 112 |
-
int blockM, int blockN, bool enable_pdl) {
|
| 113 |
-
// Only support batch <= 992 (32 warps, each with 31 batches)
|
| 114 |
-
int qhead_per_khead = !packgqa ? 1 : cutlass::ceil_div(params.h, params.h_k);
|
| 115 |
-
flash::prepare_varlen_num_blocks_kernel<<<1 /*grid*/, 1024 /*block*/, 0, stream>>>(
|
| 116 |
-
params.seqlen_q, params.seqlen_k, params.seqlen_knew,
|
| 117 |
-
params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
|
| 118 |
-
params.seqused_q, params.seqused_k, params.leftpad_k,
|
| 119 |
-
params.b, !packgqa ? params.h : params.h_k, qhead_per_khead, params.num_sm, params.num_splits,
|
| 120 |
-
cutlass::FastDivmod(blockM), cutlass::FastDivmod(blockN),
|
| 121 |
-
params.tile_count_semaphore,
|
| 122 |
-
// params.num_m_blocks_ptr,
|
| 123 |
-
params.num_splits_dynamic_ptr, enable_pdl);
|
| 124 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/heuristics.h
DELETED
|
@@ -1,59 +0,0 @@
|
|
| 1 |
-
/******************************************************************************
|
| 2 |
-
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 3 |
-
******************************************************************************/
|
| 4 |
-
|
| 5 |
-
#pragma once
|
| 6 |
-
|
| 7 |
-
#include <vector>
|
| 8 |
-
|
| 9 |
-
inline bool should_pack_gqa(bool varlen_q, int seqlen_q, int qhead_per_khead, int blockM) {
|
| 10 |
-
// If varlen, we don't actually know seqlen_q but only max_seqlen_q.
|
| 11 |
-
if (varlen_q) return true;
|
| 12 |
-
// Heuristic: PackGQA is a bit slower but can help if seqlen_q is small or not near a multiple of kBlockM
|
| 13 |
-
auto round_up = [](int a, int b) { return (a + b - 1) / b * b; };
|
| 14 |
-
float nopack_gqa_efficiency = float(seqlen_q) / float(round_up(seqlen_q, blockM));
|
| 15 |
-
float pack_gqa_efficiency = float(seqlen_q * qhead_per_khead) / float(round_up(seqlen_q * qhead_per_khead, blockM));
|
| 16 |
-
return nopack_gqa_efficiency < 0.9 * pack_gqa_efficiency;
|
| 17 |
-
};
|
| 18 |
-
|
| 19 |
-
// Find the number of splits that maximizes the occupancy. For example, if we have
|
| 20 |
-
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
|
| 21 |
-
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
|
| 22 |
-
// splits as that would incur more HBM reads/writes.
|
| 23 |
-
// So we find the best efficiency, then find the smallest number of splits that gets 85%
|
| 24 |
-
// of the best efficiency.
|
| 25 |
-
inline int num_splits_heuristic(int total_mblocks, int num_SMs, int num_n_blocks, int num_m_blocks, int size_one_kv_head, bool is_causal_or_local, int max_splits) {
|
| 26 |
-
// If we have enough to almost fill the SMs, then just use 1 split
|
| 27 |
-
// However, in the case of super long seqlen where each head of KV doesn't even fit into
|
| 28 |
-
// L2 (we assume that L2 size is 50MB), we want to split.
|
| 29 |
-
if (total_mblocks >= 0.8f * num_SMs) {
|
| 30 |
-
int const size_l2 = 50 * 1024 * 1024;
|
| 31 |
-
// Only split if there are enough queries to go over the KV at least twice
|
| 32 |
-
// Don't split if causal
|
| 33 |
-
if (size_one_kv_head > size_l2 && num_m_blocks >= num_SMs * 2 && !is_causal_or_local) {
|
| 34 |
-
return std::min((size_one_kv_head + size_l2 - 1) / size_l2, max_splits);
|
| 35 |
-
} else {
|
| 36 |
-
return 1;
|
| 37 |
-
}
|
| 38 |
-
}
|
| 39 |
-
// If num_n_blocks is too small, use 1 split. For example, we never split for hdim = 128 and seqlen_k = 512.
|
| 40 |
-
if (num_n_blocks <= 4) { return 1; }
|
| 41 |
-
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
|
| 42 |
-
float max_efficiency = 0.f;
|
| 43 |
-
std::vector<float> efficiency;
|
| 44 |
-
efficiency.reserve(max_splits);
|
| 45 |
-
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
| 46 |
-
float n_waves = float(total_mblocks * num_splits) / num_SMs;
|
| 47 |
-
float eff = n_waves / ceil(n_waves);
|
| 48 |
-
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
| 49 |
-
if (eff > max_efficiency) { max_efficiency = eff; }
|
| 50 |
-
efficiency.push_back(eff);
|
| 51 |
-
}
|
| 52 |
-
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
| 53 |
-
if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
|
| 54 |
-
// printf("num_splits chosen = %d\n", num_splits);
|
| 55 |
-
return num_splits;
|
| 56 |
-
}
|
| 57 |
-
}
|
| 58 |
-
return 1;
|
| 59 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim128_bf16_sm80.cu
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_SM8x
|
| 8 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_bwd_<80, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_bwd_hdim128<80, cutlass::bfloat16_t, false>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
template<>
|
| 14 |
-
void run_mha_bwd_<86, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 15 |
-
run_mha_bwd_hdim128<86, cutlass::bfloat16_t, false>(params, stream);
|
| 16 |
-
}
|
| 17 |
-
#endif
|
| 18 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim128_bf16_sm90.cu
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
| 8 |
-
template<>
|
| 9 |
-
void run_mha_bwd_<90, cutlass::bfloat16_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 10 |
-
run_mha_bwd_hdim128<90, cutlass::bfloat16_t, false>(params, stream);
|
| 11 |
-
}
|
| 12 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm80.cu
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_SM8x
|
| 8 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_bwd_<80, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_bwd_hdim128<80, cutlass::bfloat16_t, true>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
template<>
|
| 14 |
-
void run_mha_bwd_<86, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 15 |
-
run_mha_bwd_hdim128<86, cutlass::bfloat16_t, true>(params, stream);
|
| 16 |
-
}
|
| 17 |
-
#endif
|
| 18 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcap_sm90.cu
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
| 8 |
-
template<>
|
| 9 |
-
void run_mha_bwd_<90, cutlass::bfloat16_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 10 |
-
run_mha_bwd_hdim128<90, cutlass::bfloat16_t, true>(params, stream);
|
| 11 |
-
}
|
| 12 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim128_bf16_softcapall_sm90.cu
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_hdim128_bf16_sm90.cu"
|
| 6 |
-
#include "flash_bwd_hdim128_bf16_softcap_sm90.cu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim128_fp16_sm80.cu
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_SM8x
|
| 8 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_bwd_<80, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_bwd_hdim128<80, cutlass::half_t, false>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
template<>
|
| 14 |
-
void run_mha_bwd_<86, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 15 |
-
run_mha_bwd_hdim128<86, cutlass::half_t, false>(params, stream);
|
| 16 |
-
}
|
| 17 |
-
#endif
|
| 18 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim128_fp16_sm90.cu
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
| 8 |
-
template<>
|
| 9 |
-
void run_mha_bwd_<90, cutlass::half_t, 128, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 10 |
-
run_mha_bwd_hdim128<90, cutlass::half_t, false>(params, stream);
|
| 11 |
-
}
|
| 12 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm80.cu
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_SM8x
|
| 8 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_bwd_<80, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_bwd_hdim128<80, cutlass::half_t, true>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
template<>
|
| 14 |
-
void run_mha_bwd_<86, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 15 |
-
run_mha_bwd_hdim128<86, cutlass::half_t, true>(params, stream);
|
| 16 |
-
}
|
| 17 |
-
#endif
|
| 18 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcap_sm90.cu
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM128
|
| 8 |
-
template<>
|
| 9 |
-
void run_mha_bwd_<90, cutlass::half_t, 128, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 10 |
-
run_mha_bwd_hdim128<90, cutlass::half_t, true>(params, stream);
|
| 11 |
-
}
|
| 12 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim128_fp16_softcapall_sm90.cu
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_hdim128_fp16_sm90.cu"
|
| 6 |
-
#include "flash_bwd_hdim128_fp16_softcap_sm90.cu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim192_bf16_sm80.cu
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_SM8x
|
| 8 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_bwd_<80, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_bwd_hdim192<80, cutlass::bfloat16_t, false>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
template<>
|
| 14 |
-
void run_mha_bwd_<86, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 15 |
-
run_mha_bwd_hdim192<86, cutlass::bfloat16_t, false>(params, stream);
|
| 16 |
-
}
|
| 17 |
-
#endif
|
| 18 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim192_bf16_sm90.cu
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
| 8 |
-
template<>
|
| 9 |
-
void run_mha_bwd_<90, cutlass::bfloat16_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 10 |
-
run_mha_bwd_hdim192<90, cutlass::bfloat16_t, false>(params, stream);
|
| 11 |
-
}
|
| 12 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm80.cu
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_SM8x
|
| 8 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_bwd_<80, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_bwd_hdim192<80, cutlass::bfloat16_t, true>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
template<>
|
| 14 |
-
void run_mha_bwd_<86, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 15 |
-
run_mha_bwd_hdim192<86, cutlass::bfloat16_t, true>(params, stream);
|
| 16 |
-
}
|
| 17 |
-
#endif
|
| 18 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcap_sm90.cu
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
| 8 |
-
template<>
|
| 9 |
-
void run_mha_bwd_<90, cutlass::bfloat16_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 10 |
-
run_mha_bwd_hdim192<90, cutlass::bfloat16_t, true>(params, stream);
|
| 11 |
-
}
|
| 12 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim192_bf16_softcapall_sm90.cu
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_hdim192_bf16_sm90.cu"
|
| 6 |
-
#include "flash_bwd_hdim192_bf16_softcap_sm90.cu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim192_fp16_sm80.cu
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_SM8x
|
| 8 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_bwd_<80, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_bwd_hdim192<80, cutlass::half_t, false>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
template<>
|
| 14 |
-
void run_mha_bwd_<86, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 15 |
-
run_mha_bwd_hdim192<86, cutlass::half_t, false>(params, stream);
|
| 16 |
-
}
|
| 17 |
-
#endif
|
| 18 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim192_fp16_sm90.cu
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
| 8 |
-
template<>
|
| 9 |
-
void run_mha_bwd_<90, cutlass::half_t, 192, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 10 |
-
run_mha_bwd_hdim192<90, cutlass::half_t, false>(params, stream);
|
| 11 |
-
}
|
| 12 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm80.cu
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_SM8x
|
| 8 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_bwd_<80, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_bwd_hdim192<80, cutlass::half_t, true>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
template<>
|
| 14 |
-
void run_mha_bwd_<86, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 15 |
-
run_mha_bwd_hdim192<86, cutlass::half_t, true>(params, stream);
|
| 16 |
-
}
|
| 17 |
-
#endif
|
| 18 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcap_sm90.cu
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM192
|
| 8 |
-
template<>
|
| 9 |
-
void run_mha_bwd_<90, cutlass::half_t, 192, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 10 |
-
run_mha_bwd_hdim192<90, cutlass::half_t, true>(params, stream);
|
| 11 |
-
}
|
| 12 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim192_fp16_softcapall_sm90.cu
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_hdim192_fp16_sm90.cu"
|
| 6 |
-
#include "flash_bwd_hdim192_fp16_softcap_sm90.cu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim256_bf16_sm80.cu
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_SM8x
|
| 8 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_bwd_<80, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_bwd_hdim256<80, cutlass::bfloat16_t, false>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
template<>
|
| 14 |
-
void run_mha_bwd_<86, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 15 |
-
run_mha_bwd_hdim256<86, cutlass::bfloat16_t, false>(params, stream);
|
| 16 |
-
}
|
| 17 |
-
#endif
|
| 18 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim256_bf16_sm90.cu
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
| 8 |
-
template<>
|
| 9 |
-
void run_mha_bwd_<90, cutlass::bfloat16_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 10 |
-
run_mha_bwd_hdim256<90, cutlass::bfloat16_t, false>(params, stream);
|
| 11 |
-
}
|
| 12 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm80.cu
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_SM8x
|
| 8 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_bwd_<80, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_bwd_hdim256<80, cutlass::bfloat16_t, true>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
template<>
|
| 14 |
-
void run_mha_bwd_<86, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 15 |
-
run_mha_bwd_hdim256<86, cutlass::bfloat16_t, true>(params, stream);
|
| 16 |
-
}
|
| 17 |
-
#endif
|
| 18 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcap_sm90.cu
DELETED
|
@@ -1,12 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
| 8 |
-
template<>
|
| 9 |
-
void run_mha_bwd_<90, cutlass::bfloat16_t, 256, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 10 |
-
run_mha_bwd_hdim256<90, cutlass::bfloat16_t, true>(params, stream);
|
| 11 |
-
}
|
| 12 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim256_bf16_softcapall_sm90.cu
DELETED
|
@@ -1,6 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_hdim256_bf16_sm90.cu"
|
| 6 |
-
#include "flash_bwd_hdim256_bf16_softcap_sm90.cu"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash-attn/instantiations/flash_bwd_hdim256_fp16_sm80.cu
DELETED
|
@@ -1,18 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
|
| 2 |
-
// Splitting the different template instantiations to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
#ifndef FLASHATTENTION_DISABLE_SM8x
|
| 8 |
-
#ifndef FLASHATTENTION_DISABLE_HDIM256
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_bwd_<80, cutlass::half_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_bwd_hdim256<80, cutlass::half_t, false>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
template<>
|
| 14 |
-
void run_mha_bwd_<86, cutlass::half_t, 256, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 15 |
-
run_mha_bwd_hdim256<86, cutlass::half_t, false>(params, stream);
|
| 16 |
-
}
|
| 17 |
-
#endif
|
| 18 |
-
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|