| import torch | |
| import flash_attn | |
| # TODO: improve and add more tests | |
| def test_flash_attn(): | |
| q = torch.randn(2, 5, 4, 8) | |
| k = torch.randn(2, 5, 4, 8) | |
| v = torch.randn(2, 5, 4, 8) | |
| out = torch.empty(2, 5, 4, 8) | |
| alibi_slopes = torch.empty(4) | |
| p_dropout = 0.1 | |
| softmax_scale = 1.0 | |
| is_causal = False | |
| window_size_left = 0 | |
| window_size_right = 0 | |
| softcap = 0.0 | |
| return_softmax = False | |
| gen = None | |
| out = flash_attn.mha_fwd( | |
| q, | |
| k, | |
| v, | |
| out, | |
| alibi_slopes, | |
| p_dropout, | |
| softmax_scale, | |
| is_causal, | |
| window_size_left, | |
| window_size_right, | |
| softcap, | |
| return_softmax, | |
| gen, | |
| ) | |
| assert out.shape == (2, 5, 4, 8) | |