Skip to content

Commit 3e00f80

Browse files
[PIR save/load]Fix bug in PirInterpreterEngine and open jit layer test (#72101)
* fix bug in PirInterpreterEngine and open jit layer test * refine
1 parent 8dd22d9 commit 3e00f80

File tree

2 files changed

+30
-35
lines changed

2 files changed

+30
-35
lines changed

paddle/fluid/jit/engine/pir_interpreter_engine.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ PirInterpreterEngine::PirInterpreterEngine(
3535
common::errors::PreconditionNotMet(
3636
"There is no operator in ProgramDesc."));
3737
utils::ShareParamsIntoScope(info_->ParamNames(), params_dict_, &scope_);
38+
prog_ = paddle::dialect::PdOpLowerToKernelPass(prog_.get(), place_);
3839
CreateInterpreterCore();
3940
}
4041

@@ -60,7 +61,6 @@ std::vector<Tensor> PirInterpreterEngine::operator()(
6061

6162
std::vector<DenseTensor> PirInterpreterEngine::operator()(
6263
const std::vector<DenseTensor> &inputs) {
63-
prog_ = paddle::dialect::PdOpLowerToKernelPass(prog_.get(), place_);
6464
utils::ShareIntoScope(info_->InputArgNames(), inputs, &scope_);
6565

6666
// the latter can be moved to python side.

test/legacy_test/test_jit_layer.py

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -68,26 +68,23 @@ def tearDown(self):
6868
self.temp_dir.cleanup()
6969

7070
def test_multi_load(self):
71-
with paddle.pir_utils.OldIrGuard():
72-
paddle.disable_static()
73-
x = paddle.full([2, 4], 2)
74-
model = Net()
75-
with enable_to_static_guard(False):
76-
forward_out1 = model.forward(x)
77-
infer_out1 = model.infer(x)
78-
model_path = os.path.join(self.temp_dir.name, 'multi_program')
79-
paddle.jit.save(model, model_path, combine_params=True)
80-
place = paddle.CPUPlace()
81-
if paddle.is_compiled_with_cuda():
82-
place = paddle.CUDAPlace(0)
83-
jit_layer = Layer()
84-
jit_layer.load(model_path, place)
85-
forward_out2 = jit_layer.forward(x)
86-
infer_out2 = jit_layer.infer(x)
87-
np.testing.assert_allclose(
88-
forward_out1, forward_out2[0], rtol=1e-05
89-
)
90-
np.testing.assert_allclose(infer_out1, infer_out2[0], rtol=1e-05)
71+
paddle.disable_static()
72+
x = paddle.full([2, 4], 2)
73+
model = Net()
74+
with enable_to_static_guard(False):
75+
forward_out1 = model.forward(x)
76+
infer_out1 = model.infer(x)
77+
model_path = os.path.join(self.temp_dir.name, 'multi_program')
78+
paddle.jit.save(model, model_path, combine_params=True)
79+
place = paddle.CPUPlace()
80+
if paddle.is_compiled_with_cuda():
81+
place = paddle.CUDAPlace(0)
82+
jit_layer = Layer()
83+
jit_layer.load(model_path, place)
84+
forward_out2 = jit_layer.forward(x)
85+
infer_out2 = jit_layer.infer(x)
86+
np.testing.assert_allclose(forward_out1, forward_out2[0], rtol=1e-05)
87+
np.testing.assert_allclose(infer_out1, infer_out2[0], rtol=1e-05)
9188

9289
def test_multi_jit_load(self):
9390
x = paddle.full([2, 4], 2)
@@ -127,20 +124,18 @@ def tearDown(self):
127124
self.temp_dir.cleanup()
128125

129126
def test_mkl_output(self):
130-
with paddle.pir_utils.OldIrGuard():
131-
paddle.disable_static()
132-
with _dygraph_place_guard(place=paddle.CPUPlace()):
133-
net = SaveLinear()
134-
model_path = os.path.join(self.temp_dir.name, 'save_linear')
135-
paddle.jit.save(net, model_path, combine_params=True)
136-
137-
layer = Layer()
138-
print("load ", model_path)
139-
layer.load(model_path, paddle.CPUPlace())
140-
x = paddle.ones([498, 80])
141-
out = layer.forward(x)
142-
out = paddle.unsqueeze(out[0], 0)
143-
np.testing.assert_equal(out.shape, [1, 498, 80])
127+
paddle.disable_static()
128+
with _dygraph_place_guard(place=paddle.CPUPlace()):
129+
net = SaveLinear()
130+
model_path = os.path.join(self.temp_dir.name, 'save_linear')
131+
paddle.jit.save(net, model_path, combine_params=True)
132+
133+
layer = Layer()
134+
layer.load(model_path, paddle.CPUPlace())
135+
x = paddle.ones([498, 80])
136+
out = layer.forward(x)
137+
out = paddle.unsqueeze(out[0], 0)
138+
np.testing.assert_equal(out.shape, [1, 498, 80])
144139

145140
def test_mkl_jit_output(self):
146141
with _dygraph_place_guard(place=paddle.CPUPlace()):

0 commit comments

Comments
 (0)