Skip to content

fix: mnv5 conv_stem bias and GELU with approximate=tanh #2533

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 7, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions timm/models/mobilenetv5.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
)
from ._features import feature_take_indices
from ._features_fx import register_notrace_module
from ._manipulate import checkpoint_seq, checkpoint
from ._manipulate import checkpoint_seq
from ._registry import generate_default_cfgs, register_model

__all__ = ['MobileNetV5', 'MobileNetV5Encoder']

_GELU = partial(nn.GELU, approximate='tanh')


@register_notrace_module
class MobileNetV5MultiScaleFusionAdapter(nn.Module):
Expand Down Expand Up @@ -68,7 +70,7 @@ def __init__(
self.layer_scale_init_value = layer_scale_init_value
self.noskip = noskip

act_layer = act_layer or nn.GELU
act_layer = act_layer or _GELU
norm_layer = norm_layer or RmsNorm2d
self.ffn = UniversalInvertedResidual(
in_chs=self.in_channels,
Expand Down Expand Up @@ -167,7 +169,7 @@ def __init__(
global_pool: Type of pooling to use for global pooling features of the FC head.
"""
super().__init__()
act_layer = act_layer or nn.GELU
act_layer = act_layer or _GELU
norm_layer = get_norm_layer(norm_layer) or RmsNorm2d
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
se_layer = se_layer or SqueezeExcite
Expand Down Expand Up @@ -410,7 +412,7 @@ def __init__(
block_args: BlockArgs,
in_chans: int = 3,
stem_size: int = 64,
stem_bias: bool = False,
stem_bias: bool = True,
fix_stem: bool = False,
pad_type: str = '',
msfa_indices: Sequence[int] = (-2, -1),
Expand All @@ -426,7 +428,7 @@ def __init__(
layer_scale_init_value: Optional[float] = None,
):
super().__init__()
act_layer = act_layer or nn.GELU
act_layer = act_layer or _GELU
norm_layer = get_norm_layer(norm_layer) or RmsNorm2d
se_layer = se_layer or SqueezeExcite
self.num_classes = 0 # Exists to satisfy ._hub module APIs.
Expand Down Expand Up @@ -526,6 +528,7 @@ def forward_intermediates(
feat_idx = 0 # stem is index 0
x = self.conv_stem(x)
if feat_idx in take_indices:
print("conv_stem is captured")
intermediates.append(x)
if feat_idx in self.msfa_indices:
msfa_intermediates.append(x)
Expand Down Expand Up @@ -777,7 +780,7 @@ def _gen_mobilenet_v5(
fix_stem=channel_multiplier < 1.0,
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=RmsNorm2d,
act_layer=nn.GELU,
act_layer=_GELU,
layer_scale_init_value=1e-5,
)
model_kwargs = dict(model_kwargs, **kwargs)
Expand Down