Skip to content

Commit f0674dd

Browse files
authored
[Prim] fix clip_decomp with where instead of maximum (#71844)
* clip_decomp * refine and add tests * fix typos * open tests && fix typos
1 parent 207e34a commit f0674dd

File tree

2 files changed

+41
-11
lines changed

2 files changed

+41
-11
lines changed

paddle/fluid/primitive/decomp_rule/decomp_rule/composite.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -931,17 +931,17 @@ Tensor clip_decomp(const Tensor& x, const Tensor& min, const Tensor& max) {
931931
auto min_reshape = min;
932932
auto max_reshape = max;
933933

934-
if (x.shape().size() == 0) {
935-
min_reshape = reshape<T>(min_reshape, {});
936-
max_reshape = reshape<T>(max_reshape, {});
937-
}
938-
939934
if (has_dynamic_shape(x.shape())) {
940935
min_reshape = backend::expand<T>(min_reshape, shape64<T>(x));
941936
max_reshape = backend::expand<T>(max_reshape, shape64<T>(x));
942937
} else {
943-
min_reshape = expand<T>(min_reshape, x.shape());
944-
max_reshape = expand<T>(max_reshape, x.shape());
938+
if (x.shape().size() == 0) {
939+
min_reshape = reshape<T>(min_reshape, {});
940+
max_reshape = reshape<T>(max_reshape, {});
941+
} else {
942+
min_reshape = expand<T>(min_reshape, x.shape());
943+
max_reshape = expand<T>(max_reshape, x.shape());
944+
}
945945
}
946946
if (min_reshape.dtype() != x.dtype()) {
947947
min_reshape = cast<T>(min_reshape, x.dtype());
@@ -950,8 +950,9 @@ Tensor clip_decomp(const Tensor& x, const Tensor& min, const Tensor& max) {
950950
if (max_reshape.dtype() != x.dtype()) {
951951
max_reshape = cast<T>(max_reshape, x.dtype());
952952
}
953-
954-
auto ans = maximum<T>(minimum<T>(x, max_reshape), min_reshape);
953+
auto ans = where<T>(x <= max_reshape,
954+
where<T>(x >= min_reshape, x, min_reshape),
955+
max_reshape);
955956
return ans;
956957
}
957958

test/legacy_test/test_clip_op.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def setUp(self):
4646
else:
4747
max_v = self.attrs['max']
4848

49-
input = np.random.random(self.shape).astype(self.dtype)
49+
input = self.generate_input()
5050
input[np.abs(input - min_v) < self.max_relative_error] = 0.5
5151
input[np.abs(input - max_v) < self.max_relative_error] = 0.5
5252
self.inputs['X'] = input
@@ -67,7 +67,7 @@ def test_check_output(self):
6767

6868
def test_check_grad_normal(self):
6969
paddle.enable_static()
70-
self.check_grad(['X'], 'Out', check_pir=True)
70+
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)
7171
paddle.disable_static()
7272

7373
def initTestCase(self):
@@ -78,6 +78,9 @@ def initTestCase(self):
7878
self.inputs['Max'] = np.array([0.8]).astype(self.dtype)
7979
self.inputs['Min'] = np.array([0.1]).astype(self.dtype)
8080

81+
def generate_input(self):
82+
return np.random.random(self.shape).astype(self.dtype)
83+
8184

8285
class TestCase1(TestClipOp):
8386
def initTestCase(self):
@@ -121,6 +124,19 @@ def initTestCase(self):
121124
self.min = 0.5
122125

123126

127+
class TestCase6(TestClipOp):
128+
def initTestCase(self):
129+
self.dtype = np.float32
130+
self.shape = (4, 8, 16)
131+
self.max = 1.0
132+
self.min = 0.5
133+
134+
def generate_input(self):
135+
return np.random.choice([self.min, self.max], self.shape).astype(
136+
self.dtype
137+
)
138+
139+
124140
class TestFP16Case1(TestClipOp):
125141
def initTestCase(self):
126142
self.dtype = np.float16
@@ -163,6 +179,19 @@ def initTestCase(self):
163179
self.min = 0.5
164180

165181

182+
class TestFP16Case6(TestClipOp):
183+
def initTestCase(self):
184+
self.dtype = np.float16
185+
self.shape = (4, 8, 16)
186+
self.max = 1.0
187+
self.min = 0.5
188+
189+
def generate_input(self):
190+
return np.random.choice([self.min, self.max], self.shape).astype(
191+
self.dtype
192+
)
193+
194+
166195
@unittest.skipIf(
167196
not core.is_compiled_with_cuda()
168197
or not core.is_bfloat16_supported(core.CUDAPlace(0)),

0 commit comments

Comments
 (0)