8
8
import onnxslim
9
9
from models import *
10
10
11
+
11
12
class bcolors :
12
13
HEADER = '\033 [95m'
13
14
OKBLUE = '\033 [94m'
@@ -27,6 +28,7 @@ class bcolors:
27
28
'yolov4-tiny' : 'https://github.com/AlexeyAB/darknet/releases/download/yolov4/yolov4-tiny.weights' ,
28
29
}
29
30
31
+
30
32
def download_if_not_exist (model_type : str , filepath : str ):
31
33
if not os .path .exists (filepath ):
32
34
torch .hub .download_url_to_file (weight_paths [model_type ], filepath )
@@ -39,6 +41,19 @@ def swap_convs(cv1, cv2):
39
41
cv2 .load_state_dict (state1 )
40
42
41
43
44
+ def fuse_bias_v12 (cv : nn .Conv2d , bn : nn .BatchNorm2d ):
45
+ if exists (cv .bias ):
46
+ b_conv = cv .bias .data
47
+ gamma = bn .weight .data
48
+ beta = bn .bias .data
49
+ mean = bn .running_mean
50
+ var = bn .running_var
51
+ eps = bn .eps
52
+ bn .bias .data = beta + gamma * (b_conv - mean ) / torch .sqrt (var + eps )
53
+ cv .bias = None # PyTorch requires explicit removal
54
+ cv .register_parameter ('bias' , None )
55
+
56
+
42
57
def load_from_darknet (net : Union [Yolov3 , Yolov3Tiny , Yolov4 , Yolov4Tiny ], weights_path : str ):
43
58
44
59
def params (net ):
@@ -105,6 +120,7 @@ def params(net):
105
120
106
121
def load_from_ultralytics (net : Union [Yolov5 , Yolov8 , Yolov10 , Yolov11 ]):
107
122
from ultralytics import YOLO
123
+ from ultralytics .nn .modules .block import AAttn
108
124
109
125
if isinstance (net , Yolov5 ):
110
126
net2 = YOLO ('yolov5{}u.pt' .format (net .v )).model .eval ()
@@ -118,6 +134,12 @@ def load_from_ultralytics(net: Union[Yolov5, Yolov8, Yolov10, Yolov11]):
118
134
elif isinstance (net , Yolov11 ):
119
135
net2 = YOLO ('yolo11{}.pt' .format (net .v )).model .eval ()
120
136
l0 ,l1 = 11 ,23
137
+ elif isinstance (net , Yolov12 ):
138
+ net2 = YOLO ('yolo12{}.pt' .format (net .v )).model .eval ()
139
+ l0 ,l1 = 9 ,21
140
+ for module in net2 .modules ():
141
+ if isinstance (module , AAttn ):
142
+ fuse_bias_v12 (module .pe .conv , module .pe .bn )
121
143
122
144
assert (nP1 := count_parameters (net )) == (nP2 := count_parameters (net2 )), f'wrong number of parameters net { nP1 } vs ultralytics { nP2 } '
123
145
copy_params (net .net , net2 .model [0 :l0 ])
@@ -169,7 +191,7 @@ def params(n):
169
191
170
192
def load_from_yolov7_official (net : Yolov7 , weights_pt : str ):
171
193
def params1 ():
172
- for l in net .layers () :
194
+ for l in [ net .net , net . fpn , net . head . cv [ 0 ][ 0 ], net . head . cv [ 1 ][ 0 ], net . head . cv [ 2 ][ 0 ], net . head . cv [ 0 ][ 1 ], net . head . cv [ 1 ][ 1 ], net . head . cv [ 2 ][ 1 ]] :
173
195
for k , v in l .state_dict ().items ():
174
196
if 'anchor' not in k :
175
197
yield v
@@ -206,6 +228,7 @@ def get_model(model: str, variant: str = ''):
206
228
case 'yolov8' : net = Yolov8 (variant , 80 ).eval ()
207
229
case 'yolov10' : net = Yolov10 (variant , 80 ).eval ()
208
230
case 'yolov11' : net = Yolov11 (variant , 80 ).eval ()
231
+ case 'yolov12' : net = Yolov12 (variant , 80 ).eval ()
209
232
210
233
print (f"{ model } { variant } has { count_parameters (net )} parameters" )
211
234
@@ -217,7 +240,7 @@ def get_model(model: str, variant: str = ''):
217
240
download_if_not_exist (model , filepath )
218
241
load_from_darknet (net , filepath )
219
242
220
- if model in ['yolov5' , 'yolov8' , 'yolov10' , 'yolov11' ]:
243
+ if model in ['yolov5' , 'yolov8' , 'yolov10' , 'yolov11' , 'yolov12' ]:
221
244
load_from_ultralytics (net )
222
245
has_obj = False
223
246
@@ -321,6 +344,11 @@ def export(model: str, variant: str = '', slim=True):
321
344
test ('yolov11' , 'm' )
322
345
test ('yolov11' , 'l' )
323
346
test ('yolov11' , 'x' )
347
+ test ('yolov12' , 'n' )
348
+ test ('yolov12' , 's' )
349
+ test ('yolov12' , 'm' )
350
+ test ('yolov12' , 'l' )
351
+ test ('yolov12' , 'x' )
324
352
325
353
# export('yolov3')
326
354
# export('yolov3-spp')
@@ -351,4 +379,9 @@ def export(model: str, variant: str = '', slim=True):
351
379
# export('yolov11', 's')
352
380
# export('yolov11', 'm')
353
381
# export('yolov11', 'l')
354
- # export('yolov11', 'x')
382
+ # export('yolov11', 'x')
383
+ # export('yolov12', 'n')
384
+ # export('yolov12', 's')
385
+ # export('yolov12', 'm')
386
+ # export('yolov12', 'l')
387
+ # export('yolov12', 'x')
0 commit comments