@@ -262,9 +262,9 @@ def CIB(c1, c2, shortcut=True, e=0.5, lk=False):
262
262
Conv (c2 , c2 , 3 , g = c2 ))
263
263
return Residual (net ) if shortcut else net
264
264
265
- def C2fCIB (c1 , c2 , n = 1 , shortcut = False , lk = False , e = 0.5 ):
265
+ def C2fCIB (c1 , c2 , n = 1 , shortcut = False , e = 0.5 , cib = True , lk = False ):
266
266
net = C2f (c1 , c2 , n , shortcut , e )
267
- net .m = nn .ModuleList (CIB (net .c_ , net .c_ , shortcut , e = 1.0 , lk = lk ) for _ in range (n ))
267
+ if cib : net .m = nn .ModuleList (CIB (net .c_ , net .c_ , shortcut , e = 1.0 , lk = lk ) for _ in range (n ))
268
268
return net
269
269
270
270
class Spp (nn .Module ):
@@ -512,16 +512,16 @@ def forward(self, x):
512
512
class BackboneV5 (nn .Module ):
513
513
def __init__ (self , w , r , d ):
514
514
super ().__init__ ()
515
- self .b0 = Conv (c1 = 3 , c2 = int (64 * w ), k = 6 , s = 2 , p = 2 )
516
- self .b1 = Conv (int (64 * w ), int (128 * w ), k = 3 , s = 2 )
517
- self .b2 = C3 (c1 = int (128 * w ), c2 = int (128 * w ), n = round (3 * d ))
518
- self .b3 = Conv (int (128 * w ), int (256 * w ), k = 3 , s = 2 )
519
- self .b4 = C3 (c1 = int (256 * w ), c2 = int (256 * w ), n = round (6 * d ))
520
- self .b5 = Conv (int (256 * w ), int (512 * w ), k = 3 , s = 2 )
521
- self .b6 = C3 (c1 = int (512 * w ), c2 = int (512 * w ), n = round (9 * d ))
522
- self .b7 = Conv (int (512 * w ), int (512 * w * r ), k = 3 , s = 2 )
523
- self .b8 = C3 (c1 = int (512 * w * r ), c2 = int (512 * w * r ), n = round (3 * d ))
524
- self .b9 = SPPF (int (512 * w * r ), int (512 * w * r ))
515
+ self .b0 = Conv (c1 = 3 , c2 = int (64 * w ), k = 6 , s = 2 , p = 2 )
516
+ self .b1 = Conv (c1 = int (64 * w ), c2 = int (128 * w ), k = 3 , s = 2 )
517
+ self .b2 = C3 (c1 = int (128 * w ), c2 = int (128 * w ), n = round (3 * d ))
518
+ self .b3 = Conv (c1 = int (128 * w ), c2 = int (256 * w ), k = 3 , s = 2 )
519
+ self .b4 = C3 (c1 = int (256 * w ), c2 = int (256 * w ), n = round (6 * d ))
520
+ self .b5 = Conv (c1 = int (256 * w ), c2 = int (512 * w ), k = 3 , s = 2 )
521
+ self .b6 = C3 (c1 = int (512 * w ), c2 = int (512 * w ), n = round (9 * d ))
522
+ self .b7 = Conv (c1 = int (512 * w ), c2 = int (512 * w * r ), k = 3 , s = 2 )
523
+ self .b8 = C3 (c1 = int (512 * w * r ), c2 = int (512 * w * r ), n = round (3 * d ))
524
+ self .b9 = SPPF (c1 = int (512 * w * r ), c2 = int (512 * w * r ))
525
525
526
526
def forward (self , x ):
527
527
x4 = self .b4 (self .b3 (self .b2 (self .b1 (self .b0 (x ))))) # 4 P3/8
@@ -532,16 +532,16 @@ def forward(self, x):
532
532
class BackboneV8 (nn .Module ):
533
533
def __init__ (self , w , r , d ):
534
534
super ().__init__ ()
535
- self .b0 = Conv (c1 = 3 , c2 = int (64 * w ), k = 3 , s = 2 )
536
- self .b1 = Conv (int (64 * w ), int (128 * w ), k = 3 , s = 2 )
537
- self .b2 = C2f (c1 = int (128 * w ), c2 = int (128 * w ), n = round (3 * d ), shortcut = True )
538
- self .b3 = Conv (int (128 * w ), int (256 * w ), k = 3 , s = 2 )
539
- self .b4 = C2f (c1 = int (256 * w ), c2 = int (256 * w ), n = round (6 * d ), shortcut = True )
540
- self .b5 = Conv (int (256 * w ), int (512 * w ), k = 3 , s = 2 )
541
- self .b6 = C2f (c1 = int (512 * w ), c2 = int (512 * w ), n = round (6 * d ), shortcut = True )
542
- self .b7 = Conv (int (512 * w ), int (512 * w * r ), k = 3 , s = 2 )
543
- self .b8 = C2f (c1 = int (512 * w * r ), c2 = int (512 * w * r ), n = round (3 * d ), shortcut = True )
544
- self .b9 = SPPF (int (512 * w * r ), int (512 * w * r ))
535
+ self .b0 = Conv (c1 = 3 , c2 = int (64 * w ), k = 3 , s = 2 )
536
+ self .b1 = Conv (c1 = int (64 * w ), c2 = int (128 * w ), k = 3 , s = 2 )
537
+ self .b2 = C2f (c1 = int (128 * w ), c2 = int (128 * w ), n = round (3 * d ), shortcut = True )
538
+ self .b3 = Conv (c1 = int (128 * w ), c2 = int (256 * w ), k = 3 , s = 2 )
539
+ self .b4 = C2f (c1 = int (256 * w ), c2 = int (256 * w ), n = round (6 * d ), shortcut = True )
540
+ self .b5 = Conv (c1 = int (256 * w ), c2 = int (512 * w ), k = 3 , s = 2 )
541
+ self .b6 = C2f (c1 = int (512 * w ), c2 = int (512 * w ), n = round (6 * d ), shortcut = True )
542
+ self .b7 = Conv (c1 = int (512 * w ), c2 = int (512 * w * r ), k = 3 , s = 2 )
543
+ self .b8 = C2f (c1 = int (512 * w * r ), c2 = int (512 * w * r ), n = round (3 * d ), shortcut = True )
544
+ self .b9 = SPPF (c1 = int (512 * w * r ), c2 = int (512 * w * r ))
545
545
546
546
def forward (self , x ):
547
547
x4 = self .b4 (self .b3 (self .b2 (self .b1 (self .b0 (x ))))) # 4 P3/8
@@ -552,21 +552,16 @@ def forward(self, x):
552
552
class BackboneV10 (nn .Module ):
553
553
def __init__ (self , w , r , d , variant ):
554
554
super ().__init__ ()
555
- self .b0 = Conv (c1 = 3 , c2 = int (64 * w ), k = 3 , s = 2 )
556
- self .b1 = Conv (int (64 * w ), int (128 * w ), k = 3 , s = 2 )
557
- self .b2 = C2f (c1 = int (128 * w ), c2 = int (128 * w ), n = round (3 * d ), shortcut = True )
558
- self .b3 = Conv (int (128 * w ), int (256 * w ), k = 3 , s = 2 )
559
- self .b4 = C2f (c1 = int (256 * w ), c2 = int (256 * w ), n = round (6 * d ), shortcut = True )
560
- self .b5 = SCDown (int (256 * w ), int (512 * w ), k = 3 , s = 2 )
561
- match variant :
562
- case 'x' : self .b6 = C2fCIB (c1 = int (512 * w ), c2 = int (512 * w ), n = round (6 * d ), shortcut = True )
563
- case _ : self .b6 = C2f (c1 = int (512 * w ), c2 = int (512 * w ), n = round (6 * d ), shortcut = True )
564
- self .b7 = SCDown (int (512 * w ), int (512 * w * r ), k = 3 , s = 2 )
565
- match variant :
566
- case 'n' : self .b8 = C2f (c1 = int (512 * w * r ), c2 = int (512 * w * r ), n = round (3 * d ), shortcut = True )
567
- case 's' : self .b8 = C2fCIB (c1 = int (512 * w * r ), c2 = int (512 * w * r ), n = round (3 * d ), shortcut = True , lk = True )
568
- case _ : self .b8 = C2fCIB (c1 = int (512 * w * r ), c2 = int (512 * w * r ), n = round (3 * d ), shortcut = True , lk = False )
569
- self .b9 = SPPF (int (512 * w * r ), int (512 * w * r ))
555
+ self .b0 = Conv (c1 = 3 , c2 = int (64 * w ), k = 3 , s = 2 )
556
+ self .b1 = Conv (c1 = int (64 * w ), c2 = int (128 * w ), k = 3 , s = 2 )
557
+ self .b2 = C2f (c1 = int (128 * w ), c2 = int (128 * w ), n = round (3 * d ), shortcut = True )
558
+ self .b3 = Conv (c1 = int (128 * w ), c2 = int (256 * w ), k = 3 , s = 2 )
559
+ self .b4 = C2f (c1 = int (256 * w ), c2 = int (256 * w ), n = round (6 * d ), shortcut = True )
560
+ self .b5 = SCDown (c1 = int (256 * w ), c2 = int (512 * w ), k = 3 , s = 2 )
561
+ self .b6 = C2fCIB (c1 = int (512 * w ), c2 = int (512 * w ), n = round (6 * d ), shortcut = True , cib = variant == 'x' )
562
+ self .b7 = SCDown (c1 = int (512 * w ), c2 = int (512 * w * r ), k = 3 , s = 2 )
563
+ self .b8 = C2fCIB (c1 = int (512 * w * r ), c2 = int (512 * w * r ), n = round (3 * d ), shortcut = True , cib = not variant == 'n' , lk = variant == 's' )
564
+ self .b9 = SPPF (c1 = int (512 * w * r ), c2 = int (512 * w * r ))
570
565
self .b10 = PSA (int (512 * w * r ))
571
566
572
567
def forward (self , x ):
@@ -773,18 +768,12 @@ class HeadV10(nn.Module):
773
768
def __init__ (self , w , r , d , variant ):
774
769
super ().__init__ ()
775
770
self .up = nn .Upsample (scale_factor = 2 )
776
- match variant :
777
- case 'n' | 's' | 'm' : self .n1 = C2f (c1 = int (512 * w * (1 + r )), c2 = int (512 * w ), n = round (3 * d ))
778
- case _ : self .n1 = C2fCIB (c1 = int (512 * w * (1 + r )), c2 = int (512 * w ), n = round (3 * d ), shortcut = True )
779
- self .n2 = C2f (c1 = int (768 * w ), c2 = int (256 * w ), n = round (3 * d ))
780
- self .n3 = Conv (c1 = int (256 * w ), c2 = int (256 * w ), k = 3 , s = 2 )
781
- match variant :
782
- case 'n' | 's' : self .n4 = C2f (c1 = int (768 * w ), c2 = int (512 * w ), n = round (3 * d ))
783
- case _ : self .n4 = C2fCIB (c1 = int (768 * w ), c2 = int (512 * w ), n = round (3 * d ), shortcut = True )
784
- self .n5 = SCDown (c1 = int (512 * w ), c2 = int (512 * w ), k = 3 , s = 2 )
785
- match variant :
786
- case 'n' | 's' : self .n6 = C2fCIB (c1 = int (512 * w * (1 + r )), c2 = int (512 * w * r ), n = round (3 * d ), shortcut = True , lk = True )
787
- case _ : self .n6 = C2fCIB (c1 = int (512 * w * (1 + r )), c2 = int (512 * w * r ), n = round (3 * d ), shortcut = True , lk = False )
771
+ self .n1 = C2fCIB (c1 = int (512 * w * (1 + r )), c2 = int (512 * w ), n = round (3 * d ), shortcut = True , cib = variant in "blx" )
772
+ self .n2 = C2f (c1 = int (768 * w ), c2 = int (256 * w ), n = round (3 * d ))
773
+ self .n3 = Conv (c1 = int (256 * w ), c2 = int (256 * w ), k = 3 , s = 2 )
774
+ self .n4 = C2fCIB (c1 = int (768 * w ), c2 = int (512 * w ), n = round (3 * d ), shortcut = True , cib = variant in "mblx" )
775
+ self .n5 = SCDown (c1 = int (512 * w ), c2 = int (512 * w ), k = 3 , s = 2 )
776
+ self .n6 = C2fCIB (c1 = int (512 * w * (1 + r )), c2 = int (512 * w * r ), n = round (3 * d ), shortcut = True , cib = True , lk = variant in "ns" )
788
777
789
778
def forward (self , x4 , x6 , x10 ):
790
779
x13 = self .n1 (torch .cat ([self .up (x10 ),x6 ], 1 )) # 13
0 commit comments