Skip to content

Commit 938b83a

Browse files
author
pfeatherstone
committed
A bit of refactoring. Don't like base class and inheritance but i need it to disguish all the different types when loading weights from official repos (for now)
1 parent 92ad6f9 commit 938b83a

File tree

1 file changed

+65
-122
lines changed

1 file changed

+65
-122
lines changed

src/models.py

Lines changed: 65 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
ANCHORS_V3 = [[(10,13), (16,30), (33,23)], [(30,61), (62,45), (59,119)], [(116,90), (156,198), (373,326)]]
3232
ANCHORS_V3_TINY = [[(10,14), (23,27), (37,58)], [(81,82), (135,169), (344,319)]]
3333
ANCHORS_V4 = [[(12,16), (19,36), (40,28)], [(36,75), (76,55), (72, 146)], [(142,110), (192,243), (459,401)]]
34-
ANCHORS_V7 = ANCHORS_V4
3534

3635
actV3 = nn.LeakyReLU(negative_slope=0.1, inplace=True)
3736
actV4 = nn.Mish(inplace=True)
@@ -1046,7 +1045,7 @@ def __init__(self, nc=80, ch=()):
10461045
self.one2one_cv3 = deepcopy(self.cv3)
10471046
self.max_det = 100
10481047

1049-
def forward(self, x):
1048+
def forward(self, x, targets=None):
10501049
# TODO: implement all the topk stuff. I think yolov10 doesn't need NMS. But you can you use it in inference mode for now.
10511050
return self.forward_private(x, self.one2one_cv2, self.one2one_cv3)
10521051

@@ -1100,152 +1099,96 @@ def forward(self, xs, targets=None):
11001099

11011100
return pred if not exists(targets) else (pred, {'iou': loss_iou+loss_iou_distill, 'dfl': loss_dfl, 'cls': loss_cls})
11021101

1103-
class Yolov3(nn.Module):
1104-
def __init__(self, nclasses, spp):
1102+
class YoloBase(nn.Module):
1103+
def __init__(self, net, fpn, head, variant=None):
11051104
super().__init__()
1106-
self.net = Darknet53()
1107-
self.fpn = HeadV3(spp)
1108-
self.head = DetectV3(nclasses, [8,16,32], ANCHORS_V3, [1,1,1], ch=[(128, 256), (256, 512), (512, 1024)])
1105+
self.v = variant
1106+
self.net = net
1107+
self.fpn = fpn
1108+
self.head = head
11091109

11101110
def forward(self, x, targets=None):
11111111
x = self.net(x)
11121112
x = self.fpn(*x)
11131113
return self.head(x, targets=targets)
11141114

1115-
class Yolov3Tiny(nn.Module):
1115+
class Yolov3(YoloBase):
1116+
def __init__(self, nclasses, spp):
1117+
super().__init__(Darknet53(),
1118+
HeadV3(spp),
1119+
DetectV3(nclasses, [8,16,32], ANCHORS_V3, [1,1,1], ch=[(128, 256), (256, 512), (512, 1024)]))
1120+
1121+
class Yolov3Tiny(YoloBase):
11161122
def __init__(self, nclasses):
1117-
super().__init__()
1118-
self.net = BackboneV3Tiny()
1119-
self.fpn = HeadV3Tiny(1024)
1120-
self.head = DetectV3(nclasses, [16,32], ANCHORS_V3_TINY, [1,1], ch=[(384, 256), (256, 512)])
1121-
1122-
def forward(self, x, targets=None):
1123-
x = self.net(x)
1124-
x = self.fpn(*x)
1125-
return self.head(x, targets=targets)
1123+
super().__init__(BackboneV3Tiny(),
1124+
HeadV3Tiny(1024),
1125+
DetectV3(nclasses, [16,32], ANCHORS_V3_TINY, [1,1], ch=[(384, 256), (256, 512)]))
11261126

