@@ -225,6 +225,7 @@ def interpolate(
225
225
data_format : (
226
226
DataLayout1DVariant | DataLayout2D | DataLayout3D | None
227
227
) = None ,
228
+ recompute_scale_factor : bool | None = None ,
228
229
name : str | None = None ,
229
230
) -> Tensor :
230
231
"""
@@ -397,6 +398,12 @@ def interpolate(
397
398
When it is `"NCHW"`, the data should be stored in the order of:
398
399
`[batch_size, input_channels, input_height, input_width]`. When it is `"NCDHW"`, the
399
400
data should be stored in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
401
+ recompute_scale_factor (bool, optional): Whether to recompute the scaling factor for interpolation calculation.
402
+ When set to `True`, the `scale_factor` parameter must be provided, and the function will use it along with
403
+ the input tensor shape to calculate the output tensor shape, then recalculate the scaling factor based on
404
+ the output and input tensor shapes. This parameter is particularly useful when `scale_factor` is a floating-point
405
+ value. When set to `False`, either `size` or `scale_factor` will be used directly for interpolation without
406
+ recalculation. Default: None.
400
407
name(str, optional): The default value is None.
401
408
Normally there is no need for user to set this property.
402
409
For more information, please refer to :ref:`api_guide_Name`
@@ -552,6 +559,11 @@ def _is_list_or_tuple_(data):
552
559
if out_shape is not None and scale is not None :
553
560
raise ValueError ("Only one of size or scale_factor should be defined." )
554
561
if out_shape is not None :
562
+ if recompute_scale_factor :
563
+ raise ValueError (
564
+ "recompute_scale_factor is not meaningful with an explicit size."
565
+ )
566
+
555
567
if (
556
568
isinstance (out_shape , (Variable , paddle .pir .Value ))
557
569
and not in_dynamic_mode ()
@@ -644,36 +656,81 @@ def _is_list_or_tuple_(data):
644
656
attrs ['out_h' ] = out_shape [1 ]
645
657
attrs ['out_w' ] = out_shape [2 ]
646
658
647
- else :
648
- if in_dynamic_mode () and isinstance (scale , Variable ):
649
- if scale .shape == []:
650
- scale = float (scale )
659
+ elif scale is not None :
660
+ if recompute_scale_factor :
661
+ if in_dynamic_mode () and isinstance (scale , Variable ):
662
+ if scale .shape == []:
663
+ scale = float (scale )
664
+ else :
665
+ scale = list (scale .numpy ())
666
+
667
+ dim = len (x .shape ) - 2
668
+
669
+ if isinstance (scale , (float , int , numpy .ndarray )):
670
+ scale_list = [float (scale )] * dim
671
+ elif isinstance (scale , (list , tuple )):
672
+ if len (scale ) != dim :
673
+ raise ValueError (
674
+ f"scale_shape length should be { dim } for "
675
+ f"input { len (x .shape )} -D tensor."
676
+ )
677
+ scale_list = list (map (float , scale ))
651
678
else :
652
- scale = list (scale .numpy ())
653
- if isinstance (scale , (Variable , paddle .pir .Value )):
654
- scale .stop_gradient = True
655
- inputs ["Scale" ] = scale
656
- elif isinstance (scale , (float , int , numpy .ndarray )):
657
- if scale <= 0 :
658
- raise ValueError ("Attr(scale) should be greater than zero." )
659
- scale_list = []
660
- for i in range (len (x .shape ) - 2 ):
661
- scale_list .append (scale )
662
- attrs ['scale' ] = list (map (float , scale_list ))
663
- elif isinstance (scale , (list , tuple )):
664
- if len (scale ) != len (x .shape ) - 2 :
665
- raise ValueError (
666
- f"scale_shape length should be { len (x .shape ) - 2 } for "
667
- f"input { len (x .shape )} -D tensor."
679
+ raise TypeError (
680
+ "Attr(scale)'s type should be float, int, list, tuple, or Tensor."
668
681
)
669
- for value in scale :
670
- if value <= 0 :
671
- raise ValueError ("Attr(scale) should be greater than zero." )
672
- attrs ['scale' ] = list (map (float , scale ))
682
+
683
+ out_shape = []
684
+ for i in range (dim ):
685
+ input_size = x .shape [i + 2 ]
686
+ output_size = int (
687
+ numpy .floor (float (input_size ) * scale_list [i ])
688
+ )
689
+ out_shape .append (output_size )
690
+
691
+ if len (x .shape ) == 3 :
692
+ attrs ['out_w' ] = out_shape [0 ]
693
+ elif len (x .shape ) == 4 :
694
+ attrs ['out_h' ] = out_shape [0 ]
695
+ attrs ['out_w' ] = out_shape [1 ]
696
+ elif len (x .shape ) == 5 :
697
+ attrs ['out_d' ] = out_shape [0 ]
698
+ attrs ['out_h' ] = out_shape [1 ]
699
+ attrs ['out_w' ] = out_shape [2 ]
700
+
701
+ scale = None
673
702
else :
674
- raise TypeError (
675
- "Attr(scale)'s type should be float, int, list, tuple, or Tensor."
676
- )
703
+ if in_dynamic_mode () and isinstance (scale , Variable ):
704
+ if scale .shape == []:
705
+ scale = float (scale )
706
+ else :
707
+ scale = list (scale .numpy ())
708
+ if isinstance (scale , (Variable , paddle .pir .Value )):
709
+ scale .stop_gradient = True
710
+ inputs ["Scale" ] = scale
711
+ elif isinstance (scale , (float , int , numpy .ndarray )):
712
+ if scale <= 0 :
713
+ raise ValueError ("Attr(scale) should be greater than zero." )
714
+ scale_list = []
715
+ for i in range (len (x .shape ) - 2 ):
716
+ scale_list .append (scale )
717
+ attrs ['scale' ] = list (map (float , scale_list ))
718
+ elif isinstance (scale , (list , tuple )):
719
+ if len (scale ) != len (x .shape ) - 2 :
720
+ raise ValueError (
721
+ f"scale_shape length should be { len (x .shape ) - 2 } for "
722
+ f"input { len (x .shape )} -D tensor."
723
+ )
724
+ for value in scale :
725
+ if value <= 0 :
726
+ raise ValueError (
727
+ "Attr(scale) should be greater than zero."
728
+ )
729
+ attrs ['scale' ] = list (map (float , scale ))
730
+ else :
731
+ raise TypeError (
732
+ "Attr(scale)'s type should be float, int, list, tuple, or Tensor."
733
+ )
677
734
678
735
if in_dynamic_or_pir_mode ():
679
736
attr_list = []
0 commit comments