Skip to content

Commit 0bae8f2

Browse files
committed
fix: mnv5 GELU with approximate=tanh
1 parent cd57234 commit 0bae8f2

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

timm/models/mobilenetv5.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,21 @@
77

88
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
99
from timm.layers import (
10-
SelectAdaptivePool2d, Linear, LayerType, PadType, RmsNorm2d, ConvNormAct, create_conv2d, get_norm_act_layer,
11-
to_2tuple
10+
SelectAdaptivePool2d, Linear, LayerType, RmsNorm2d, ConvNormAct, create_conv2d, get_norm_act_layer, to_2tuple
1211
)
1312
from ._builder import build_model_with_cfg
1413
from ._efficientnet_blocks import SqueezeExcite, UniversalInvertedResidual
1514
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
16-
round_channels, resolve_act_layer
15+
round_channels
1716
from ._features import feature_take_indices
1817
from ._features_fx import register_notrace_module
19-
from ._manipulate import checkpoint_seq, checkpoint
18+
from ._manipulate import checkpoint_seq
2019
from ._registry import generate_default_cfgs, register_model
2120

2221
__all__ = ['MobileNetV5', 'MobileNetV5Encoder']
2322

23+
_GELU = partial(nn.GELU, approximate='tanh')
24+
2425

2526
@register_notrace_module
2627
class MobileNetV5MultiScaleFusionAdapter(nn.Module):
@@ -56,7 +57,7 @@ def __init__(
5657
self.layer_scale_init_value = layer_scale_init_value
5758
self.noskip = noskip
5859

59-
act_layer = act_layer or nn.GELU
60+
act_layer = act_layer or _GELU
6061
norm_layer = norm_layer or RmsNorm2d
6162
self.ffn = UniversalInvertedResidual(
6263
in_chs=self.in_channels,
@@ -154,7 +155,7 @@ def __init__(
154155
global_pool: Type of pooling to use for global pooling features of the FC head.
155156
"""
156157
super().__init__()
157-
act_layer = act_layer or nn.GELU
158+
act_layer = act_layer or _GELU
158159
norm_layer = norm_layer or RmsNorm2d
159160
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
160161
se_layer = se_layer or SqueezeExcite
@@ -411,7 +412,7 @@ def __init__(
411412
layer_scale_init_value: Optional[float] = None,
412413
):
413414
super().__init__()
414-
act_layer = act_layer or nn.GELU
415+
act_layer = act_layer or _GELU
415416
norm_layer = norm_layer or RmsNorm2d
416417
se_layer = se_layer or SqueezeExcite
417418
self.num_classes = 0 # Exists to satisfy ._hub module APIs.
@@ -761,7 +762,7 @@ def _gen_mobilenet_v5(
761762
fix_stem=channel_multiplier < 1.0,
762763
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
763764
norm_layer=RmsNorm2d,
764-
act_layer=nn.GELU,
765+
act_layer=_GELU,
765766
layer_scale_init_value=1e-5,
766767
)
767768
model_kwargs = dict(model_kwargs, **kwargs)

0 commit comments

Comments
 (0)