1127-
class Yolov4(nn.Module):
1127+
class Yolov4(YoloBase):
11281128
def __init__(self, nclasses, act=actV4):
1129-
super().__init__()
1130-
self.net = BackboneV4(act)
1131-
self.fpn = HeadV4(actV3)
1132-
self.head = DetectV3(nclasses, [8,16,32], ANCHORS_V4, [1.2, 1.1, 1.05], ch=[(128, 256), (256, 512), (512, 1024)])
1133-
1134-
def forward(self, x, targets=None):
1135-
x = self.net(x)
1136-
x = self.fpn(*x)
1137-
return self.head(x, targets=targets)
1129+
super().__init__(BackboneV4(act),
1130+
HeadV4(actV3),
1131+
DetectV3(nclasses, [8,16,32], ANCHORS_V4, [1.2, 1.1, 1.05], ch=[(128, 256), (256, 512), (512, 1024)]))
11381132

1139-
class Yolov4Tiny(nn.Module):
1133+
class Yolov4Tiny(YoloBase):
11401134
def __init__(self, nclasses):
1141-
super().__init__()
1142-
self.net = BackboneV4Tiny()
1143-
self.fpn = HeadV3Tiny(512)
1144-
self.head = DetectV3(nclasses, [16,32], ANCHORS_V3_TINY, [1.05,1.5], ch=[(384, 256), (256, 512)])
1145-
1146-
def forward(self, x, targets=None):
1147-
x = self.net(x)
1148-
x = self.fpn(*x)
1149-
return self.head(x, targets=targets)
1135+
super().__init__(BackboneV4Tiny(),
1136+
HeadV3Tiny(512),
1137+
DetectV3(nclasses, [16,32], ANCHORS_V3_TINY, [1.05,1.5], ch=[(384, 256), (256, 512)]))
11501138

1151-
class Yolov7(nn.Module):
1139+
class Yolov7(YoloBase):
11521140
def __init__(self, nclasses):
1153-
super().__init__()
1154-
ch = [(128,256), (256,512), (512,1024)]
1155-
self.net = BackboneV7()
1156-
self.fpn = HeadV7()
1157-
self.head = DetectV3(nclasses, [8,16,32], ANCHORS_V7, [2,2,2], ch=ch, is_v7=True)
1158-
1159-
def layers(self):
1160-
return [self.net, self.fpn, self.head.cv[0][0], self.head.cv[1][0], self.head.cv[2][0], self.head.cv[0][1], self.head.cv[1][1], self.head.cv[2][1]]
1161-
1162-
def forward(self, x, targets=None):
1163-
x = self.net(x)
1164-
x = self.fpn(*x)
1165-
return self.head(x, targets=targets)
1141+
super().__init__(BackboneV7(),
1142+
HeadV7(),
1143+
DetectV3(nclasses, [8,16,32], ANCHORS_V4, [2,2,2], ch=[(128,256), (256,512), (512,1024)], is_v7=True))
11661144

1167-
class Yolov5(nn.Module):
1145+
class Yolov5(YoloBase):
11681146
def __init__(self, variant, num_classes):
1169-
super().__init__()
1170-
self.v = variant
1171-
d, w, r = get_variant_multiplesV5(variant)
1172-
self.net = BackboneV5(w, r, d)
1173-
self.fpn = HeadV5(w, r, d)
1174-
self.head = Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*2)))
1175-
1176-
def forward(self, x, targets=None):
1177-
x = self.net(x)
1178-
x = self.fpn(*x)
1179-
return self.head(x, targets)
1180-
1181-
class Yolov8(nn.Module):
1147+
d, w, r = get_variant_multiplesV5(variant)
1148+
super().__init__(BackboneV5(w, r, d),
1149+
HeadV5(w, r, d),
1150+
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*2))),
1151+
variant)
1152+
1153+
class Yolov8(YoloBase):
11821154
def __init__(self, variant, num_classes):
1183-
super().__init__()
1184-
self.v = variant
1185-
d, w, r = get_variant_multiplesV8(variant)
1186-
self.net = BackboneV8(w, r, d)
1187-
self.fpn = HeadV8(w, r, d)
1188-
self.head = Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r)))
1155+
d, w, r = get_variant_multiplesV8(variant)
1156+
super().__init__(BackboneV8(w, r, d),
1157+
HeadV8(w, r, d),
1158+
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r))),
1159+
variant)
11891160

