import random import pytest import torch import activation from .utils import assert_close, opcheck DTYPES = [torch.float, torch.bfloat16, torch.half] NUM_TOKENS = [7, 83, 256, 2048] # Arbitrary values for testing D = [1, 7, 512, 13824] # Arbitrary values for testing SEEDS = [0] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] def norm(x, eps: float) -> torch.Tensor: return x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps) def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float) -> torch.Tensor: x = x.float() return (weight[0] * norm(x**3, eps) + weight[1] * norm(x**2, eps) + weight[2] * norm(x, eps) + bias).to(weight.dtype) def mul_poly_norm_all_naive(x: torch.Tensor, mul: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float) -> torch.Tensor: return poly_norm(x, weight, bias, eps) * mul #use poly_norm kernel def mul_poly_norm_partial_naive(x: torch.Tensor, mul: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float) -> torch.Tensor: return activation.poly_norm(x, weight, bias, eps) * mul @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_fused_mul_poly_norm( num_tokens: int, d: int, dtype: torch.dtype, seed: int, device: str, ) -> None: random.seed(seed) torch.manual_seed(seed) torch.set_default_device(device) x = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True) mul = torch.randn(num_tokens, d, dtype=dtype, requires_grad=True) weight = torch.randn(3, dtype=dtype, requires_grad=True) bias = torch.randn(1, dtype=dtype, requires_grad=True) eps = 1e-05 x.retain_grad() mul.retain_grad() weight.retain_grad() bias.retain_grad() # To separate gradient computation, clone the inputs x_ref = x.detach().clone().requires_grad_(True) mul_ref = mul.detach().clone().requires_grad_(True) weight_ref = weight.detach().clone().requires_grad_(True) bias_ref = bias.detach().clone().requires_grad_(True) x_ref2 = x.detach().clone().requires_grad_(True) mul_ref2 = mul.detach().clone().requires_grad_(True) weight_ref2 = weight.detach().clone().requires_grad_(True) bias_ref2 = bias.detach().clone().requires_grad_(True) torch_fn = mul_poly_norm_all_naive torch_fn2 = mul_poly_norm_partial_naive op = activation.ops.fused_mul_poly_norm fn = activation.fused_mul_poly_norm layer = activation.layers.FusedMulPolyNorm(eps) layer.weight = torch.nn.Parameter(weight) layer.bias = torch.nn.Parameter(bias) out = torch.empty(x.shape, dtype=x.dtype, device=x.device) opcheck(op, (out, x, mul, weight, bias, eps)) out = fn(x, mul, weight, bias, eps) mod_out = layer(x, mul) ref_out = torch_fn(x_ref, mul_ref, weight_ref, bias_ref, eps) ref_out2 = torch_fn2(x_ref2, mul_ref2, weight_ref2, bias_ref2, eps) # Mul amplifies small numeric differences between naive poly_norm and the kernel. # When validating against all_naive, use a looser rtol/atol. assert_close(out, ref_out, atol=0.01, rtol=0.01) assert_close(out, ref_out2) assert_close(mod_out, out, atol=0.0, rtol=0.0) # test backward pass out_grad = torch.randn_like(out) out_grad = out_grad / out_grad.norm() ref_out.backward(out_grad) ref_out2.backward(out_grad) mod_out.backward(out_grad) assert_close(x.grad, x_ref.grad) assert_close(x.grad, x_ref2.grad) assert_close(mul.grad, mul_ref.grad) assert_close(mul.grad, mul_ref2.grad) assert_close(layer.bias.grad, bias_ref.grad, rtol=0.05) assert_close(layer.bias.grad, bias_ref2.grad, rtol=0.05) assert_close(layer.weight.grad, weight_ref.grad, rtol=0.05) assert_close(layer.weight.grad, weight_ref2.grad, rtol=0.05)