drbh
commited on
Commit
·
c743a32
1
Parent(s):
4762963
fix: align kernel source with latest reference source
Browse files- .gitignore +8 -1
- build.toml +6 -20
- flash_attn/flash_api.cpp +5 -5
- flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu +0 -14
- flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu +0 -14
- flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu +0 -14
- flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu +0 -14
- flash_attn/src/flash_bwd_launch_template.h +1 -21
- flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu +0 -14
- flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu +0 -14
- flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu +0 -14
- flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu +0 -14
- flash_attn/src/flash_fwd_launch_template.h +2 -31
- flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu +0 -11
- flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu +0 -11
- flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu +0 -11
- flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu +0 -11
- flash_attn/src/generate_kernels.py +1 -1
- flash_attn/src/static_switch.h +0 -3
.gitignore
CHANGED
|
@@ -1 +1,8 @@
|
|
| 1 |
-
.bak
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.bak
|
| 2 |
+
__pycache__
|
| 3 |
+
build-ext
|
| 4 |
+
cmake
|
| 5 |
+
result
|
| 6 |
+
CMakeLists.txt
|
| 7 |
+
setup.py
|
| 8 |
+
pyproject.toml
|
build.toml
CHANGED
|
@@ -1,10 +1,12 @@
|
|
| 1 |
[general]
|
| 2 |
name = "flash_attn"
|
|
|
|
| 3 |
|
| 4 |
[torch]
|
| 5 |
src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
|
| 6 |
|
| 7 |
[kernel.flash_attn]
|
|
|
|
| 8 |
cuda-capabilities = [
|
| 9 |
"8.0",
|
| 10 |
"9.0",
|
|
@@ -13,6 +15,7 @@ cuda-capabilities = [
|
|
| 13 |
]
|
| 14 |
src = [
|
| 15 |
"flash_attn/flash_api.cpp",
|
|
|
|
| 16 |
"flash_attn/src/philox_unpack.cuh",
|
| 17 |
"flash_attn/src/namespace_config.h",
|
| 18 |
"flash_attn/src/hardware_info.h",
|
|
@@ -21,29 +24,18 @@ src = [
|
|
| 21 |
"flash_attn/src/alibi.h",
|
| 22 |
"flash_attn/src/block_info.h",
|
| 23 |
"flash_attn/src/dropout.h",
|
| 24 |
-
"flash_attn/src/flash.h",
|
| 25 |
-
"flash_attn/src/generate_kernels.py",
|
| 26 |
-
"flash_attn/src/hardware_info.h",
|
| 27 |
"flash_attn/src/kernel_traits.h",
|
| 28 |
"flash_attn/src/mask.h",
|
| 29 |
-
"flash_attn/src/namespace_config.h",
|
| 30 |
"flash_attn/src/philox.cuh",
|
| 31 |
-
"flash_attn/src/philox_unpack.cuh",
|
| 32 |
"flash_attn/src/rotary.h",
|
| 33 |
"flash_attn/src/softmax.h",
|
| 34 |
-
"flash_attn/src/static_switch.h",
|
| 35 |
"flash_attn/src/utils.h",
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
| 39 |
"flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
|
| 40 |
"flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
|
| 41 |
"flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
|
| 42 |
"flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
|
| 43 |
-
"flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu",
|
| 44 |
-
"flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
|
| 45 |
-
"flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu",
|
| 46 |
-
"flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
|
| 47 |
"flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
|
| 48 |
"flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
|
| 49 |
"flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
|
|
@@ -73,10 +65,6 @@ src = [
|
|
| 73 |
"flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
|
| 74 |
"flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
|
| 75 |
"flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
|
| 76 |
-
"flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
|
| 77 |
-
"flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
|
| 78 |
-
"flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
|
| 79 |
-
"flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
|
| 80 |
"flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
|
| 81 |
"flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
|
| 82 |
"flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
|
|
@@ -99,14 +87,12 @@ src = [
|
|
| 99 |
"flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
|
| 100 |
"flash_attn/src/flash_fwd_kernel.h",
|
| 101 |
"flash_attn/src/flash_fwd_launch_template.h",
|
|
|
|
|
|
|
| 102 |
"flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
|
| 103 |
"flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
|
| 104 |
"flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
|
| 105 |
"flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
|
| 106 |
-
"flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
|
| 107 |
-
"flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
|
| 108 |
-
"flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
|
| 109 |
-
"flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
|
| 110 |
"flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
|
| 111 |
"flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
|
| 112 |
"flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
|
|
|
|
| 1 |
[general]
|
| 2 |
name = "flash_attn"
|
| 3 |
+
universal=false
|
| 4 |
|
| 5 |
[torch]
|
| 6 |
src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"]
|
| 7 |
|
| 8 |
[kernel.flash_attn]
|
| 9 |
+
backend = "cuda"
|
| 10 |
cuda-capabilities = [
|
| 11 |
"8.0",
|
| 12 |
"9.0",
|
|
|
|
| 15 |
]
|
| 16 |
src = [
|
| 17 |
"flash_attn/flash_api.cpp",
|
| 18 |
+
|
| 19 |
"flash_attn/src/philox_unpack.cuh",
|
| 20 |
"flash_attn/src/namespace_config.h",
|
| 21 |
"flash_attn/src/hardware_info.h",
|
|
|
|
| 24 |
"flash_attn/src/alibi.h",
|
| 25 |
"flash_attn/src/block_info.h",
|
| 26 |
"flash_attn/src/dropout.h",
|
|
|
|
|
|
|
|
|
|
| 27 |
"flash_attn/src/kernel_traits.h",
|
| 28 |
"flash_attn/src/mask.h",
|
|
|
|
| 29 |
"flash_attn/src/philox.cuh",
|
|
|
|
| 30 |
"flash_attn/src/rotary.h",
|
| 31 |
"flash_attn/src/softmax.h",
|
|
|
|
| 32 |
"flash_attn/src/utils.h",
|
| 33 |
|
| 34 |
+
# bwd kernels
|
|
|
|
| 35 |
"flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
|
| 36 |
"flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
|
| 37 |
"flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
|
| 38 |
"flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
"flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
|
| 40 |
"flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
|
| 41 |
"flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
|
|
|
|
| 65 |
"flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
|
| 66 |
"flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
|
| 67 |
"flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
"flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
|
| 69 |
"flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
|
| 70 |
"flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
|
|
|
|
| 87 |
"flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
|
| 88 |
"flash_attn/src/flash_fwd_kernel.h",
|
| 89 |
"flash_attn/src/flash_fwd_launch_template.h",
|
| 90 |
+
|
| 91 |
+
# split kernels
|
| 92 |
"flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
|
| 93 |
"flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
|
| 94 |
"flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
|
| 95 |
"flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
"flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
|
| 97 |
"flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
|
| 98 |
"flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
|
flash_attn/flash_api.cpp
CHANGED
|
@@ -432,7 +432,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult
|
|
| 432 |
}
|
| 433 |
|
| 434 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 435 |
-
const int head_size_rounded = head_size <=
|
| 436 |
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
| 437 |
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
| 438 |
|
|
@@ -644,7 +644,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
|
|
| 644 |
}
|
| 645 |
|
| 646 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 647 |
-
const int head_size_rounded = head_size <=
|
| 648 |
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
| 649 |
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
| 650 |
|
|
@@ -831,7 +831,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl
|
|
| 831 |
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
| 832 |
|
| 833 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 834 |
-
const int head_size_rounded = head_size <=
|
| 835 |
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
| 836 |
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
| 837 |
|
|
@@ -1048,7 +1048,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
|
| 1048 |
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
|
| 1049 |
|
| 1050 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 1051 |
-
const int head_size_rounded = head_size <=
|
| 1052 |
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
| 1053 |
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
| 1054 |
|
|
@@ -1321,7 +1321,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
|
| 1321 |
|
| 1322 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 1323 |
const int head_size = round_multiple(head_size_og, 8);
|
| 1324 |
-
const int head_size_rounded = head_size <=
|
| 1325 |
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
| 1326 |
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
| 1327 |
|
|
|
|
| 432 |
}
|
| 433 |
|
| 434 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 435 |
+
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
|
| 436 |
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
| 437 |
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
| 438 |
|
|
|
|
| 644 |
}
|
| 645 |
|
| 646 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 647 |
+
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
|
| 648 |
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
| 649 |
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
| 650 |
|
|
|
|
| 831 |
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
| 832 |
|
| 833 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 834 |
+
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
|
| 835 |
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
| 836 |
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
| 837 |
|
|
|
|
| 1048 |
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
|
| 1049 |
|
| 1050 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 1051 |
+
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
|
| 1052 |
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
| 1053 |
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
| 1054 |
|
|
|
|
| 1321 |
|
| 1322 |
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
| 1323 |
const int head_size = round_multiple(head_size_og, 8);
|
| 1324 |
+
const int head_size_rounded = round_multiple(head_size, head_size <= 128 ? 32 : 64);
|
| 1325 |
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
| 1326 |
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
| 1327 |
|
flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Tri Dao.
|
| 2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
#include "namespace_config.h"
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
namespace FLASH_NAMESPACE {
|
| 8 |
-
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_bwd_<cutlass::bfloat16_t, 160, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_bwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
|
| 14 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Tri Dao.
|
| 2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
#include "namespace_config.h"
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
namespace FLASH_NAMESPACE {
|
| 8 |
-
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_bwd_<cutlass::bfloat16_t, 160, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_bwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
|
| 14 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Tri Dao.
|
| 2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
#include "namespace_config.h"
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
namespace FLASH_NAMESPACE {
|
| 8 |
-
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_bwd_<cutlass::half_t, 160, true>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_bwd_hdim160<cutlass::half_t, true>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
|
| 14 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Tri Dao.
|
| 2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
#include "namespace_config.h"
|
| 5 |
-
#include "flash_bwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
namespace FLASH_NAMESPACE {
|
| 8 |
-
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_bwd_<cutlass::half_t, 160, false>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_bwd_hdim160<cutlass::half_t, false>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
|
| 14 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_bwd_launch_template.h
CHANGED
|
@@ -102,7 +102,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream)
|
|
| 102 |
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
| 103 |
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
| 104 |
// If Is_local, set Is_causal to false
|
| 105 |
-
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
|
| 106 |
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
|
| 107 |
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
| 108 |
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
|
@@ -261,26 +261,6 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
|
| 261 |
});
|
| 262 |
}
|
| 263 |
|
| 264 |
-
template<typename T, bool Is_causal>
|
| 265 |
-
void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 266 |
-
constexpr static int Headdim = 160;
|
| 267 |
-
int device;
|
| 268 |
-
cudaGetDevice(&device);
|
| 269 |
-
int max_smem_per_block;
|
| 270 |
-
cudaError status_ = cudaDeviceGetAttribute(
|
| 271 |
-
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
| 272 |
-
if (status_ != cudaSuccess) {
|
| 273 |
-
C10_CUDA_CHECK(status_);
|
| 274 |
-
}
|
| 275 |
-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
| 276 |
-
if (max_smem_per_block >= 116 * 1024) {
|
| 277 |
-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
| 278 |
-
} else {
|
| 279 |
-
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
|
| 280 |
-
}
|
| 281 |
-
});
|
| 282 |
-
}
|
| 283 |
-
|
| 284 |
template<typename T, bool Is_causal>
|
| 285 |
void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 286 |
constexpr static int Headdim = 192;
|
|
|
|
| 102 |
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
| 103 |
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
| 104 |
// If Is_local, set Is_causal to false
|
| 105 |
+
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !Has_alibi && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !Has_alibi, Is_softcap>;
|
| 106 |
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
|
| 107 |
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
| 108 |
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
|
|
|
| 261 |
});
|
| 262 |
}
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
template<typename T, bool Is_causal>
|
| 265 |
void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
| 266 |
constexpr static int Headdim = 192;
|
flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Tri Dao.
|
| 2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
#include "namespace_config.h"
|
| 5 |
-
#include "flash_fwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
namespace FLASH_NAMESPACE {
|
| 8 |
-
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_fwd_<cutlass::bfloat16_t, 160, true>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_fwd_hdim160<cutlass::bfloat16_t, true>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
|
| 14 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Tri Dao.
|
| 2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
#include "namespace_config.h"
|
| 5 |
-
#include "flash_fwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
namespace FLASH_NAMESPACE {
|
| 8 |
-
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_fwd_<cutlass::bfloat16_t, 160, false>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_fwd_hdim160<cutlass::bfloat16_t, false>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
|
| 14 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Tri Dao.
|
| 2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
#include "namespace_config.h"
|
| 5 |
-
#include "flash_fwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
namespace FLASH_NAMESPACE {
|
| 8 |
-
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_fwd_<cutlass::half_t, 160, true>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_fwd_hdim160<cutlass::half_t, true>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
|
| 14 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Tri Dao.
|
| 2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
#include "namespace_config.h"
|
| 5 |
-
#include "flash_fwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
namespace FLASH_NAMESPACE {
|
| 8 |
-
|
| 9 |
-
template<>
|
| 10 |
-
void run_mha_fwd_<cutlass::half_t, 160, false>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
| 11 |
-
run_mha_fwd_hdim160<cutlass::half_t, false>(params, stream);
|
| 12 |
-
}
|
| 13 |
-
|
| 14 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_fwd_launch_template.h
CHANGED
|
@@ -76,7 +76,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
| 76 |
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
| 77 |
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
| 78 |
// If Is_local, set Is_causal to false
|
| 79 |
-
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
|
| 80 |
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
| 81 |
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
| 82 |
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
|
@@ -117,7 +117,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
| 117 |
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
| 118 |
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
| 119 |
// If Is_local, set Is_causal to false
|
| 120 |
-
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>;
|
| 121 |
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
|
| 122 |
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
|
| 123 |
if (smem_size >= 48 * 1024) {
|
|
@@ -165,7 +165,6 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream)
|
|
| 165 |
constexpr static int kBlockM = 64; // Fixed for all head dimensions
|
| 166 |
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
|
| 167 |
// and for headdim 192 with block size 64 x 128.
|
| 168 |
-
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
|
| 169 |
constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
|
| 170 |
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
|
| 171 |
}
|
|
@@ -257,34 +256,6 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
|
| 257 |
});
|
| 258 |
}
|
| 259 |
|
| 260 |
-
template<typename T, bool Is_causal>
|
| 261 |
-
void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
| 262 |
-
constexpr static int Headdim = 160;
|
| 263 |
-
auto [cc_major, cc_minor] = get_compute_capability(get_current_device());
|
| 264 |
-
bool is_sm8x = cc_major == 8 && cc_minor > 0;
|
| 265 |
-
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
| 266 |
-
// For A100, H100, 128 x 32 is the fastest.
|
| 267 |
-
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
| 268 |
-
// and 128 x 64 with 8 warps is the fastest for non-causal.
|
| 269 |
-
if (is_sm8x) {
|
| 270 |
-
if constexpr(!Is_causal) {
|
| 271 |
-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
| 272 |
-
} else {
|
| 273 |
-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
| 274 |
-
}
|
| 275 |
-
} else {
|
| 276 |
-
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
| 277 |
-
}
|
| 278 |
-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
|
| 279 |
-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
| 280 |
-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
| 281 |
-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
|
| 282 |
-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
|
| 283 |
-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
|
| 284 |
-
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
|
| 285 |
-
});
|
| 286 |
-
}
|
| 287 |
-
|
| 288 |
template<typename T, bool Is_causal>
|
| 289 |
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
| 290 |
constexpr static int Headdim = 192;
|
|
|
|
| 76 |
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
| 77 |
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
| 78 |
// If Is_local, set Is_causal to false
|
| 79 |
+
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !Has_alibi && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !ReturnSoftmaxConst && !Has_alibi, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>;
|
| 80 |
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
| 81 |
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
| 82 |
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
|
|
|
| 117 |
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
| 118 |
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
| 119 |
// If Is_local, set Is_causal to false
|
| 120 |
+
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && !Has_alibi && Kernel_traits::kHeadDim <= 128, IsEvenKConst && !Has_alibi, Is_softcap, Split, Append_KV>;
|
| 121 |
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
|
| 122 |
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
|
| 123 |
if (smem_size >= 48 * 1024) {
|
|
|
|
| 165 |
constexpr static int kBlockM = 64; // Fixed for all head dimensions
|
| 166 |
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
|
| 167 |
// and for headdim 192 with block size 64 x 128.
|
|
|
|
| 168 |
constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
|
| 169 |
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream);
|
| 170 |
}
|
|
|
|
| 256 |
});
|
| 257 |
}
|
| 258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
template<typename T, bool Is_causal>
|
| 260 |
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
| 261 |
constexpr static int Headdim = 192;
|
flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Tri Dao.
|
| 2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
#include "namespace_config.h"
|
| 5 |
-
#include "flash_fwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
namespace FLASH_NAMESPACE {
|
| 8 |
-
|
| 9 |
-
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160, true>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
| 10 |
-
|
| 11 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Tri Dao.
|
| 2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
#include "namespace_config.h"
|
| 5 |
-
#include "flash_fwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
namespace FLASH_NAMESPACE {
|
| 8 |
-
|
| 9 |
-
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160, false>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
| 10 |
-
|
| 11 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Tri Dao.
|
| 2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
#include "namespace_config.h"
|
| 5 |
-
#include "flash_fwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
namespace FLASH_NAMESPACE {
|
| 8 |
-
|
| 9 |
-
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160, true>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
| 10 |
-
|
| 11 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
DELETED
|
@@ -1,11 +0,0 @@
|
|
| 1 |
-
// Copyright (c) 2024, Tri Dao.
|
| 2 |
-
// Splitting the different head dimensions to different files to speed up compilation.
|
| 3 |
-
// This file is auto-generated. See "generate_kernels.py"
|
| 4 |
-
#include "namespace_config.h"
|
| 5 |
-
#include "flash_fwd_launch_template.h"
|
| 6 |
-
|
| 7 |
-
namespace FLASH_NAMESPACE {
|
| 8 |
-
|
| 9 |
-
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160, false>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
| 10 |
-
|
| 11 |
-
} // namespace FLASH_NAMESPACE
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
flash_attn/src/generate_kernels.py
CHANGED
|
@@ -10,7 +10,7 @@ DTYPE_MAP = {
|
|
| 10 |
}
|
| 11 |
|
| 12 |
SM = [80] # Sm80 kernels support up to
|
| 13 |
-
HEAD_DIMENSIONS = [32, 64, 96, 128,
|
| 14 |
IS_CAUSAL = ["false", "true"]
|
| 15 |
NAMESPACE_INCLUDE = '#include "namespace_config.h"\n'
|
| 16 |
|
|
|
|
| 10 |
}
|
| 11 |
|
| 12 |
SM = [80] # Sm80 kernels support up to
|
| 13 |
+
HEAD_DIMENSIONS = [32, 64, 96, 128, 192, 256]
|
| 14 |
IS_CAUSAL = ["false", "true"]
|
| 15 |
NAMESPACE_INCLUDE = '#include "namespace_config.h"\n'
|
| 16 |
|
flash_attn/src/static_switch.h
CHANGED
|
@@ -101,9 +101,6 @@
|
|
| 101 |
} else if (HEADDIM <= 128) { \
|
| 102 |
constexpr static int kHeadDim = 128; \
|
| 103 |
return __VA_ARGS__(); \
|
| 104 |
-
} else if (HEADDIM <= 160) { \
|
| 105 |
-
constexpr static int kHeadDim = 160; \
|
| 106 |
-
return __VA_ARGS__(); \
|
| 107 |
} else if (HEADDIM <= 192) { \
|
| 108 |
constexpr static int kHeadDim = 192; \
|
| 109 |
return __VA_ARGS__(); \
|
|
|
|
| 101 |
} else if (HEADDIM <= 128) { \
|
| 102 |
constexpr static int kHeadDim = 128; \
|
| 103 |
return __VA_ARGS__(); \
|
|
|
|
|
|
|
|
|
|
| 104 |
} else if (HEADDIM <= 192) { \
|
| 105 |
constexpr static int kHeadDim = 192; \
|
| 106 |
return __VA_ARGS__(); \
|