@@ -225,6 +225,7 @@ def interpolate(
225
225
data_format : (
226
226
DataLayout1DVariant | DataLayout2D | DataLayout3D | None
227
227
) = None ,
228
+ recompute_scale_factor : bool | None = None ,
228
229
name : str | None = None ,
229
230
) -> Tensor :
230
231
"""
@@ -493,6 +494,15 @@ def interpolate(
493
494
"align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
494
495
)
495
496
497
+ if (
498
+ recompute_scale_factor is not None
499
+ and recompute_scale_factor
500
+ and size is not None
501
+ ):
502
+ raise ValueError (
503
+ "recompute_scale_factor is not meaningful with an explicit size."
504
+ )
505
+
496
506
if resample == 'AREA' :
497
507
if isinstance (size , (list , tuple , Variable , paddle .pir .Value )):
498
508
if len (size ) == 0 :
@@ -645,7 +655,7 @@ def _is_list_or_tuple_(data):
645
655
attrs ['out_h' ] = out_shape [1 ]
646
656
attrs ['out_w' ] = out_shape [2 ]
647
657
648
- else :
658
+ elif scale is not None and recompute_scale_factor is not True :
649
659
if in_dynamic_mode () and isinstance (scale , Variable ):
650
660
if scale .shape == []:
651
661
scale = float (scale )
@@ -676,6 +686,51 @@ def _is_list_or_tuple_(data):
676
686
"Attr(scale)'s type should be float, int, list, tuple, or Tensor."
677
687
)
678
688
689
+ elif recompute_scale_factor is not None and recompute_scale_factor :
690
+ assert (
691
+ scale is not None
692
+ ), "scale_factor must not be None when recompute_scale_factor=True"
693
+
694
+ if in_dynamic_mode () and isinstance (scale , Variable ):
695
+ if scale .shape == []:
696
+ scale = float (scale )
697
+ else :
698
+ scale = list (scale .numpy ())
699
+
700
+ dim = len (x .shape ) - 2
701
+
702
+ if isinstance (scale , (float , int , numpy .ndarray )):
703
+ scale_list = [float (scale )] * dim
704
+ elif isinstance (scale , (list , tuple )):
705
+ if len (scale ) != dim :
706
+ raise ValueError (
707
+ f"scale_shape length should be { dim } for "
708
+ f"input { len (x .shape )} -D tensor."
709
+ )
710
+ scale_list = list (map (float , scale ))
711
+ else :
712
+ raise TypeError (
713
+ "Attr(scale)'s type should be float, int, list, tuple, or Tensor."
714
+ )
715
+
716
+ out_shape = []
717
+ for i in range (dim ):
718
+ input_size = x .shape [i + 2 ]
719
+ output_size = int (numpy .floor (float (input_size ) * scale_list [i ]))
720
+ out_shape .append (output_size )
721
+
722
+ if len (x .shape ) == 3 :
723
+ attrs ['out_w' ] = out_shape [0 ]
724
+ elif len (x .shape ) == 4 :
725
+ attrs ['out_h' ] = out_shape [0 ]
726
+ attrs ['out_w' ] = out_shape [1 ]
727
+ elif len (x .shape ) == 5 :
728
+ attrs ['out_d' ] = out_shape [0 ]
729
+ attrs ['out_h' ] = out_shape [1 ]
730
+ attrs ['out_w' ] = out_shape [2 ]
731
+
732
+ scale = None
733
+
679
734
if in_dynamic_or_pir_mode ():
680
735
attr_list = []
681
736
for k , v in attrs .items ():
0 commit comments