1190-
def forward(self, x, targets=None):
1191-
x = self.net(x)
1192-
x = self.fpn(*x)
1193-
return self.head(x, targets)
1194-
1195-
class Yolov10(nn.Module):
1161+
class Yolov10(YoloBase):
11961162
def __init__(self, variant, num_classes):
1197-
super().__init__()
1198-
self.v = variant
1199-
d, w, r = get_variant_multiplesV10(variant)
1200-
self.net = BackboneV10(w, r, d, variant)
1201-
self.fpn = HeadV10(w, r, d, variant)
1202-
self.head = DetectV10(num_classes, ch=(int(256*w), int(512*w), int(512*w*r)))
1203-
1204-
def forward(self, x):
1205-
x = self.net(x)
1206-
x = self.fpn(*x)
1207-
return self.head(x)
1163+
d, w, r = get_variant_multiplesV10(variant)
1164+
super().__init__(BackboneV10(w, r, d, variant),
1165+
HeadV10(w, r, d, variant),
1166+
DetectV10(num_classes, ch=(int(256*w), int(512*w), int(512*w*r))),
1167+
variant)
12081168

1209-
class Yolov11(nn.Module):
1169+
class Yolov11(YoloBase):
12101170
def __init__(self, variant, num_classes):
1211-
super().__init__()
1212-
self.v = variant
1213-
d, w, r = get_variant_multiplesV11(variant)
1214-
self.net = BackboneV11(w, r, d, variant)
1215-
self.fpn = HeadV11(w, r, d, variant)
1216-
self.head = Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r)), v11=True)
1217-
1218-
def forward(self, x):
1219-
x = self.net(x)
1220-
x = self.fpn(*x)
1221-
return self.head(x)
1171+
d, w, r = get_variant_multiplesV11(variant)
1172+
super().__init__(BackboneV11(w, r, d, variant),
1173+
HeadV11(w, r, d, variant),
1174+
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r)), v11=True),
1175+
variant)
12221176

1223-
class Yolov12(nn.Module):
1177+
class Yolov12(YoloBase):
12241178
def __init__(self, variant, num_classes):
1225-
super().__init__()
1226-
self.v = variant
1227-
d, w, r = get_variant_multiplesV12(variant)
1228-
self.net = BackboneV12(w, r, d, variant)
1229-
self.fpn = HeadV12(w, r, d)
1230-
self.head = Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r)), v11=True)
1231-
1232-
def forward(self, x):
1233-
x = self.net(x)
1234-
x = self.fpn(*x)
1235-
return self.head(x)
1179+
d, w, r = get_variant_multiplesV12(variant)
1180+
super().__init__(BackboneV12(w, r, d, variant),
1181+
HeadV12(w, r, d),
1182+
Detect(num_classes, ch=(int(256*w), int(512*w), int(512*w*r)), v11=True),
1183+
variant)
12361184

1237-
class Yolov6(nn.Module):
1185+
class Yolov6(YoloBase):
12381186
def __init__(self, variant, num_classes):
1239-
super().__init__()
12401187
d, w, csp, csp_e, distill = get_variant_multiplesV6(variant)
1241-
self.net = CSPBepBackbone(w, d, csp_e=csp_e) if csp else EfficientRep(w, d, cspsppf=True)
1242-
self.fpn = CSPRepBiFPANNeck(w, d, csp_e=csp_e) if csp else RepBiFPANNeck(w, d)
1243-
self.head = DetectV6(num_classes, ch=(int(128*w), int(256*w), int(512*w)), use_dfl=True, distill=distill)
1244-
1245-
def forward(self, x, targets=None):
1246-
x = self.net(x)
1247-
x = self.fpn(*x)
1248-
return self.head(x, targets)
1188+
super().__init__(CSPBepBackbone(w, d, csp_e=csp_e) if csp else EfficientRep(w, d, cspsppf=True),
1189+
CSPRepBiFPANNeck(w, d, csp_e=csp_e) if csp else RepBiFPANNeck(w, d),
1190+
DetectV6(num_classes, ch=(int(128*w), int(256*w), int(512*w)), use_dfl=True, distill=distill),
1191+
variant)
12491192

12501193
@torch.no_grad()
12511194
def nms(preds: torch.Tensor, conf_thresh: float, nms_thresh: float , has_objectness: bool):

0 commit comments

Comments
 (0)