|
#pragma once |
|
|
|
#include <torch/torch.h> |
|
|
|
void poly_norm(torch::Tensor &out, const torch::Tensor &input, |
|
const torch::Tensor &weights, const torch::Tensor &bias, |
|
double eps); |
|
void poly_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad, |
|
torch::Tensor &bias_grad, |
|
const torch::Tensor &output_grad, |
|
const torch::Tensor &input, const torch::Tensor &weight, |
|
double eps); |
|
|
|
void rms_norm(torch::Tensor &out, const torch::Tensor &input, |
|
const torch::Tensor &weights, double eps); |
|
void rms_norm_backward(torch::Tensor &input_grad, torch::Tensor &weight_grad, |
|
const torch::Tensor &output_grad, |
|
const torch::Tensor &input, const torch::Tensor &weight, |
|
double eps); |
|
|
|
void fused_mul_poly_norm(torch::Tensor &out, const torch::Tensor &input, |
|
const torch::Tensor &mul, const torch::Tensor &weights, |
|
const torch::Tensor &bias, double eps); |
|
void fused_mul_poly_norm_backward( |
|
torch::Tensor &input_grad, torch::Tensor &mul_grad, |
|
torch::Tensor &weight_grad, torch::Tensor &bias_grad, |
|
const torch::Tensor &output_grad, const torch::Tensor &input, |
|
const torch::Tensor &mul, const torch::Tensor &weight, |
|
const torch::Tensor &bias, double eps); |
|
|
|
void fused_add_rms_norm(torch::Tensor &out, torch::Tensor &add_out, |
|
const torch::Tensor &input, |
|
const torch::Tensor &residual, |
|
const torch::Tensor &weight, double eps); |
|
void fused_add_rms_norm_backward(torch::Tensor &input_grad, |
|
torch::Tensor &weight_grad, |
|
const torch::Tensor &output_grad, |
|
const torch::Tensor &add_output_grad, |
|
const torch::Tensor &input, |
|
const torch::Tensor &weight, double eps); |
|
|