Skip to content

Commit 1e5d96d

Browse files
committed
fix: mnv5 with conv_stem bia and GELU approx tanh
1 parent 446e8a8 commit 1e5d96d

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

timm/models/mobilenetv5.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@
2828
)
2929
from ._features import feature_take_indices
3030
from ._features_fx import register_notrace_module
31-
from ._manipulate import checkpoint_seq, checkpoint
31+
from ._manipulate import checkpoint_seq
3232
from ._registry import generate_default_cfgs, register_model
3333

3434
__all__ = ['MobileNetV5', 'MobileNetV5Encoder']
3535

36+
_GELU = partial(nn.GELU, approximate='tanh')
37+
3638

3739
@register_notrace_module
3840
class MobileNetV5MultiScaleFusionAdapter(nn.Module):
@@ -68,7 +70,7 @@ def __init__(
6870
self.layer_scale_init_value = layer_scale_init_value
6971
self.noskip = noskip
7072

71-
act_layer = act_layer or nn.GELU
73+
act_layer = act_layer or _GELU
7274
norm_layer = norm_layer or RmsNorm2d
7375
self.ffn = UniversalInvertedResidual(
7476
in_chs=self.in_channels,
@@ -167,7 +169,7 @@ def __init__(
167169
global_pool: Type of pooling to use for global pooling features of the FC head.
168170
"""
169171
super().__init__()
170-
act_layer = act_layer or nn.GELU
172+
act_layer = act_layer or _GELU
171173
norm_layer = get_norm_layer(norm_layer) or RmsNorm2d
172174
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
173175
se_layer = se_layer or SqueezeExcite
@@ -410,7 +412,7 @@ def __init__(
410412
block_args: BlockArgs,
411413
in_chans: int = 3,
412414
stem_size: int = 64,
413-
stem_bias: bool = False,
415+
stem_bias: bool = True,
414416
fix_stem: bool = False,
415417
pad_type: str = '',
416418
msfa_indices: Sequence[int] = (-2, -1),
@@ -426,7 +428,7 @@ def __init__(
426428
layer_scale_init_value: Optional[float] = None,
427429
):
428430
super().__init__()
429-
act_layer = act_layer or nn.GELU
431+
act_layer = act_layer or _GELU
430432
norm_layer = get_norm_layer(norm_layer) or RmsNorm2d
431433
se_layer = se_layer or SqueezeExcite
432434
self.num_classes = 0 # Exists to satisfy ._hub module APIs.
@@ -526,6 +528,7 @@ def forward_intermediates(
526528
feat_idx = 0 # stem is index 0
527529
x = self.conv_stem(x)
528530
if feat_idx in take_indices:
531+
print("conv_stem is captured")
529532
intermediates.append(x)
530533
if feat_idx in self.msfa_indices:
531534
msfa_intermediates.append(x)
@@ -537,9 +540,16 @@ def forward_intermediates(
537540

538541
for blk in blocks:
539542
feat_idx += 1
540-
x = blk(x)
541-
if feat_idx in take_indices:
542-
intermediates.append(x)
543+
# DO NOT SUBMIT: Revert to only the else condition after verification.
544+
if isinstance(blk, nn.Sequential):
545+
for subblk in blk:
546+
x = subblk(x)
547+
if feat_idx in take_indices:
548+
intermediates.append(x)
549+
else:
550+
x = blk(x)
551+
if feat_idx in take_indices:
552+
intermediates.append(x)
543553
if feat_idx in self.msfa_indices:
544554
msfa_intermediates.append(x)
545555

@@ -777,7 +787,7 @@ def _gen_mobilenet_v5(
777787
fix_stem=channel_multiplier < 1.0,
778788
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
779789
norm_layer=RmsNorm2d,
780-
act_layer=nn.GELU,
790+
act_layer=_GELU,
781791
layer_scale_init_value=1e-5,
782792
)
783793
model_kwargs = dict(model_kwargs, **kwargs)

0 commit comments

Comments
 (0)