Skip to content

Commit 3482f93

Browse files
committed
added recompute_scale_factor in interpolate
modified: python/paddle/nn/functional/common.py modified: python/paddle/nn/layer/common.py new file: test/legacy_test/test_interp_recompute_scale_factor.py
1 parent efb2886 commit 3482f93

File tree

3 files changed

+631
-1
lines changed

3 files changed

+631
-1
lines changed

python/paddle/nn/functional/common.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def interpolate(
225225
data_format: (
226226
DataLayout1DVariant | DataLayout2D | DataLayout3D | None
227227
) = None,
228+
recompute_scale_factor: bool | None = None,
228229
name: str | None = None,
229230
) -> Tensor:
230231
"""
@@ -493,6 +494,15 @@ def interpolate(
493494
"align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
494495
)
495496

497+
if (
498+
recompute_scale_factor is not None
499+
and recompute_scale_factor
500+
and size is not None
501+
):
502+
raise ValueError(
503+
"recompute_scale_factor is not meaningful with an explicit size."
504+
)
505+
496506
if resample == 'AREA':
497507
if isinstance(size, (list, tuple, Variable, paddle.pir.Value)):
498508
if len(size) == 0:
@@ -645,7 +655,7 @@ def _is_list_or_tuple_(data):
645655
attrs['out_h'] = out_shape[1]
646656
attrs['out_w'] = out_shape[2]
647657

648-
else:
658+
elif scale is not None and recompute_scale_factor is not True:
649659
if in_dynamic_mode() and isinstance(scale, Variable):
650660
if scale.shape == []:
651661
scale = float(scale)
@@ -676,6 +686,51 @@ def _is_list_or_tuple_(data):
676686
"Attr(scale)'s type should be float, int, list, tuple, or Tensor."
677687
)
678688

689+
elif recompute_scale_factor is not None and recompute_scale_factor:
690+
assert (
691+
scale is not None
692+
), "scale_factor must not be None when recompute_scale_factor=True"
693+
694+
if in_dynamic_mode() and isinstance(scale, Variable):
695+
if scale.shape == []:
696+
scale = float(scale)
697+
else:
698+
scale = list(scale.numpy())
699+
700+
dim = len(x.shape) - 2
701+
702+
if isinstance(scale, (float, int, numpy.ndarray)):
703+
scale_list = [float(scale)] * dim
704+
elif isinstance(scale, (list, tuple)):
705+
if len(scale) != dim:
706+
raise ValueError(
707+
f"scale_shape length should be {dim} for "
708+
f"input {len(x.shape)}-D tensor."
709+
)
710+
scale_list = list(map(float, scale))
711+
else:
712+
raise TypeError(
713+
"Attr(scale)'s type should be float, int, list, tuple, or Tensor."
714+
)
715+
716+
out_shape = []
717+
for i in range(dim):
718+
input_size = x.shape[i + 2]
719+
output_size = int(numpy.floor(float(input_size) * scale_list[i]))
720+
out_shape.append(output_size)
721+
722+
if len(x.shape) == 3:
723+
attrs['out_w'] = out_shape[0]
724+
elif len(x.shape) == 4:
725+
attrs['out_h'] = out_shape[0]
726+
attrs['out_w'] = out_shape[1]
727+
elif len(x.shape) == 5:
728+
attrs['out_d'] = out_shape[0]
729+
attrs['out_h'] = out_shape[1]
730+
attrs['out_w'] = out_shape[2]
731+
732+
scale = None
733+
679734
if in_dynamic_or_pir_mode():
680735
attr_list = []
681736
for k, v in attrs.items():

python/paddle/nn/layer/common.py

+4
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ class Upsample(Layer):
431431
align_corners: bool
432432
align_mode: int
433433
data_format: DataLayout1DVariant | DataLayout2D | DataLayout3D | None
434+
recompute_scale_factor: bool | None
434435
name: str | None
435436

436437
def __init__(
@@ -443,6 +444,7 @@ def __init__(
443444
data_format: (
444445
DataLayout1DVariant | DataLayout2D | DataLayout3D | None
445446
) = None,
447+
recompute_scale_factor: bool | None = None,
446448
name: str | None = None,
447449
) -> None:
448450
super().__init__()
@@ -452,6 +454,7 @@ def __init__(
452454
self.align_corners = align_corners
453455
self.align_mode = align_mode
454456
self.data_format = data_format
457+
self.recompute_scale_factor = recompute_scale_factor
455458
self.name = name
456459

457460
def forward(self, x: Tensor) -> Tensor:
@@ -475,6 +478,7 @@ def forward(self, x: Tensor) -> Tensor:
475478
align_corners=self.align_corners,
476479
align_mode=self.align_mode,
477480
data_format=self.data_format,
481+
recompute_scale_factor=self.recompute_scale_factor,
478482
name=self.name,
479483
)
480484

0 commit comments

Comments
 (0)