|
[general] |
|
name = "activation" |
|
universal = false |
|
|
|
[torch] |
|
src = [ |
|
"torch-ext/torch_binding.cpp", |
|
"torch-ext/torch_binding.h" |
|
] |
|
|
|
[kernel.activation] |
|
backend = "rocm" |
|
rocm-archs = [ "gfx90a", "gfx942" ] |
|
src = [ |
|
"activation/poly_norm.cu", |
|
"activation/fused_mul_poly_norm.cu", |
|
"activation/rms_norm.cu", |
|
"activation/fused_add_rms_norm.cu", |
|
"activation/cuda_compat.h", |
|
"activation/dispatch_utils.h", |
|
"activation/assert_utils.h", |
|
"activation/atomic_utils.h", |
|
] |
|
depends = [ "torch" ] |
|
|
|
[kernel.activation_cuda] |
|
backend = "cuda" |
|
src = [ |
|
"activation/poly_norm.cu", |
|
"activation/fused_mul_poly_norm.cu", |
|
"activation/rms_norm.cu", |
|
"activation/fused_add_rms_norm.cu", |
|
"activation/cuda_compat.h", |
|
"activation/dispatch_utils.h", |
|
"activation/assert_utils.h", |
|
"activation/atomic_utils.h", |
|
] |
|
depends = ["torch"] |
|
|
|
|