11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ from __future__ import annotations
15
+
16
+ from typing import TYPE_CHECKING , Literal
17
+
18
+ from typing_extensions import TypeAlias
14
19
15
20
from paddle .incubate .nn import functional as F
16
21
from paddle .nn import Layer
17
22
23
+ if TYPE_CHECKING :
24
+ from paddle import Tensor
25
+ from paddle ._typing import ParamAttrLike
26
+
27
+ _ActTypeLiteral : TypeAlias = Literal [
28
+ 'gelu' ,
29
+ 'relu' ,
30
+ ]
31
+
18
32
19
33
class FusedEcMoe (Layer ):
20
34
r"""A FusedEcMoe Layer.
@@ -24,11 +38,11 @@ class FusedEcMoe(Layer):
24
38
inter_size (int): The dim size of feed forward network.
25
39
num_expert (int): The number of experts.
26
40
act_type (string): The activation type. Currently only support `gelu`, `relu`.
27
- weight_attr (ParamAttr, optional): The attribute for the learnable
41
+ weight_attr (ParamAttr|None , optional): The attribute for the learnable
28
42
weight of this layer. The default value is None and the weight will be
29
43
initialized to zero. For detailed information, please refer to
30
44
paddle.ParamAttr.
31
- bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias
45
+ bias_attr (ParamAttr|bool|None , optional): The attribute for the learnable bias
32
46
of this layer. If it is set to False, no bias will be added to the output.
33
47
If it is set to None or one kind of ParamAttr, a bias parameter will
34
48
be created according to ParamAttr. For detailed information, please refer
@@ -58,16 +72,21 @@ class FusedEcMoe(Layer):
58
72
>>> print(y.shape)
59
73
[10, 128, 1024]
60
74
"""
75
+ bmm_weight0 : Tensor
76
+ bmm_bias0 : Tensor
77
+ bmm_weight1 : Tensor
78
+ bmm_bias1 : Tensor
79
+ act_type : _ActTypeLiteral
61
80
62
81
def __init__ (
63
82
self ,
64
- hidden_size ,
65
- inter_size ,
66
- num_experts ,
67
- act_type ,
68
- weight_attr = None ,
69
- bias_attr = None ,
70
- ):
83
+ hidden_size : int ,
84
+ inter_size : int ,
85
+ num_experts : int ,
86
+ act_type : _ActTypeLiteral ,
87
+ weight_attr : ParamAttrLike | None = None ,
88
+ bias_attr : ParamAttrLike | None = None ,
89
+ ) -> None :
71
90
super ().__init__ ()
72
91
weight0_shape = [num_experts , hidden_size , inter_size ]
73
92
bias0_shape = [num_experts , 1 , inter_size ]
@@ -91,7 +110,7 @@ def __init__(
91
110
if self .act_type not in ["gelu" , "relu" ]:
92
111
raise NotImplementedError ("Currently only support `gelu`, `relu`. " )
93
112
94
- def forward (self , x , gate ) :
113
+ def forward (self , x : Tensor , gate : Tensor ) -> Tensor :
95
114
return F .fused_ec_moe (
96
115
x ,
97
116
gate ,
0 commit comments