Skip to content

Commit de2166c

Browse files
authored
[XPU] fix unit test of test_pad3d_op_xpu. (#51962)
1 parent 153351e commit de2166c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

python/paddle/fluid/tests/unittests/xpu/test_pad3d_op_xpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def check_static_result_1(self, place):
189189
value = 100
190190
input_data = np.random.rand(*input_shape).astype(self.dtype)
191191
x = paddle.static.data(
192-
name="x", shape=input_shape, dtype="float32"
192+
name="x", shape=input_shape, dtype=self.dtype
193193
)
194194
result = F.pad(
195195
x=x, pad=pad, value=value, mode=mode, data_format="NCDHW"
@@ -212,7 +212,7 @@ def check_static_result_2(self, place):
212212
mode = "reflect"
213213
input_data = np.random.rand(*input_shape).astype(self.dtype)
214214
x = paddle.static.data(
215-
name="x", shape=input_shape, dtype="float32"
215+
name="x", shape=input_shape, dtype=self.dtype
216216
)
217217
result1 = F.pad(x=x, pad=pad, mode=mode, data_format="NCDHW")
218218
result2 = F.pad(x=x, pad=pad, mode=mode, data_format="NDHWC")
@@ -240,7 +240,7 @@ def check_static_result_3(self, place):
240240
mode = "replicate"
241241
input_data = np.random.rand(*input_shape).astype(self.dtype)
242242
x = paddle.static.data(
243-
name="x", shape=input_shape, dtype="float32"
243+
name="x", shape=input_shape, dtype=self.dtype
244244
)
245245
result1 = F.pad(x=x, pad=pad, mode=mode, data_format="NCDHW")
246246
result2 = F.pad(x=x, pad=pad, mode=mode, data_format="NDHWC")

0 commit comments

Comments
 (0)