4
4
import torch .nn as nn
5
5
from flashinfer .norm import (
6
6
fused_add_rmsnorm ,
7
- # gemma_fused_add_rmsnorm,
8
- # gemma_rmsnorm,
9
7
rmsnorm ,
10
8
)
11
9
from scratchpad .model_executor .custom_op import CustomOp
@@ -50,45 +48,4 @@ def forward_native(
50
48
if residual is None :
51
49
return x
52
50
else :
53
- return x , residual
54
-
55
-
56
- # class GemmaRMSNorm(CustomOp):
57
- # def __init__(
58
- # self,
59
- # hidden_size: int,
60
- # eps: float = 1e-6,
61
- # ) -> None:
62
- # super().__init__()
63
- # self.weight = nn.Parameter(torch.zeros(hidden_size))
64
- # self.variance_epsilon = eps
65
-
66
- # def forward_native(
67
- # self,
68
- # x: torch.Tensor,
69
- # residual: Optional[torch.Tensor] = None,
70
- # ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
71
- # orig_dtype = x.dtype
72
- # if residual is not None:
73
- # x = x + residual
74
- # residual = x
75
-
76
- # x = x.float()
77
- # variance = x.pow(2).mean(dim=-1, keepdim=True)
78
- # x = x * torch.rsqrt(variance + self.variance_epsilon)
79
- # x = x * (1.0 + self.weight.float())
80
- # x = x.to(orig_dtype)
81
- # return x if residual is None else (x, residual)
82
-
83
- # def forward_cuda(
84
- # self,
85
- # x: torch.Tensor,
86
- # residual: Optional[torch.Tensor] = None,
87
- # ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
88
- # if residual is not None:
89
- # gemma_fused_add_rmsnorm(
90
- # x, residual, self.weight.data, self.variance_epsilon
91
- # )
92
- # return x, residual
93
- # out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
94
- # return out
51
+ return x , residual
0 commit comments