Skip to content

Commit 33912da

Browse files
authored
在Upsample和interpolate函数中加入recompute_scale_factor参数 (#71997)
1 parent 370e790 commit 33912da

File tree

4 files changed

+693
-27
lines changed

4 files changed

+693
-27
lines changed

python/paddle/nn/functional/common.py

+84-27
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def interpolate(
225225
data_format: (
226226
DataLayout1DVariant | DataLayout2D | DataLayout3D | None
227227
) = None,
228+
recompute_scale_factor: bool | None = None,
228229
name: str | None = None,
229230
) -> Tensor:
230231
"""
@@ -397,6 +398,12 @@ def interpolate(
397398
When it is `"NCHW"`, the data should be stored in the order of:
398399
`[batch_size, input_channels, input_height, input_width]`. When it is `"NCDHW"`, the
399400
data should be stored in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
401+
recompute_scale_factor (bool, optional): Whether to recompute the scaling factor for interpolation calculation.
402+
When set to `True`, the `scale_factor` parameter must be provided, and the function will use it along with
403+
the input tensor shape to calculate the output tensor shape, then recalculate the scaling factor based on
404+
the output and input tensor shapes. This parameter is particularly useful when `scale_factor` is a floating-point
405+
value. When set to `False`, either `size` or `scale_factor` will be used directly for interpolation without
406+
recalculation. Default: None.
400407
name(str, optional): The default value is None.
401408
Normally there is no need for user to set this property.
402409
For more information, please refer to :ref:`api_guide_Name`
@@ -552,6 +559,11 @@ def _is_list_or_tuple_(data):
552559
if out_shape is not None and scale is not None:
553560
raise ValueError("Only one of size or scale_factor should be defined.")
554561
if out_shape is not None:
562+
if recompute_scale_factor:
563+
raise ValueError(
564+
"recompute_scale_factor is not meaningful with an explicit size."
565+
)
566+
555567
if (
556568
isinstance(out_shape, (Variable, paddle.pir.Value))
557569
and not in_dynamic_mode()
@@ -644,36 +656,81 @@ def _is_list_or_tuple_(data):
644656
attrs['out_h'] = out_shape[1]
645657
attrs['out_w'] = out_shape[2]
646658

647-
else:
648-
if in_dynamic_mode() and isinstance(scale, Variable):
649-
if scale.shape == []:
650-
scale = float(scale)
659+
elif scale is not None:
660+
if recompute_scale_factor:
661+
if in_dynamic_mode() and isinstance(scale, Variable):
662+
if scale.shape == []:
663+
scale = float(scale)
664+
else:
665+
scale = list(scale.numpy())
666+
667+
dim = len(x.shape) - 2
668+
669+
if isinstance(scale, (float, int, numpy.ndarray)):
670+
scale_list = [float(scale)] * dim
671+
elif isinstance(scale, (list, tuple)):
672+
if len(scale) != dim:
673+
raise ValueError(
674+
f"scale_shape length should be {dim} for "
675+
f"input {len(x.shape)}-D tensor."
676+
)
677+
scale_list = list(map(float, scale))
651678
else:
652-
scale = list(scale.numpy())
653-
if isinstance(scale, (Variable, paddle.pir.Value)):
654-
scale.stop_gradient = True
655-
inputs["Scale"] = scale
656-
elif isinstance(scale, (float, int, numpy.ndarray)):
657-
if scale <= 0:
658-
raise ValueError("Attr(scale) should be greater than zero.")
659-
scale_list = []
660-
for i in range(len(x.shape) - 2):
661-
scale_list.append(scale)
662-
attrs['scale'] = list(map(float, scale_list))
663-
elif isinstance(scale, (list, tuple)):
664-
if len(scale) != len(x.shape) - 2:
665-
raise ValueError(
666-
f"scale_shape length should be {len(x.shape) - 2} for "
667-
f"input {len(x.shape)}-D tensor."
679+
raise TypeError(
680+
"Attr(scale)'s type should be float, int, list, tuple, or Tensor."
668681
)
669-
for value in scale:
670-
if value <= 0:
671-
raise ValueError("Attr(scale) should be greater than zero.")
672-
attrs['scale'] = list(map(float, scale))
682+
683+
out_shape = []
684+
for i in range(dim):
685+
input_size = x.shape[i + 2]
686+
output_size = int(
687+
numpy.floor(float(input_size) * scale_list[i])
688+
)
689+
out_shape.append(output_size)
690+
691+
if len(x.shape) == 3:
692+
attrs['out_w'] = out_shape[0]
693+
elif len(x.shape) == 4:
694+
attrs['out_h'] = out_shape[0]
695+
attrs['out_w'] = out_shape[1]
696+
elif len(x.shape) == 5:
697+
attrs['out_d'] = out_shape[0]
698+
attrs['out_h'] = out_shape[1]
699+
attrs['out_w'] = out_shape[2]
700+
701+
scale = None
673702
else:
674-
raise TypeError(
675-
"Attr(scale)'s type should be float, int, list, tuple, or Tensor."
676-
)
703+
if in_dynamic_mode() and isinstance(scale, Variable):
704+
if scale.shape == []:
705+
scale = float(scale)
706+
else:
707+
scale = list(scale.numpy())
708+
if isinstance(scale, (Variable, paddle.pir.Value)):
709+
scale.stop_gradient = True
710+
inputs["Scale"] = scale
711+
elif isinstance(scale, (float, int, numpy.ndarray)):
712+
if scale <= 0:
713+
raise ValueError("Attr(scale) should be greater than zero.")
714+
scale_list = []
715+
for i in range(len(x.shape) - 2):
716+
scale_list.append(scale)
717+
attrs['scale'] = list(map(float, scale_list))
718+
elif isinstance(scale, (list, tuple)):
719+
if len(scale) != len(x.shape) - 2:
720+
raise ValueError(
721+
f"scale_shape length should be {len(x.shape) - 2} for "
722+
f"input {len(x.shape)}-D tensor."
723+
)
724+
for value in scale:
725+
if value <= 0:
726+
raise ValueError(
727+
"Attr(scale) should be greater than zero."
728+
)
729+
attrs['scale'] = list(map(float, scale))
730+
else:
731+
raise TypeError(
732+
"Attr(scale)'s type should be float, int, list, tuple, or Tensor."
733+
)
677734

678735
if in_dynamic_or_pir_mode():
679736
attr_list = []

python/paddle/nn/layer/common.py

+10
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,12 @@ class Upsample(Layer):
405405
When it is `"NCHW"`, the data should be stored in the order of:
406406
`[batch_size, input_channels, input_height, input_width]`. When it is `"NCDHW"`, the
407407
data should be stored in the order of: `[batch_size, input_channels, input_depth, input_height, input_width]`.
408+
recompute_scale_factor (bool, optional): Whether to recompute the scaling factor for interpolation calculation.
409+
When set to `True`, the `scale_factor` parameter must be provided, and the function will use it along with
410+
the input tensor shape to calculate the output tensor shape, then recalculate the scaling factor based on
411+
the output and input tensor shapes. This parameter is particularly useful when `scale_factor` is a floating-point
412+
value. When set to `False`, either `size` or `scale_factor` will be used directly for interpolation without
413+
recalculation. Default: None.
408414
name(str|None, optional): The default value is None.
409415
Normally there is no need for user to set this property.
410416
For more information, please refer to :ref:`api_guide_Name`
@@ -431,6 +437,7 @@ class Upsample(Layer):
431437
align_corners: bool
432438
align_mode: int
433439
data_format: DataLayout1DVariant | DataLayout2D | DataLayout3D | None
440+
recompute_scale_factor: bool | None
434441
name: str | None
435442

436443
def __init__(
@@ -443,6 +450,7 @@ def __init__(
443450
data_format: (
444451
DataLayout1DVariant | DataLayout2D | DataLayout3D | None
445452
) = None,
453+
recompute_scale_factor: bool | None = None,
446454
name: str | None = None,
447455
) -> None:
448456
super().__init__()
@@ -452,6 +460,7 @@ def __init__(
452460
self.align_corners = align_corners
453461
self.align_mode = align_mode
454462
self.data_format = data_format
463+
self.recompute_scale_factor = recompute_scale_factor
455464
self.name = name
456465

457466
def forward(self, x: Tensor) -> Tensor:
@@ -475,6 +484,7 @@ def forward(self, x: Tensor) -> Tensor:
475484
align_corners=self.align_corners,
476485
align_mode=self.align_mode,
477486
data_format=self.data_format,
487+
recompute_scale_factor=self.recompute_scale_factor,
478488
name=self.name,
479489
)
480490

test/legacy_test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,7 @@ set_tests_properties(test_multiprocess_dataloader_iterable_dataset_static
833833
set_tests_properties(test_lstm_cudnn_op PROPERTIES TIMEOUT 120)
834834
set_tests_properties(test_stack_op PROPERTIES TIMEOUT 120)
835835
set_tests_properties(test_bilinear_interp_v2_op PROPERTIES TIMEOUT 120)
836+
set_tests_properties(test_interp_recompute_scale_factor PROPERTIES TIMEOUT 60)
836837
set_tests_properties(test_cond PROPERTIES TIMEOUT 240)
837838
set_tests_properties(test_norm_nn_grad PROPERTIES TIMEOUT 180)
838839
set_tests_properties(test_matrix_nms_op PROPERTIES TIMEOUT 120)

0 commit comments

Comments
 (0)