Skip to content

Commit aa25ae2

Browse files
author
pfeatherstone
committed
test yolov11 and yolov12 models
1 parent 938b83a commit aa25ae2

File tree

1 file changed

+36
-3
lines changed

1 file changed

+36
-3
lines changed

src/test.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import onnxslim
99
from models import *
1010

11+
1112
class bcolors:
1213
HEADER = '\033[95m'
1314
OKBLUE = '\033[94m'
@@ -27,6 +28,7 @@ class bcolors:
2728
'yolov4-tiny' : 'https://github.com/AlexeyAB/darknet/releases/download/yolov4/yolov4-tiny.weights',
2829
}
2930

31+
3032
def download_if_not_exist(model_type: str, filepath: str):
3133
if not os.path.exists(filepath):
3234
torch.hub.download_url_to_file(weight_paths[model_type], filepath)
@@ -39,6 +41,19 @@ def swap_convs(cv1, cv2):
3941
cv2.load_state_dict(state1)
4042

4143

44+
def fuse_bias_v12(cv: nn.Conv2d, bn: nn.BatchNorm2d):
45+
if exists(cv.bias):
46+
b_conv = cv.bias.data
47+
gamma = bn.weight.data
48+
beta = bn.bias.data
49+
mean = bn.running_mean
50+
var = bn.running_var
51+
eps = bn.eps
52+
bn.bias.data = beta + gamma * (b_conv - mean) / torch.sqrt(var + eps)
53+
cv.bias = None # PyTorch requires explicit removal
54+
cv.register_parameter('bias', None)
55+
56+
4257
def load_from_darknet(net: Union[Yolov3, Yolov3Tiny, Yolov4, Yolov4Tiny], weights_path: str):
4358

4459
def params(net):
@@ -105,6 +120,7 @@ def params(net):
105120

106121
def load_from_ultralytics(net: Union[Yolov5, Yolov8, Yolov10, Yolov11]):
107122
from ultralytics import YOLO
123+
from ultralytics.nn.modules.block import AAttn
108124

109125
if isinstance(net, Yolov5):
110126
net2 = YOLO('yolov5{}u.pt'.format(net.v)).model.eval()
@@ -118,6 +134,12 @@ def load_from_ultralytics(net: Union[Yolov5, Yolov8, Yolov10, Yolov11]):
118134
elif isinstance(net, Yolov11):
119135
net2 = YOLO('yolo11{}.pt'.format(net.v)).model.eval()
120136
l0,l1 = 11,23
137+
elif isinstance(net, Yolov12):
138+
net2 = YOLO('yolo12{}.pt'.format(net.v)).model.eval()
139+
l0,l1 = 9,21
140+
for module in net2.modules():
141+
if isinstance(module, AAttn):
142+
fuse_bias_v12(module.pe.conv, module.pe.bn)
121143

122144
assert (nP1 := count_parameters(net)) == (nP2 := count_parameters(net2)), f'wrong number of parameters net {nP1} vs ultralytics {nP2}'
123145
copy_params(net.net, net2.model[0:l0])
@@ -169,7 +191,7 @@ def params(n):
169191

170192
def load_from_yolov7_official(net: Yolov7, weights_pt: str):
171193
def params1():
172-
for l in net.layers():
194+
for l in [net.net, net.fpn, net.head.cv[0][0], net.head.cv[1][0], net.head.cv[2][0], net.head.cv[0][1], net.head.cv[1][1], net.head.cv[2][1]]:
173195
for k, v in l.state_dict().items():
174196
if 'anchor' not in k:
175197
yield v
@@ -206,6 +228,7 @@ def get_model(model: str, variant: str = ''):
206228
case 'yolov8': net = Yolov8(variant, 80).eval()
207229
case 'yolov10': net = Yolov10(variant, 80).eval()
208230
case 'yolov11': net = Yolov11(variant, 80).eval()
231+
case 'yolov12': net = Yolov12(variant, 80).eval()
209232

210233
print(f"{model}{variant} has {count_parameters(net)} parameters")
211234

@@ -217,7 +240,7 @@ def get_model(model: str, variant: str = ''):
217240
download_if_not_exist(model, filepath)
218241
load_from_darknet(net, filepath)
219242

220-
if model in ['yolov5', 'yolov8', 'yolov10', 'yolov11']:
243+
if model in ['yolov5', 'yolov8', 'yolov10', 'yolov11', 'yolov12']:
221244
load_from_ultralytics(net)
222245
has_obj = False
223246

@@ -321,6 +344,11 @@ def export(model: str, variant: str = '', slim=True):
321344
test('yolov11', 'm')
322345
test('yolov11', 'l')
323346
test('yolov11', 'x')
347+
test('yolov12', 'n')
348+
test('yolov12', 's')
349+
test('yolov12', 'm')
350+
test('yolov12', 'l')
351+
test('yolov12', 'x')
324352

325353
# export('yolov3')
326354
# export('yolov3-spp')
@@ -351,4 +379,9 @@ def export(model: str, variant: str = '', slim=True):
351379
# export('yolov11', 's')
352380
# export('yolov11', 'm')
353381
# export('yolov11', 'l')
354-
# export('yolov11', 'x')
382+
# export('yolov11', 'x')
383+
# export('yolov12', 'n')
384+
# export('yolov12', 's')
385+
# export('yolov12', 'm')
386+
# export('yolov12', 'l')
387+
# export('yolov12', 'x')

0 commit comments

Comments
 (0)