Skip to content

Commit 32e0785

Browse files
authored
[Typing][C-83] Add type annotations for python/paddle/incubate/nn/layer/fused_ec_moe.py (#67143)
* fix * fix
1 parent 8305020 commit 32e0785

File tree

1 file changed

+29
-10
lines changed

1 file changed

+29
-10
lines changed

python/paddle/incubate/nn/layer/fused_ec_moe.py

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,24 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from typing import TYPE_CHECKING, Literal
17+
18+
from typing_extensions import TypeAlias
1419

1520
from paddle.incubate.nn import functional as F
1621
from paddle.nn import Layer
1722

23+
if TYPE_CHECKING:
24+
from paddle import Tensor
25+
from paddle._typing import ParamAttrLike
26+
27+
_ActTypeLiteral: TypeAlias = Literal[
28+
'gelu',
29+
'relu',
30+
]
31+
1832

1933
class FusedEcMoe(Layer):
2034
r"""A FusedEcMoe Layer.
@@ -24,11 +38,11 @@ class FusedEcMoe(Layer):
2438
inter_size (int): The dim size of feed forward network.
2539
num_expert (int): The number of experts.
2640
act_type (string): The activation type. Currently only support `gelu`, `relu`.
27-
weight_attr (ParamAttr, optional): The attribute for the learnable
41+
weight_attr (ParamAttr|None, optional): The attribute for the learnable
2842
weight of this layer. The default value is None and the weight will be
2943
initialized to zero. For detailed information, please refer to
3044
paddle.ParamAttr.
31-
bias_attr (ParamAttr|bool, optional): The attribute for the learnable bias
45+
bias_attr (ParamAttr|bool|None, optional): The attribute for the learnable bias
3246
of this layer. If it is set to False, no bias will be added to the output.
3347
If it is set to None or one kind of ParamAttr, a bias parameter will
3448
be created according to ParamAttr. For detailed information, please refer
@@ -58,16 +72,21 @@ class FusedEcMoe(Layer):
5872
>>> print(y.shape)
5973
[10, 128, 1024]
6074
"""
75+
bmm_weight0: Tensor
76+
bmm_bias0: Tensor
77+
bmm_weight1: Tensor
78+
bmm_bias1: Tensor
79+
act_type: _ActTypeLiteral
6180

6281
def __init__(
6382
self,
64-
hidden_size,
65-
inter_size,
66-
num_experts,
67-
act_type,
68-
weight_attr=None,
69-
bias_attr=None,
70-
):
83+
hidden_size: int,
84+
inter_size: int,
85+
num_experts: int,
86+
act_type: _ActTypeLiteral,
87+
weight_attr: ParamAttrLike | None = None,
88+
bias_attr: ParamAttrLike | None = None,
89+
) -> None:
7190
super().__init__()
7291
weight0_shape = [num_experts, hidden_size, inter_size]
7392
bias0_shape = [num_experts, 1, inter_size]
@@ -91,7 +110,7 @@ def __init__(
91110
if self.act_type not in ["gelu", "relu"]:
92111
raise NotImplementedError("Currently only support `gelu`, `relu`. ")
93112

94-
def forward(self, x, gate):
113+
def forward(self, x: Tensor, gate: Tensor) -> Tensor:
95114
return F.fused_ec_moe(
96115
x,
97116
gate,

0 commit comments

Comments
 (0)