File size: 1,633 Bytes
e5e2eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from common.diff_engine import DiffCase

import activation


class FusedMulPolyNorm(torch.nn.Module):

    def __init__(self, eps=1e-6, dtype: torch.dtype = torch.float32):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(3, dtype=dtype) / 3)
        self.bias = torch.nn.Parameter(torch.zeros(1, dtype=dtype))
        self.eps = eps

    def forward(self, x, mul):
        output = activation.poly_norm(x, self.weight, self.bias, self.eps)
        return output * mul


class MulPoly(DiffCase):

    def build_inputs(self, bs, sl, hidden, dtype, eps):
        return {
            "x": torch.randn(bs, sl, hidden, dtype=dtype, requires_grad=True),
            "mul": torch.randn(bs, sl, hidden, dtype=dtype,
                               requires_grad=True),
            "weight": torch.ones(3, dtype=dtype),
            "bias": torch.ones(1, dtype=dtype),
            "dim": hidden,
            "eps": eps,
            "dtype": dtype,
        }

    def make_naive(self, I):
        m = FusedMulPolyNorm(I["eps"], dtype=I["dtype"])
        m.weight = torch.nn.Parameter(I["weight"].detach().clone())
        m.bias = torch.nn.Parameter(I["bias"].detach().clone())
        return m

    def make_cuda(self, I):
        m = activation.layers.FusedMulPolyNorm(I["eps"], dtype=I["dtype"])
        m.weight = torch.nn.Parameter(I["weight"].detach().clone())
        m.bias = torch.nn.Parameter(I["bias"].detach().clone())
        return m

    def forward(self, obj, I):
        return obj(I["x"], I["mul"])

    def grad_inputs(self, I):
        return [I["x"], I["mul"]]


CASE = MulPoly()