File size: 334 Bytes
			
			| e6010fe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | [general]
name = "batch_invariant"
universal = false
# Defines the C++ files that bind to PyTorch
[torch]
src = [
  "torch-ext/torch_binding.cpp",
  "torch-ext/torch_binding.h"
]
# Defines the CUDA kernels
[kernel.batch_invariant_matmul]
backend = "cuda"
depends = ["torch"]
src = [
    "csrc/batch_invariant.cu",
] | 
