diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index ab4b5084f1..4ce7da9009 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -27,7 +27,7 @@ import math from collections import OrderedDict from functools import partial -from typing import Callable, List, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, Dict, Optional, Sequence, Set, Tuple, Type, Union, Literal, List import torch import torch.nn as nn @@ -55,14 +55,14 @@ class Attention(nn.Module): def __init__( self, - dim, - num_heads=8, - qkv_bias=False, - qk_norm=False, - attn_drop=0., - proj_drop=0., - norm_layer=nn.LayerNorm, - ): + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads @@ -77,7 +77,7 @@ def __init__( self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) @@ -102,32 +102,36 @@ def forward(self, x): class LayerScale(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: super().__init__() self.inplace = inplace self.gamma = nn.Parameter(init_values * torch.ones(dim)) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma class Block(nn.Module): - def __init__( self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - qk_norm=False, - proj_drop=0., - attn_drop=0., - init_values=None, - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - mlp_layer=Mlp, - ): + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + init_values: Optional[float] = None, + drop_path: float = 0., + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( @@ -152,29 +156,28 @@ def __init__( self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x class ResPostBlock(nn.Module): - def __init__( self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - qk_norm=False, - proj_drop=0., - attn_drop=0., - init_values=None, - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - mlp_layer=Mlp, - ): + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + init_values: Optional[float] = None, + drop_path: float = 0., + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: super().__init__() self.init_values = init_values @@ -201,13 +204,13 @@ def __init__( self.init_weights() - def init_weights(self): + def init_weights(self) -> None: # NOTE this init overrides that base model init with specific changes for the block type if self.init_values is not None: nn.init.constant_(self.norm1.weight, self.init_values) nn.init.constant_(self.norm2.weight, self.init_values) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = x + self.drop_path1(self.norm1(self.attn(x))) x = x + self.drop_path2(self.norm2(self.mlp(x))) return x @@ -222,19 +225,19 @@ class ParallelScalingBlock(nn.Module): def __init__( self, - dim, - num_heads, - mlp_ratio=4., - qkv_bias=False, - qk_norm=False, - proj_drop=0., - attn_drop=0., - init_values=None, - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - mlp_layer=None, # NOTE: not used - ): + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + init_values: Optional[float] = None, + drop_path: float = 0., + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: Optional[nn.Module] = None, + ) -> None: super().__init__() assert dim % num_heads == 0, 'dim should be divisible by num_heads' self.num_heads = num_heads @@ -266,7 +269,7 @@ def __init__( self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: B, N, C = x.shape # Combined MLP fc1 & qkv projections @@ -315,20 +318,20 @@ class ParallelThingsBlock(nn.Module): """ def __init__( self, - dim, - num_heads, - num_parallel=2, - mlp_ratio=4., - qkv_bias=False, - qk_norm=False, - init_values=None, - proj_drop=0., - attn_drop=0., - drop_path=0., - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - mlp_layer=Mlp, - ): + dim: int, + num_heads: int, + num_parallel: int = 2, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + init_values: Optional[float] = None, + proj_drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: super().__init__() self.num_parallel = num_parallel self.attns = nn.ModuleList() @@ -360,18 +363,18 @@ def __init__( ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ]))) - def _forward_jit(self, x): + def _forward_jit(self, x: torch.Tensor) -> torch.Tensor: x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) return x @torch.jit.ignore - def _forward(self, x): + def _forward(self, x: torch.Tensor) -> torch.Tensor: x = x + sum(attn(x) for attn in self.attns) x = x + sum(ffn(x) for ffn in self.ffns) return x - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: if torch.jit.is_scripting() or torch.jit.is_tracing(): return self._forward_jit(x) else: @@ -392,7 +395,7 @@ def __init__( patch_size: Union[int, Tuple[int, int]] = 16, in_chans: int = 3, num_classes: int = 1000, - global_pool: str = 'token', + global_pool: Literal['', 'avg', 'token', 'map'] = 'token', embed_dim: int = 768, depth: int = 12, num_heads: int = 12, @@ -413,13 +416,13 @@ def __init__( proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0., - weight_init: str = '', + weight_init: Literal['skip', 'jax', 'jax_nlhb', 'moco', ''] = '', embed_layer: Callable = PatchEmbed, norm_layer: Optional[LayerType] = None, act_layer: Optional[LayerType] = None, block_fn: Type[nn.Module] = Block, mlp_layer: Type[nn.Module] = Mlp, - ): + ) -> None: """ Args: img_size: Input image size. @@ -530,7 +533,7 @@ def __init__( if weight_init != 'skip': self.init_weights(weight_init) - def init_weights(self, mode=''): + def init_weights(self, mode: Literal['jax', 'jax_nlhb', 'moco', ''] = '') -> None: assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. trunc_normal_(self.pos_embed, std=.02) @@ -538,34 +541,34 @@ def init_weights(self, mode=''): nn.init.normal_(self.cls_token, std=1e-6) named_apply(get_init_weights_vit(mode, head_bias), self) - def _init_weights(self, m): + def _init_weights(self, m: nn.Module) -> None: # this fn left here for compat with downstream users init_weights_vit_timm(m) @torch.jit.ignore() - def load_pretrained(self, checkpoint_path, prefix=''): + def load_pretrained(self, checkpoint_path: str, prefix: str = '') -> None: _load_weights(self, checkpoint_path, prefix) @torch.jit.ignore - def no_weight_decay(self): + def no_weight_decay(self) -> Set: return {'pos_embed', 'cls_token', 'dist_token'} @torch.jit.ignore - def group_matcher(self, coarse=False): + def group_matcher(self, coarse: bool = False) -> Dict: return dict( stem=r'^cls_token|pos_embed|patch_embed', # stem and embed blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] ) @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): + def set_grad_checkpointing(self, enable: bool = True) -> None: self.grad_checkpointing = enable @torch.jit.ignore - def get_classifier(self): + def get_classifier(self) -> nn.Module: return self.head - def reset_classifier(self, num_classes: int, global_pool=None): + def reset_classifier(self, num_classes: int, global_pool = None) -> None: self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg', 'token', 'map') @@ -576,7 +579,7 @@ def reset_classifier(self, num_classes: int, global_pool=None): self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - def _pos_embed(self, x): + def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: if self.dynamic_img_size: B, H, W, C = x.shape pos_embed = resample_abs_pos_embed( @@ -613,7 +616,7 @@ def _intermediate_layers( self, x: torch.Tensor, n: Union[int, Sequence] = 1, - ): + ) -> List[torch.Tensor]: outputs, num_blocks = [], len(self.blocks) take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n) @@ -658,7 +661,7 @@ def get_intermediate_layers( return tuple(zip(outputs, prefix_tokens)) return tuple(outputs) - def forward_features(self, x): + def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) x = self._pos_embed(x) x = self.patch_drop(x) @@ -670,7 +673,7 @@ def forward_features(self, x): x = self.norm(x) return x - def forward_head(self, x, pre_logits: bool = False): + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: if self.attn_pool is not None: x = self.attn_pool(x) elif self.global_pool == 'avg': @@ -681,13 +684,13 @@ def forward_head(self, x, pre_logits: bool = False): x = self.head_drop(x) return x if pre_logits else self.head(x) - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.forward_features(x) x = self.forward_head(x) return x -def init_weights_vit_timm(module: nn.Module, name: str = ''): +def init_weights_vit_timm(module: nn.Module, name: str = '') -> None: """ ViT weight initialization, original timm impl (for reproducibility) """ if isinstance(module, nn.Linear): trunc_normal_(module.weight, std=.02) @@ -697,7 +700,7 @@ def init_weights_vit_timm(module: nn.Module, name: str = ''): module.init_weights() -def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): +def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.0) -> None: """ ViT weight initialization, matching JAX (Flax) impl """ if isinstance(module, nn.Linear): if name.startswith('head'): @@ -715,7 +718,7 @@ def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0 module.init_weights() -def init_weights_vit_moco(module: nn.Module, name: str = ''): +def init_weights_vit_moco(module: nn.Module, name: str = '') -> None: """ ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """ if isinstance(module, nn.Linear): if 'qkv' in name: @@ -730,7 +733,7 @@ def init_weights_vit_moco(module: nn.Module, name: str = ''): module.init_weights() -def get_init_weights_vit(mode='jax', head_bias: float = 0.): +def get_init_weights_vit(mode: str = 'jax', head_bias: float = 0.0) -> None: if 'jax' in mode: return partial(init_weights_vit_jax, head_bias=head_bias) elif 'moco' in mode: @@ -740,13 +743,13 @@ def get_init_weights_vit(mode='jax', head_bias: float = 0.): def resize_pos_embed( - posemb, - posemb_new, - num_prefix_tokens=1, - gs_new=(), - interpolation='bicubic', - antialias=False, -): + posemb: torch.Tensor, + posemb_new: torch.Tensor, + num_prefix_tokens: int = 1, + gs_new: Tuple[int, int] = (), + interpolation: str = 'bicubic', + antialias: bool = False, +) -> torch.Tensor: """ Rescale the grid of position embeddings when loading from state_dict. *DEPRECATED* This function is being deprecated in favour of resample_abs_pos_embed @@ -773,7 +776,7 @@ def resize_pos_embed( @torch.no_grad() -def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = '') -> None: """ Load weights from .npz checkpoints for official Google Brain Flax implementation """ import numpy as np @@ -905,12 +908,23 @@ def _n2p(w, t=True): getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) -def _convert_openai_clip(state_dict, model, prefix='visual.'): +def _convert_openai_clip( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, + prefix: str = 'visual.', +) -> Dict[str, torch.Tensor]: out_dict = {} swaps = [ - ('conv1', 'patch_embed.proj'), ('positional_embedding', 'pos_embed'), - ('transformer.resblocks.', 'blocks.'), ('ln_pre', 'norm_pre'), ('ln_post', 'norm'), ('ln_', 'norm'), - ('in_proj_', 'qkv.'), ('out_proj', 'proj'), ('mlp.c_fc', 'mlp.fc1'), ('mlp.c_proj', 'mlp.fc2'), + ('conv1', 'patch_embed.proj'), + ('positional_embedding', 'pos_embed'), + ('transformer.resblocks.', 'blocks.'), + ('ln_pre', 'norm_pre'), + ('ln_post', 'norm'), + ('ln_', 'norm'), + ('in_proj_', 'qkv.'), + ('out_proj', 'proj'), + ('mlp.c_fc', 'mlp.fc1'), + ('mlp.c_proj', 'mlp.fc2'), ] for k, v in state_dict.items(): if not k.startswith(prefix): @@ -940,7 +954,10 @@ def _convert_openai_clip(state_dict, model, prefix='visual.'): return out_dict -def _convert_dinov2(state_dict, model): +def _convert_dinov2( + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, +) -> Dict[str, torch.Tensor]: import re out_dict = {} state_dict.pop("mask_token", None) @@ -961,12 +978,12 @@ def _convert_dinov2(state_dict, model): def checkpoint_filter_fn( - state_dict, - model, - adapt_layer_scale=False, - interpolation='bicubic', - antialias=True, -): + state_dict: Dict[str, torch.Tensor], + model: VisionTransformer, + adapt_layer_scale: bool = False, + interpolation: str = 'bicubic', + antialias: bool = True, +) -> Dict[str, torch.Tensor]: """ convert patch embedding weight from manual patchify + linear proj to conv""" import re out_dict = {} @@ -1031,17 +1048,22 @@ def checkpoint_filter_fn( return out_dict -def _cfg(url='', **kwargs): +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: return { 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', - **kwargs + 'num_classes': 1000, + 'input_size': (3, 224, 224), + 'pool_size': None, + 'crop_pct': 0.9, + 'interpolation': 'bicubic', + 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, + 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'patch_embed.proj', + 'classifier': 'head', + **kwargs, } - default_cfgs = { # re-finetuned augreg 21k FT on in1k weights @@ -1708,7 +1730,7 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs(default_cfgs) -def _create_vision_transformer(variant, pretrained=False, **kwargs): +def _create_vision_transformer(variant: str, pretrained: bool = False, **kwargs) -> VisionTransformer: if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') @@ -1735,7 +1757,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs): @register_model -def vit_tiny_patch16_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Tiny (Vit-Ti/16) """ model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) @@ -1744,7 +1766,7 @@ def vit_tiny_patch16_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_tiny_patch16_384(pretrained=False, **kwargs) -> VisionTransformer: +def vit_tiny_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Tiny (Vit-Ti/16) @ 384x384. """ model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) @@ -1753,7 +1775,7 @@ def vit_tiny_patch16_384(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_small_patch32_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_small_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Small (ViT-S/32) """ model_args = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6) @@ -1762,7 +1784,7 @@ def vit_small_patch32_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_small_patch32_384(pretrained=False, **kwargs) -> VisionTransformer: +def vit_small_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Small (ViT-S/32) at 384x384. """ model_args = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6) @@ -1771,7 +1793,7 @@ def vit_small_patch32_384(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_small_patch16_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Small (ViT-S/16) """ model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6) @@ -1780,7 +1802,7 @@ def vit_small_patch16_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_small_patch16_384(pretrained=False, **kwargs) -> VisionTransformer: +def vit_small_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Small (ViT-S/16) """ model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6) @@ -1789,7 +1811,7 @@ def vit_small_patch16_384(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_small_patch8_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_small_patch8_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Small (ViT-S/8) """ model_args = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6) @@ -1798,7 +1820,7 @@ def vit_small_patch8_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch32_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer. """ @@ -1808,7 +1830,7 @@ def vit_base_patch32_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch32_384(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ @@ -1818,7 +1840,7 @@ def vit_base_patch32_384(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ @@ -1828,7 +1850,7 @@ def vit_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch16_384(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ @@ -1838,7 +1860,7 @@ def vit_base_patch16_384(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch8_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch8_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ @@ -1848,7 +1870,7 @@ def vit_base_patch8_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_large_patch32_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_large_patch32_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. """ model_args = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16) @@ -1857,7 +1879,7 @@ def vit_large_patch32_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_large_patch32_384(pretrained=False, **kwargs) -> VisionTransformer: +def vit_large_patch32_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ @@ -1867,7 +1889,7 @@ def vit_large_patch32_384(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_large_patch16_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_large_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ @@ -1877,7 +1899,7 @@ def vit_large_patch16_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_large_patch16_384(pretrained=False, **kwargs) -> VisionTransformer: +def vit_large_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ @@ -1887,7 +1909,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_large_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_large_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/14) """ model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16) @@ -1896,7 +1918,7 @@ def vit_large_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_huge_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_huge_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). """ model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16) @@ -1905,7 +1927,7 @@ def vit_huge_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_giant_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_giant_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ model_args = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16) @@ -1914,7 +1936,7 @@ def vit_giant_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_gigantic_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_gigantic_patch14_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ model_args = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16) @@ -1924,7 +1946,7 @@ def vit_gigantic_patch14_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch16_224_miil(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_224_miil(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K """ @@ -1935,7 +1957,7 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_medium_patch16_gap_240(pretrained=False, **kwargs) -> VisionTransformer: +def vit_medium_patch16_gap_240(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 240x240 """ model_args = dict( @@ -1947,7 +1969,7 @@ def vit_medium_patch16_gap_240(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_medium_patch16_gap_256(pretrained=False, **kwargs) -> VisionTransformer: +def vit_medium_patch16_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 256x256 """ model_args = dict( @@ -1959,7 +1981,7 @@ def vit_medium_patch16_gap_256(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_medium_patch16_gap_384(pretrained=False, **kwargs) -> VisionTransformer: +def vit_medium_patch16_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Medium (ViT-M/16) w/o class token, w/ avg-pool @ 384x384 """ model_args = dict( @@ -1971,7 +1993,7 @@ def vit_medium_patch16_gap_384(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224 """ model_args = dict( @@ -1982,7 +2004,7 @@ def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_huge_patch14_gap_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_huge_patch14_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/14) w/ no class token, avg pool """ model_args = dict( @@ -1993,7 +2015,7 @@ def vit_huge_patch14_gap_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_huge_patch16_gap_448(pretrained=False, **kwargs) -> VisionTransformer: +def vit_huge_patch16_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/16) w/ no class token, avg pool @ 448x448 """ model_args = dict( @@ -2004,7 +2026,7 @@ def vit_huge_patch16_gap_448(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_giant_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_giant_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Giant (little-gg) model (ViT-g/16) w/ no class token, avg pool """ model_args = dict( @@ -2016,7 +2038,7 @@ def vit_giant_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch32_clip_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/32 CLIP image tower @ 224x224 """ model_args = dict( @@ -2027,7 +2049,7 @@ def vit_base_patch32_clip_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch32_clip_256(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch32_clip_256(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/32 CLIP image tower @ 256x256 """ model_args = dict( @@ -2038,7 +2060,7 @@ def vit_base_patch32_clip_256(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch32_clip_384(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch32_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/32 CLIP image tower @ 384x384 """ model_args = dict( @@ -2049,7 +2071,7 @@ def vit_base_patch32_clip_384(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch32_clip_448(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch32_clip_448(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/32 CLIP image tower @ 448x448 """ model_args = dict( @@ -2060,7 +2082,7 @@ def vit_base_patch32_clip_448(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch16_clip_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/16 CLIP image tower """ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) @@ -2070,7 +2092,7 @@ def vit_base_patch16_clip_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch16_clip_384(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_clip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/16 CLIP image tower @ 384x384 """ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) @@ -2080,7 +2102,7 @@ def vit_base_patch16_clip_384(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_large_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_large_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/14) CLIP image tower """ model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) @@ -2090,7 +2112,7 @@ def vit_large_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_large_patch14_clip_336(pretrained=False, **kwargs) -> VisionTransformer: +def vit_large_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 """ model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) @@ -2100,7 +2122,7 @@ def vit_large_patch14_clip_336(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_huge_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_huge_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/14) CLIP image tower. """ model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) @@ -2110,7 +2132,7 @@ def vit_huge_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_huge_patch14_clip_336(pretrained=False, **kwargs) -> VisionTransformer: +def vit_huge_patch14_clip_336(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336 """ model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) @@ -2120,7 +2142,7 @@ def vit_huge_patch14_clip_336(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_huge_patch14_clip_378(pretrained=False, **kwargs) -> VisionTransformer: +def vit_huge_patch14_clip_378(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 """ model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) @@ -2130,7 +2152,7 @@ def vit_huge_patch14_clip_378(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_giant_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_giant_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 Pretrained weights from CLIP image tower. """ @@ -2142,7 +2164,7 @@ def vit_giant_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_gigantic_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_gigantic_patch14_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 Pretrained weights from CLIP image tower. """ @@ -2154,7 +2176,7 @@ def vit_gigantic_patch14_clip_224(pretrained=False, **kwargs) -> VisionTransform @register_model -def vit_base_patch32_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch32_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/32 CLIP image tower @ 224x224 """ model_args = dict( @@ -2166,7 +2188,7 @@ def vit_base_patch32_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTra @register_model -def vit_base_patch16_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/16 CLIP image tower w/ QuickGELU act """ model_args = dict( @@ -2178,7 +2200,7 @@ def vit_base_patch16_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTra @register_model -def vit_large_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_large_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/14) CLIP image tower w/ QuickGELU act """ from timm.layers import get_act_layer @@ -2191,7 +2213,7 @@ def vit_large_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTr @register_model -def vit_large_patch14_clip_quickgelu_336(pretrained=False, **kwargs) -> VisionTransformer: +def vit_large_patch14_clip_quickgelu_336(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 w/ QuickGELU act """ model_args = dict( @@ -2203,7 +2225,7 @@ def vit_large_patch14_clip_quickgelu_336(pretrained=False, **kwargs) -> VisionTr @register_model -def vit_huge_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_huge_patch14_clip_quickgelu_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/14) CLIP image tower w/ QuickGELU act. """ model_args = dict( @@ -2215,7 +2237,7 @@ def vit_huge_patch14_clip_quickgelu_224(pretrained=False, **kwargs) -> VisionTra @register_model -def vit_huge_patch14_clip_quickgelu_378(pretrained=False, **kwargs) -> VisionTransformer: +def vit_huge_patch14_clip_quickgelu_378(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/14) CLIP image tower @ 378x378 w/ QuickGELU act """ model_args = dict( @@ -2229,7 +2251,7 @@ def vit_huge_patch14_clip_quickgelu_378(pretrained=False, **kwargs) -> VisionTra # Experimental models below @register_model -def vit_base_patch32_plus_256(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch32_plus_256(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Base (ViT-B/32+) """ model_args = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5) @@ -2239,7 +2261,7 @@ def vit_base_patch32_plus_256(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch16_plus_240(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_plus_240(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Base (ViT-B/16+) """ model_args = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5) @@ -2249,7 +2271,7 @@ def vit_base_patch16_plus_240(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch16_rpn_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_rpn_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Base (ViT-B/16) w/ residual post-norm """ model_args = dict( @@ -2261,7 +2283,7 @@ def vit_base_patch16_rpn_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_small_patch16_36x1_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_small_patch16_36x1_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. @@ -2273,7 +2295,7 @@ def vit_small_patch16_36x1_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_small_patch16_18x2_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_small_patch16_18x2_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. @@ -2286,7 +2308,7 @@ def vit_small_patch16_18x2_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch16_18x2_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_18x2_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 """ @@ -2298,7 +2320,7 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def eva_large_patch14_196(pretrained=False, **kwargs) -> VisionTransformer: +def eva_large_patch14_196(pretrained: bool = False, **kwargs) -> VisionTransformer: """ EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain""" model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg') model = _create_vision_transformer( @@ -2307,7 +2329,7 @@ def eva_large_patch14_196(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def eva_large_patch14_336(pretrained=False, **kwargs) -> VisionTransformer: +def eva_large_patch14_336(pretrained: bool = False, **kwargs) -> VisionTransformer: """ EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain""" model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg') model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs)) @@ -2315,7 +2337,7 @@ def eva_large_patch14_336(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def flexivit_small(pretrained=False, **kwargs) -> VisionTransformer: +def flexivit_small(pretrained: bool = False, **kwargs) -> VisionTransformer: """ FlexiViT-Small """ model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True) @@ -2324,7 +2346,7 @@ def flexivit_small(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def flexivit_base(pretrained=False, **kwargs) -> VisionTransformer: +def flexivit_base(pretrained: bool = False, **kwargs) -> VisionTransformer: """ FlexiViT-Base """ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True) @@ -2333,7 +2355,7 @@ def flexivit_base(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def flexivit_large(pretrained=False, **kwargs) -> VisionTransformer: +def flexivit_large(pretrained: bool = False, **kwargs) -> VisionTransformer: """ FlexiViT-Large """ model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True) @@ -2342,7 +2364,7 @@ def flexivit_large(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch16_xp_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled. """ model_args = dict( @@ -2355,7 +2377,7 @@ def vit_base_patch16_xp_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_large_patch14_xp_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_large_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled. """ model_args = dict( @@ -2368,7 +2390,7 @@ def vit_large_patch14_xp_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_huge_patch14_xp_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_huge_patch14_xp_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-Huge model (ViT-H/14) w/ parallel blocks and qk norm enabled. """ model_args = dict( @@ -2381,7 +2403,7 @@ def vit_huge_patch14_xp_224(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_small_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: +def vit_small_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-S/14 for DINOv2 """ model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518) @@ -2391,7 +2413,7 @@ def vit_small_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_base_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/14 for DINOv2 """ model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518) @@ -2401,7 +2423,7 @@ def vit_base_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_large_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: +def vit_large_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-L/14 for DINOv2 """ model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518) @@ -2411,7 +2433,7 @@ def vit_large_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: +def vit_giant_patch14_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-G/14 for DINOv2 """ # The hidden_features of SwiGLU is calculated by: @@ -2428,7 +2450,7 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: @register_model -def vit_small_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer: +def vit_small_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-S/14 for DINOv2 w/ 4 registers """ model_args = dict( @@ -2441,7 +2463,7 @@ def vit_small_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransform @register_model -def vit_base_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/14 for DINOv2 w/ 4 registers """ model_args = dict( @@ -2454,7 +2476,7 @@ def vit_base_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransforme @register_model -def vit_large_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer: +def vit_large_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-L/14 for DINOv2 w/ 4 registers """ model_args = dict( @@ -2467,7 +2489,7 @@ def vit_large_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransform @register_model -def vit_giant_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer: +def vit_giant_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-G/14 for DINOv2 """ # The hidden_features of SwiGLU is calculated by: @@ -2484,7 +2506,7 @@ def vit_giant_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransform @register_model -def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_siglip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', ) @@ -2494,7 +2516,7 @@ def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer @register_model -def vit_base_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', ) @@ -2504,7 +2526,7 @@ def vit_base_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer @register_model -def vit_base_patch16_siglip_384(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', ) @@ -2514,7 +2536,7 @@ def vit_base_patch16_siglip_384(pretrained=False, **kwargs) -> VisionTransformer @register_model -def vit_base_patch16_siglip_512(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map', ) @@ -2524,7 +2546,7 @@ def vit_base_patch16_siglip_512(pretrained=False, **kwargs) -> VisionTransformer @register_model -def vit_large_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer: +def vit_large_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map', ) @@ -2534,7 +2556,7 @@ def vit_large_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransforme @register_model -def vit_large_patch16_siglip_384(pretrained=False, **kwargs) -> VisionTransformer: +def vit_large_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map', ) @@ -2544,7 +2566,7 @@ def vit_large_patch16_siglip_384(pretrained=False, **kwargs) -> VisionTransforme @register_model -def vit_so400m_patch14_siglip_224(pretrained=False, **kwargs) -> VisionTransformer: +def vit_so400m_patch14_siglip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map', ) @@ -2554,7 +2576,7 @@ def vit_so400m_patch14_siglip_224(pretrained=False, **kwargs) -> VisionTransform @register_model -def vit_so400m_patch14_siglip_384(pretrained=False, **kwargs) -> VisionTransformer: +def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map', ) @@ -2564,7 +2586,7 @@ def vit_so400m_patch14_siglip_384(pretrained=False, **kwargs) -> VisionTransform @register_model -def vit_medium_patch16_reg4_256(pretrained=False, **kwargs) -> VisionTransformer: +def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=True, no_embed_class=True, reg_tokens=4, @@ -2575,7 +2597,7 @@ def vit_medium_patch16_reg4_256(pretrained=False, **kwargs) -> VisionTransformer @register_model -def vit_medium_patch16_reg4_gap_256(pretrained=False, **kwargs) -> VisionTransformer: +def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg', @@ -2586,7 +2608,7 @@ def vit_medium_patch16_reg4_gap_256(pretrained=False, **kwargs) -> VisionTransfo @register_model -def vit_base_patch16_reg8_gap_256(pretrained=False, **kwargs) -> VisionTransformer: +def vit_base_patch16_reg8_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, no_embed_class=True, global_pool='avg', reg_tokens=8,