Skip to content

Commit ae03a28

Browse files
committed
[Dy2St][CINN] Explicit use phi backend in more CINN uts
1 parent 8c5c40a commit ae03a28

File tree

5 files changed

+20
-4
lines changed

5 files changed

+20
-4
lines changed

test/ir/pir/cinn/sub_graphs/base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ def train(self, net, to_static, with_prim=False, with_cinn=False):
6161
)
6262
else:
6363
net = paddle.jit.to_static(
64-
net(), full_graph=True, input_spec=self.input_specs
64+
net(),
65+
backend=None,
66+
full_graph=True,
67+
input_spec=self.input_specs,
6568
)
6669
if self.with_train:
6770
net.train()

test/ir/pir/cinn/symbolic/test_decomp_inference_predictor_run.py

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def setUp(self):
5858
),
5959
],
6060
full_graph=True,
61+
backend=None,
6162
)
6263
paddle.jit.save(
6364
model,

test/ir/pir/cinn/symbolic/test_reshape_dyshape2stshape.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,11 @@ def train(self, net, to_static, with_cinn=False):
5454
full_graph=True,
5555
)
5656
else:
57-
net = paddle.jit.to_static(net, full_graph=True)
57+
net = paddle.jit.to_static(
58+
net,
59+
backend=None,
60+
full_graph=True,
61+
)
5862
paddle.seed(123)
5963
outs = net(*self.inputs)
6064
return outs

test/ir/pir/cinn/symbolic/test_sub_graph_flatten.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,11 @@ def train(self, net, to_static, with_prim=False, with_cinn=False):
5151
full_graph=True,
5252
)
5353
else:
54-
net = paddle.jit.to_static(net, full_graph=True)
54+
net = paddle.jit.to_static(
55+
net,
56+
backend=None,
57+
full_graph=True,
58+
)
5559
paddle.seed(123)
5660
outs = net(*self.inputs)
5761
return outs

test/legacy_test/test_stack_op.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,11 @@ def setUp(self):
396396
def test_list_single_tensor(self):
397397
expect = paddle.stack(self.x)
398398
paddle.base.core._set_prim_all_enabled(True)
399-
st_model = paddle.jit.to_static(paddle.stack, full_graph=True)
399+
st_model = paddle.jit.to_static(
400+
paddle.stack,
401+
backend=None,
402+
full_graph=True,
403+
)
400404
actual = st_model(self.x)
401405
np.testing.assert_allclose(expect, actual)
402406
paddle.enable_static()

0 commit comments

Comments
 (0)