MekkCyber commited on
Commit
317612c
·
1 Parent(s): 9967514

some updates

Browse files
torch-ext/rmsnorm_kernel/__init__.py CHANGED
@@ -1,21 +1,13 @@
1
  import torch
2
- import torch.nn as nn
3
 
4
  from ._ops import ops
5
 
 
6
 
7
- class LlamaRMSNorm(nn.Module):
8
- weight: torch.Tensor
9
- variance_epsilon: float
10
 
11
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
- return ops.rmsnorm_forward(
13
- hidden_states,
14
- self.weight,
15
- bias=None,
16
- residual=None,
17
- eps=self.variance_epsilon,
18
- dropout_p=0.0,
19
- prenorm=False,
20
- residual_in_fp32=False,
21
- )
 
1
  import torch
 
2
 
3
  from ._ops import ops
4
 
5
+ from . import layers
6
 
7
+ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
8
+ return ops.rmsnorm_forward(x, weight)
 
9
 
10
+ __all__ = [
11
+ "layers",
12
+ "rmsnorm_forward",
13
+ ]
 
 
 
 
 
 
 
torch-ext/rmsnorm_kernel/layers.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from ._ops import ops
5
+
6
+
7
+ class LlamaRMSNorm(nn.Module):
8
+ weight: torch.Tensor
9
+ variance_epsilon: float
10
+
11
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
12
+ return ops.rmsnorm_forward(
13
+ hidden_states,
14
+ self.weight,
15
+ )