drbh
commited on
Commit
·
3099d65
1
Parent(s):
809bfb7
fix: simplify and align signatures further
Browse files- .gitignore +5 -1
- flash_attn/flash_api.cpp +17 -33
- tests/test_flash_attn.py +56 -31
- torch-ext/torch_binding.cpp +4 -1
- torch-ext/torch_binding.h +9 -9
.gitignore
CHANGED
|
@@ -5,4 +5,8 @@ cmake
|
|
| 5 |
result
|
| 6 |
CMakeLists.txt
|
| 7 |
setup.py
|
| 8 |
-
pyproject.toml
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
result
|
| 6 |
CMakeLists.txt
|
| 7 |
setup.py
|
| 8 |
+
pyproject.toml
|
| 9 |
+
.venv
|
| 10 |
+
torch-ext/registration.h
|
| 11 |
+
torch-ext/flash_attn/*.so
|
| 12 |
+
torch-ext/flash_attn/_ops.py
|
flash_attn/flash_api.cpp
CHANGED
|
@@ -1475,40 +1475,24 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
|
| 1475 |
}
|
| 1476 |
} // namespace FLASH_NAMESPACE
|
| 1477 |
|
| 1478 |
-
//
|
| 1479 |
std::vector<torch::Tensor>
|
| 1480 |
-
mha_fwd(
|
| 1481 |
-
|
| 1482 |
-
|
| 1483 |
-
const
|
| 1484 |
-
const
|
| 1485 |
-
const double
|
| 1486 |
-
|
| 1487 |
-
|
| 1488 |
-
|
| 1489 |
-
|
| 1490 |
-
|
| 1491 |
-
|
| 1492 |
-
|
| 1493 |
-
|
| 1494 |
-
|
| 1495 |
-
|
| 1496 |
-
// Prepare the optional arguments as non-const references.
|
| 1497 |
-
std::optional<at::Tensor> out = out_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(out_.value())) : std::nullopt;
|
| 1498 |
-
std::optional<at::Tensor> alibi_slopes = alibi_slopes_.has_value() ? std::optional<at::Tensor>(const_cast<at::Tensor &>(alibi_slopes_.value())) : std::nullopt;
|
| 1499 |
-
|
| 1500 |
-
if (!out.has_value()){
|
| 1501 |
-
out = torch::empty_like(q);
|
| 1502 |
-
}
|
| 1503 |
-
|
| 1504 |
-
// Convert double to float and int64_t to int.
|
| 1505 |
-
float p_dropout_float = static_cast<float>(p_dropout);
|
| 1506 |
-
float softmax_scale_float = static_cast<float>(softmax_scale);
|
| 1507 |
-
float softcap_float = static_cast<float>(softcap);
|
| 1508 |
-
int window_size_left_int = static_cast<int>(window_size_left);
|
| 1509 |
-
int window_size_right_int = static_cast<int>(window_size_right);
|
| 1510 |
-
|
| 1511 |
-
return FLASH_NAMESPACE::mha_fwd(const_cast<at::Tensor &>(q), k, v, out, alibi_slopes, p_dropout_float, softmax_scale_float, is_causal, window_size_left_int, window_size_right_int, softcap_float, return_softmax, gen);
|
| 1512 |
}
|
| 1513 |
|
| 1514 |
std::vector<torch::Tensor>
|
|
|
|
| 1475 |
}
|
| 1476 |
} // namespace FLASH_NAMESPACE
|
| 1477 |
|
| 1478 |
+
// Prefer the most minimal wrapper possible to avoid unnecessary copies or conversions.
|
| 1479 |
std::vector<torch::Tensor>
|
| 1480 |
+
mha_fwd(torch::Tensor &q, const torch::Tensor &k, const torch::Tensor &v,
|
| 1481 |
+
c10::optional<torch::Tensor> out_,
|
| 1482 |
+
c10::optional<torch::Tensor> alibi_slopes_,
|
| 1483 |
+
const double p_dropout, const double softmax_scale, bool is_causal,
|
| 1484 |
+
const int64_t window_size_left, const int64_t window_size_right,
|
| 1485 |
+
const double softcap, const bool return_softmax,
|
| 1486 |
+
c10::optional<at::Generator> gen_) {
|
| 1487 |
+
|
| 1488 |
+
printf("Confirm this path is taken\n");
|
| 1489 |
+
auto result = FLASH_NAMESPACE::mha_fwd(
|
| 1490 |
+
q, k, v, out_, alibi_slopes_, static_cast<float>(p_dropout),
|
| 1491 |
+
static_cast<float>(softmax_scale), is_causal,
|
| 1492 |
+
static_cast<int>(window_size_left), static_cast<int>(window_size_right),
|
| 1493 |
+
static_cast<float>(softcap), return_softmax, gen_);
|
| 1494 |
+
|
| 1495 |
+
return result;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1496 |
}
|
| 1497 |
|
| 1498 |
std::vector<torch::Tensor>
|
tests/test_flash_attn.py
CHANGED
|
@@ -1,38 +1,63 @@
|
|
| 1 |
import torch
|
| 2 |
-
|
| 3 |
import flash_attn
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
# TODO: improve and add more tests
|
| 7 |
def test_flash_attn():
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
)
|
| 37 |
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
|
|
|
| 2 |
import flash_attn
|
| 3 |
|
| 4 |
+
# make reproducible
|
| 5 |
+
torch.manual_seed(0)
|
| 6 |
+
|
| 7 |
+
def _attention_torch(query, key, value, *, backend):
|
| 8 |
+
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
|
| 9 |
+
with torch.nn.attention.sdpa_kernel(backend):
|
| 10 |
+
out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
|
| 11 |
+
out = out.transpose(1, 2).contiguous()
|
| 12 |
+
return out
|
| 13 |
|
|
|
|
| 14 |
def test_flash_attn():
|
| 15 |
+
# ===== Testing shape: (1, 4224, 24, 128) =====
|
| 16 |
+
batch_size = 1
|
| 17 |
+
seq_len = 4224
|
| 18 |
+
num_attention_heads = 24
|
| 19 |
+
attention_head_dim = 128
|
| 20 |
+
|
| 21 |
+
shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
query = torch.randn(shape, device="cuda", dtype=torch.float16)
|
| 25 |
+
key = torch.randn(shape, device="cuda", dtype=torch.float16)
|
| 26 |
+
value = torch.randn(shape, device="cuda", dtype=torch.float16)
|
| 27 |
+
|
| 28 |
+
golden_truth = _attention_torch(query, key, value, backend=torch.nn.attention.SDPBackend.MATH)
|
| 29 |
+
|
| 30 |
+
print("Golden truth shape:", golden_truth.shape)
|
| 31 |
+
|
| 32 |
+
# print query sum
|
| 33 |
+
print("Query sum:", query.sum().item())
|
| 34 |
+
|
| 35 |
+
# now use the flash attention
|
| 36 |
+
out, softmax_lse, p, rng_state = flash_attn.mha_fwd(
|
| 37 |
+
query,
|
| 38 |
+
key,
|
| 39 |
+
value,
|
| 40 |
+
torch.empty(shape, device="cuda", dtype=torch.half),
|
| 41 |
+
torch.empty(num_attention_heads, device="cuda", dtype=torch.float32),
|
| 42 |
+
0.0,
|
| 43 |
+
1.0,
|
| 44 |
+
False,
|
| 45 |
+
0,
|
| 46 |
+
0,
|
| 47 |
+
0.0,
|
| 48 |
+
False,
|
| 49 |
+
None,
|
| 50 |
)
|
| 51 |
|
| 52 |
+
print("Flash attention output shape:", out.shape)
|
| 53 |
+
|
| 54 |
+
# print query sum
|
| 55 |
+
print(query.sum().item())
|
| 56 |
+
|
| 57 |
+
# compare
|
| 58 |
+
diff = (out- golden_truth).abs().max()
|
| 59 |
+
print("Max absolute difference:", diff.item())
|
| 60 |
+
|
| 61 |
+
assert out.shape == (1, 4224, 24, 128)
|
| 62 |
+
assert diff < 1e-2
|
| 63 |
+
|
torch-ext/torch_binding.cpp
CHANGED
|
@@ -14,7 +14,10 @@
|
|
| 14 |
// }
|
| 15 |
|
| 16 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 17 |
-
ops.def("mha_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor?
|
|
|
|
|
|
|
|
|
|
| 18 |
ops.impl("mha_fwd", torch::kCUDA, &mha_fwd);
|
| 19 |
|
| 20 |
ops.def("mha_varlen_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor? seqused_k_, Tensor? leftpad_k_, Tensor? block_table_, Tensor? alibi_slopes_, int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
|
|
|
|
| 14 |
// }
|
| 15 |
|
| 16 |
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 17 |
+
ops.def("mha_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor? "
|
| 18 |
+
"alibi_slopes_, float p_dropout, float softmax_scale, bool "
|
| 19 |
+
"is_causal, int window_size_left, int window_size_right, float "
|
| 20 |
+
"softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
|
| 21 |
ops.impl("mha_fwd", torch::kCUDA, &mha_fwd);
|
| 22 |
|
| 23 |
ops.def("mha_varlen_fwd(Tensor! q, Tensor! k, Tensor! v, Tensor? out_, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor? seqused_k_, Tensor? leftpad_k_, Tensor? block_table_, Tensor? alibi_slopes_, int max_seqlen_q, int max_seqlen_k, float p_dropout, float softmax_scale, bool zero_tensors, bool is_causal, int window_size_left, int window_size_right, float softcap, bool return_softmax, Generator? gen_) -> Tensor[]");
|
torch-ext/torch_binding.h
CHANGED
|
@@ -3,19 +3,19 @@
|
|
| 3 |
#include <torch/torch.h>
|
| 4 |
|
| 5 |
std::vector<torch::Tensor>
|
| 6 |
-
mha_fwd(
|
| 7 |
-
const torch::Tensor &k,
|
| 8 |
-
const torch::Tensor &v,
|
| 9 |
-
|
| 10 |
-
|
| 11 |
const double p_dropout,
|
| 12 |
const double softmax_scale,
|
| 13 |
-
bool is_causal
|
| 14 |
-
const int64_t window_size_left
|
| 15 |
const int64_t window_size_right,
|
| 16 |
const double softcap,
|
| 17 |
-
const bool return_softmax
|
| 18 |
-
|
| 19 |
|
| 20 |
std::vector<torch::Tensor>
|
| 21 |
mha_varlen_fwd(
|
|
|
|
| 3 |
#include <torch/torch.h>
|
| 4 |
|
| 5 |
std::vector<torch::Tensor>
|
| 6 |
+
mha_fwd(torch::Tensor &q,
|
| 7 |
+
const torch::Tensor &k,
|
| 8 |
+
const torch::Tensor &v,
|
| 9 |
+
c10::optional<torch::Tensor> out_,
|
| 10 |
+
c10::optional<torch::Tensor> alibi_slopes_,
|
| 11 |
const double p_dropout,
|
| 12 |
const double softmax_scale,
|
| 13 |
+
bool is_causal
|
| 14 |
+
const int64_t window_size_left
|
| 15 |
const int64_t window_size_right,
|
| 16 |
const double softcap,
|
| 17 |
+
const bool return_softmax
|
| 18 |
+
c10::optional<at::Generator> gen_);
|
| 19 |
|
| 20 |
std::vector<torch::Tensor>
|
| 21 |
mha_varlen_fwd(
|