@@ -546,65 +546,6 @@ def forward(self, x):
546
546
check_variable_and_dtype ,
547
547
)
548
548
549
-
550
- def group_norm (
551
- input ,
552
- groups ,
553
- epsilon = 1e-05 ,
554
- weight = None ,
555
- bias = None ,
556
- act = None ,
557
- data_layout = "NCHW" ,
558
- name = None ,
559
- ):
560
- helper = LayerHelper ("group_norm" , ** locals ())
561
- dtype = helper .input_dtype ()
562
- check_variable_and_dtype (
563
- input ,
564
- "input" ,
565
- ["float16" , "uint16" , "float32" , "float64" ],
566
- "group_norm" ,
567
- )
568
- # create intput and parameters
569
- inputs = {"X" : input }
570
- input_shape = input .shape
571
- if len (input_shape ) < 2 :
572
- raise ValueError (
573
- f"The dimensions of Op(static.nn.group_norm)'s input should be more than 1. But received { len (input_shape )} "
574
- )
575
- if data_layout != "NCHW" and data_layout != "NHWC" :
576
- raise ValueError (
577
- "Param(data_layout) of Op(static.nn.group_norm) got wrong value: received "
578
- + data_layout
579
- + " but only NCHW or NHWC supported."
580
- )
581
- channel_num = input_shape [1 ] if data_layout == "NCHW" else input_shape [- 1 ]
582
- param_shape = [channel_num ]
583
- inputs ["Scale" ] = weight
584
- inputs ["Bias" ] = bias
585
- # create output
586
- mean_out = helper .create_variable (dtype = dtype , stop_gradient = True )
587
- variance_out = helper .create_variable (dtype = dtype , stop_gradient = True )
588
- group_norm_out = helper .create_variable (dtype = dtype )
589
-
590
- helper .append_op (
591
- type = "group_norm" ,
592
- inputs = inputs ,
593
- outputs = {
594
- "Y" : group_norm_out ,
595
- "Mean" : mean_out ,
596
- "Variance" : variance_out ,
597
- },
598
- attrs = {
599
- "epsilon" : epsilon ,
600
- "groups" : groups ,
601
- "data_layout" : data_layout ,
602
- },
603
- )
604
-
605
- return helper .append_activation (group_norm_out )
606
-
607
-
608
549
class GroupNormAct (nn .GroupNorm ):
609
550
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
610
551
def __init__ (
@@ -630,9 +571,7 @@ def __init__(
630
571
self .act = nn .Identity ()
631
572
632
573
def forward (self , x ):
633
- x = group_norm (
634
- x , self ._num_groups , self ._epsilon , weight = self .weight , bias = self .bias
635
- )
574
+ x = F .group_norm (x , num_groups = self ._num_groups , weight = self .weight , bias = self .bias )
636
575
x = self .act (x )
637
576
return x
638
577
0 commit comments