@@ -287,20 +287,30 @@ std::vector<std::vector<pir::OpResult>> IfOp::Vjp(
287
287
void WhileOp::Build (pir::Builder &builder, // NOLINT
288
288
pir::OperationArgument &argument, // NOLINT
289
289
pir::Value cond,
290
- const std::vector<pir::Value> &inputs) {
290
+ const std::vector<pir::Value> &inputs,
291
+ bool construct_body) {
291
292
argument.AddInput (cond);
292
293
argument.AddInputs (inputs);
293
- auto &body = argument.AddRegion ().emplace_back ();
294
294
std::vector<pir::Attribute> outs_stop_gradient;
295
- for (auto val : inputs) {
296
- argument.AddOutput (val.type ());
297
- auto arg = body.AddArgument (val.type ());
298
-
299
- auto bool_attr = val.attribute <pir::BoolAttribute>(kStopGradientAttrName );
300
- arg.set_attribute (kStopGradientAttrName ,
301
- bool_attr ? bool_attr : builder.bool_attr (false ));
302
- outs_stop_gradient.push_back (bool_attr ? bool_attr
303
- : builder.bool_attr (false ));
295
+ if (construct_body) {
296
+ auto &body = argument.AddRegion ().emplace_back ();
297
+ for (auto val : inputs) {
298
+ argument.AddOutput (val.type ());
299
+ auto arg = body.AddArgument (val.type ());
300
+ auto bool_attr = val.attribute <pir::BoolAttribute>(kStopGradientAttrName );
301
+ outs_stop_gradient.push_back (bool_attr ? bool_attr
302
+ : builder.bool_attr (false ));
303
+ arg.set_attribute (kStopGradientAttrName ,
304
+ bool_attr ? bool_attr : builder.bool_attr (false ));
305
+ }
306
+ } else {
307
+ argument.AddRegion (nullptr );
308
+ for (auto val : inputs) {
309
+ argument.AddOutput (val.type ());
310
+ auto bool_attr = val.attribute <pir::BoolAttribute>(kStopGradientAttrName );
311
+ outs_stop_gradient.push_back (bool_attr ? bool_attr
312
+ : builder.bool_attr (false ));
313
+ }
304
314
}
305
315
306
316
argument.AddAttribute (
@@ -343,6 +353,96 @@ void WhileOp::Print(pir::IrPrinter &printer) {
343
353
os << " \n }" ;
344
354
}
345
355
356
+ void WhileOp::VerifySig () {
357
+ VLOG (4 ) << " Start Verifying inputs, outputs and attributes for: WhileOp." ;
358
+ auto input_size = num_operands ();
359
+ PADDLE_ENFORCE_GE (
360
+ input_size,
361
+ 1u ,
362
+ phi::errors::PreconditionNotMet (
363
+ " The size %d of inputs must be greater or equal to 1." , input_size));
364
+
365
+ if (auto cond_type = operand_type (0 ).dyn_cast <pir::DenseTensorType>()) {
366
+ PADDLE_ENFORCE_EQ (
367
+ cond_type.dtype ().isa <pir::BoolType>(),
368
+ true ,
369
+ phi::errors::PreconditionNotMet (
370
+ " Type validation failed for the 0th input, it should be a "
371
+ " bool DenseTensorType." ));
372
+ } else if (auto cond_type =
373
+ operand_type (0 ).dyn_cast <AllocatedDenseTensorType>()) {
374
+ PADDLE_ENFORCE_EQ (
375
+ cond_type.dtype ().isa <pir::BoolType>(),
376
+ true ,
377
+ phi::errors::PreconditionNotMet (
378
+ " Type validation failed for the 0th input, it should be a "
379
+ " bool DenseTensorType." ));
380
+ } else {
381
+ PADDLE_THROW (phi::errors::PreconditionNotMet (
382
+ " Currently, the while op cond input only support bool dense_tensor "
383
+ " and bool allocated_dense_tensor." ));
384
+ }
385
+ PADDLE_ENFORCE_EQ ((*this )->num_regions (),
386
+ 1u ,
387
+ phi::errors::PreconditionNotMet (
388
+ " The size %d of regions must be equal to 1." ,
389
+ (*this )->num_regions ()));
390
+ auto output_size = num_results ();
391
+ PADDLE_ENFORCE_EQ (output_size + 1 ,
392
+ input_size,
393
+ phi::errors::PreconditionNotMet (
394
+ " The result size (%d) not equal to input size(%d) + 1." ,
395
+ num_results (),
396
+ input_size));
397
+ for (size_t index = 0 ; index < output_size; ++index ) {
398
+ PADDLE_ENFORCE_EQ (
399
+ operand_type (index + 1 ),
400
+ result_type (index ),
401
+ phi::errors::PreconditionNotMet (
402
+ " The (%d) result and operand type is not equal." , index ));
403
+ }
404
+ }
405
+
406
+ void WhileOp::VerifyRegion () {
407
+ VLOG (4 ) << " Start verifying sub regions for: WhileOp." ;
408
+ PADDLE_ENFORCE_EQ (
409
+ (*this )->region (0 ).size (),
410
+ 1u ,
411
+ phi::errors::PreconditionNotMet (" The size %d of body_region must be 1." ,
412
+ (*this )->region (0 ).size ()));
413
+ auto &body_block = body ();
414
+ auto output_size = num_results ();
415
+ PADDLE_ENFORCE_EQ (
416
+ body_block.args_size (),
417
+ output_size,
418
+ phi::errors::PreconditionNotMet (
419
+ " The result size (%d) not equal to block args size(%d) + 1." ,
420
+ output_size,
421
+ body_block.args_size ()));
422
+
423
+ PADDLE_ENFORCE_EQ (
424
+ body_block.empty (),
425
+ false ,
426
+ phi::errors::PreconditionNotMet (" The body block is empty." ));
427
+
428
+ auto yield_op = body_block.back ().dyn_cast <pir::YieldOp>();
429
+ auto input_size = num_operands ();
430
+ PADDLE_ENFORCE_EQ (
431
+ yield_op && yield_op.num_operands () == input_size,
432
+ true ,
433
+ phi::errors::PreconditionNotMet (
434
+ " The body block yield size not equal to operands size." ));
435
+ // Todo: fix other bugs and make the following code work.
436
+ // for (size_t index = 0; index < input_size; ++index) {
437
+ // PADDLE_ENFORCE_EQ(
438
+ // operand_type(index),
439
+ // yield_op.operand_type(index),
440
+ // phi::errors::PreconditionNotMet(
441
+ // "The (%d) operand and block yield type is not equal.", index));
442
+ // }
443
+ VLOG (4 ) << " Successful end verifying sub regions for: WhileOp." ;
444
+ }
445
+
346
446
std::vector<std::vector<pir::OpResult>> WhileOp::Vjp (
347
447
pir::Operation *op,
348
448
const std::vector<std::vector<pir::Value>> &inputs,
0 commit comments