Skip to content

Commit 9b9b443

Browse files
committed
typing
1 parent d3a077c commit 9b9b443

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

src/model_constructor/convmixer.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,29 @@
33
# Adopted from https://github.com/tmp-iclr/convmixer
44
# Home for convmixer: https://github.com/locuslab/convmixer
55
from collections import OrderedDict
6-
from typing import Callable, Optional, Union
6+
from typing import Callable, List, Optional, Union
7+
78
import torch.nn as nn
9+
from torch import TensorType
810

911

1012
class Residual(nn.Module):
11-
def __init__(self, fn):
13+
def __init__(self, fn: Callable[[TensorType], TensorType]):
1214
super().__init__()
1315
self.fn = fn
1416

15-
def forward(self, x):
17+
def forward(self, x: TensorType) -> TensorType:
1618
return self.fn(x) + x
1719

1820

1921
# As original version, act_fn as argument.
2022
def ConvMixerOriginal(
21-
dim, depth, kernel_size=9, patch_size=7, n_classes=1000, act_fn=nn.GELU()
23+
dim: int,
24+
depth: int,
25+
kernel_size: int = 9,
26+
patch_size: int = 7,
27+
n_classes: int = 1000,
28+
act_fn: nn.Module = nn.GELU(),
2229
):
2330
return nn.Sequential(
2431
nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
@@ -61,7 +68,7 @@ def __init__(
6168
pre_act: bool = False,
6269
):
6370

64-
conv_layer = [
71+
conv_layer: List[tuple[str, nn.Module]] = [
6572
(
6673
"conv",
6774
nn.Conv2d(
@@ -74,7 +81,10 @@ def __init__(
7481
),
7582
)
7683
]
77-
act_bn = [("act_fn", act_fn), ("bn", nn.BatchNorm2d(out_channels))]
84+
act_bn: List[tuple[str, nn.Module]] = [
85+
("act_fn", act_fn),
86+
("bn", nn.BatchNorm2d(out_channels)),
87+
]
7888
if bn_1st:
7989
act_bn.reverse()
8090
if pre_act:

0 commit comments

Comments
 (0)