File size: 893 Bytes
e5e2eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from common.diff_engine import DiffCase

import activation


class RMS(DiffCase):

    def build_inputs(self, bs, sl, hidden, dtype, eps):
        return {
            "x": torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True),
            "weight": torch.ones(hidden, dtype=dtype),
            "dim": hidden,
            "eps": eps,
            "dtype": dtype,
        }

    def make_naive(self, I):
        m = torch.nn.RMSNorm(I["dim"], I["eps"], dtype=I["dtype"])
        m.weight = torch.nn.Parameter(I["weight"].detach().clone())
        return m

    def make_cuda(self, I):
        m = activation.layers.RMSNorm(I["dim"], I["eps"], dtype=I["dtype"])
        m.weight = torch.nn.Parameter(I["weight"].detach().clone())
        return m

    def forward(self, obj, I):
        return obj(I["x"])

    def grad_inputs(self, I):
        return [I["x"]]


CASE = RMS()