@@ -357,7 +357,8 @@ std::unique_ptr<::pir::Program> ConstructFowardIrProgram(
357
357
const paddle::framework::BlockDesc *backward_global_block,
358
358
const std::vector<std::string> output_names,
359
359
const std::vector<paddle::Tensor> &x,
360
- const std::vector<paddle::Tensor> ¶ms) {
360
+ const std::vector<paddle::Tensor> ¶ms,
361
+ const phi::Place &place) {
361
362
auto ir_ctx = ::pir::IrContext::Instance ();
362
363
auto program = std::make_unique<::pir::Program>(ir_ctx);
363
364
@@ -381,29 +382,29 @@ std::unique_ptr<::pir::Program> ConstructFowardIrProgram(
381
382
if (block->FindVarRecursive (name) == nullptr ) {
382
383
continue ;
383
384
}
384
- auto place = in_t .place ().GetType ();
385
+ auto p = in_t .place ().GetType ();
385
386
386
387
auto op_desc = block->PrependOp ();
387
388
op_desc->SetType (" data" );
388
389
op_desc->SetAttr (" shape" , std::vector<int64_t >());
389
390
// TODO(phlrain) : using tensor dtype
390
391
op_desc->SetAttr (" dtype" , 0 );
391
- op_desc->SetAttr (" place" , static_cast <int >(place ));
392
+ op_desc->SetAttr (" place" , static_cast <int >(p ));
392
393
op_desc->SetAttr (" name" , name);
393
394
op_desc->SetOutput (" out" , {name});
394
395
}
395
396
396
397
std::set<std::string> input_param_names;
397
398
for (auto ¶m : params) {
398
399
auto &name = param.name ();
399
- auto place = param.place ().GetType ();
400
+ auto p = param.place ().GetType ();
400
401
401
402
auto op_desc = local_program.MutableBlock (0 )->PrependOp ();
402
403
op_desc->SetType (" data" );
403
404
op_desc->SetAttr (" shape" , std::vector<int64_t >());
404
405
// TODO(phlrain) : using tensor dtype
405
406
op_desc->SetAttr (" dtype" , 0 );
406
- op_desc->SetAttr (" place" , static_cast <int >(place ));
407
+ op_desc->SetAttr (" place" , static_cast <int >(p ));
407
408
op_desc->SetAttr (" name" , name);
408
409
op_desc->SetOutput (" out" , {name});
409
410
@@ -445,7 +446,7 @@ std::unique_ptr<::pir::Program> ConstructFowardIrProgram(
445
446
446
447
program_translator.Translate ();
447
448
448
- auto ir_res = paddle::dialect::PdOpLowerToKernelPass (program.get ());
449
+ auto ir_res = paddle::dialect::PdOpLowerToKernelPass (program.get (), place );
449
450
450
451
if (FLAGS_new_ir_apply_inplace_pass) {
451
452
::pir::PassManager pm (::pir::IrContext::Instance (), 3 );
0 commit comments