diff --git a/timm/models/mobilenetv5.py b/timm/models/mobilenetv5.py index acc37ad32b..eca210b24d 100644 --- a/timm/models/mobilenetv5.py +++ b/timm/models/mobilenetv5.py @@ -28,11 +28,13 @@ ) from ._features import feature_take_indices from ._features_fx import register_notrace_module -from ._manipulate import checkpoint_seq, checkpoint +from ._manipulate import checkpoint_seq from ._registry import generate_default_cfgs, register_model __all__ = ['MobileNetV5', 'MobileNetV5Encoder'] +_GELU = partial(nn.GELU, approximate='tanh') + @register_notrace_module class MobileNetV5MultiScaleFusionAdapter(nn.Module): @@ -68,7 +70,7 @@ def __init__( self.layer_scale_init_value = layer_scale_init_value self.noskip = noskip - act_layer = act_layer or nn.GELU + act_layer = act_layer or _GELU norm_layer = norm_layer or RmsNorm2d self.ffn = UniversalInvertedResidual( in_chs=self.in_channels, @@ -167,7 +169,7 @@ def __init__( global_pool: Type of pooling to use for global pooling features of the FC head. """ super().__init__() - act_layer = act_layer or nn.GELU + act_layer = act_layer or _GELU norm_layer = get_norm_layer(norm_layer) or RmsNorm2d norm_act_layer = get_norm_act_layer(norm_layer, act_layer) se_layer = se_layer or SqueezeExcite @@ -410,7 +412,7 @@ def __init__( block_args: BlockArgs, in_chans: int = 3, stem_size: int = 64, - stem_bias: bool = False, + stem_bias: bool = True, fix_stem: bool = False, pad_type: str = '', msfa_indices: Sequence[int] = (-2, -1), @@ -426,7 +428,7 @@ def __init__( layer_scale_init_value: Optional[float] = None, ): super().__init__() - act_layer = act_layer or nn.GELU + act_layer = act_layer or _GELU norm_layer = get_norm_layer(norm_layer) or RmsNorm2d se_layer = se_layer or SqueezeExcite self.num_classes = 0 # Exists to satisfy ._hub module APIs. @@ -526,6 +528,7 @@ def forward_intermediates( feat_idx = 0 # stem is index 0 x = self.conv_stem(x) if feat_idx in take_indices: + print("conv_stem is captured") intermediates.append(x) if feat_idx in self.msfa_indices: msfa_intermediates.append(x) @@ -777,7 +780,7 @@ def _gen_mobilenet_v5( fix_stem=channel_multiplier < 1.0, round_chs_fn=partial(round_channels, multiplier=channel_multiplier), norm_layer=RmsNorm2d, - act_layer=nn.GELU, + act_layer=_GELU, layer_scale_init_value=1e-5, ) model_kwargs = dict(model_kwargs, **kwargs)