Skip to content

Commit 5337272

Browse files
fix ir default place (PaddlePaddle#57243)
1 parent 68fa80a commit 5337272

File tree

3 files changed

+16
-9
lines changed

3 files changed

+16
-9
lines changed

paddle/fluid/eager/to_static/run_program_op_node.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,13 @@ inline void RunProgramAPI(
403403

404404
if (FLAGS_enable_new_ir_in_executor) {
405405
// build new ir program
406-
auto ir_program = paddle::framework::ConstructFowardIrProgram(
407-
forward_global_block, backward_global_block, output_names, x, params);
406+
auto ir_program =
407+
paddle::framework::ConstructFowardIrProgram(forward_global_block,
408+
backward_global_block,
409+
output_names,
410+
x,
411+
params,
412+
place);
408413
interpreter_core =
409414
paddle::framework::CreateNewIRInterpreterCoreInfoToCache(
410415
std::move(ir_program),

paddle/fluid/framework/executor_cache.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,8 @@ std::unique_ptr<::pir::Program> ConstructFowardIrProgram(
357357
const paddle::framework::BlockDesc *backward_global_block,
358358
const std::vector<std::string> output_names,
359359
const std::vector<paddle::Tensor> &x,
360-
const std::vector<paddle::Tensor> &params) {
360+
const std::vector<paddle::Tensor> &params,
361+
const phi::Place &place) {
361362
auto ir_ctx = ::pir::IrContext::Instance();
362363
auto program = std::make_unique<::pir::Program>(ir_ctx);
363364

@@ -381,29 +382,29 @@ std::unique_ptr<::pir::Program> ConstructFowardIrProgram(
381382
if (block->FindVarRecursive(name) == nullptr) {
382383
continue;
383384
}
384-
auto place = in_t.place().GetType();
385+
auto p = in_t.place().GetType();
385386

386387
auto op_desc = block->PrependOp();
387388
op_desc->SetType("data");
388389
op_desc->SetAttr("shape", std::vector<int64_t>());
389390
// TODO(phlrain) : using tensor dtype
390391
op_desc->SetAttr("dtype", 0);
391-
op_desc->SetAttr("place", static_cast<int>(place));
392+
op_desc->SetAttr("place", static_cast<int>(p));
392393
op_desc->SetAttr("name", name);
393394
op_desc->SetOutput("out", {name});
394395
}
395396

396397
std::set<std::string> input_param_names;
397398
for (auto &param : params) {
398399
auto &name = param.name();
399-
auto place = param.place().GetType();
400+
auto p = param.place().GetType();
400401

401402
auto op_desc = local_program.MutableBlock(0)->PrependOp();
402403
op_desc->SetType("data");
403404
op_desc->SetAttr("shape", std::vector<int64_t>());
404405
// TODO(phlrain) : using tensor dtype
405406
op_desc->SetAttr("dtype", 0);
406-
op_desc->SetAttr("place", static_cast<int>(place));
407+
op_desc->SetAttr("place", static_cast<int>(p));
407408
op_desc->SetAttr("name", name);
408409
op_desc->SetOutput("out", {name});
409410

@@ -445,7 +446,7 @@ std::unique_ptr<::pir::Program> ConstructFowardIrProgram(
445446

446447
program_translator.Translate();
447448

448-
auto ir_res = paddle::dialect::PdOpLowerToKernelPass(program.get());
449+
auto ir_res = paddle::dialect::PdOpLowerToKernelPass(program.get(), place);
449450

450451
if (FLAGS_new_ir_apply_inplace_pass) {
451452
::pir::PassManager pm(::pir::IrContext::Instance(), 3);

paddle/fluid/framework/executor_cache.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,8 @@ std::unique_ptr<::pir::Program> ConstructFowardIrProgram(
254254
const paddle::framework::BlockDesc* backward_global_block,
255255
const std::vector<std::string> output_names,
256256
const std::vector<paddle::Tensor>& x,
257-
const std::vector<paddle::Tensor>& params);
257+
const std::vector<paddle::Tensor>& params,
258+
const phi::Place& place);
258259

259260
std::unique_ptr<::pir::Program> ConstructBackwardIrProgram(
260261
const paddle::framework::BlockDesc* backward_global_block,

0 commit comments

Comments
 (0)