Skip to content

Commit 9c28f50

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into symbolic_infer
2 parents ab4790e + 04bceca commit 9c28f50

20 files changed

+278
-75
lines changed

paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu

+4-4
Original file line numberDiff line numberDiff line change
@@ -295,10 +295,8 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {
295295
sum_exp_logits = ctx.AllocateTmpTensor<T, phi::GPUContext>({N, 1}, dev_ctx);
296296
sum_exp_logits.mutable_data<T>(place);
297297

298-
auto eigen_sum_exp_logits =
299-
phi::funcs::EigenMatrix<T>::From(sum_exp_logits);
300-
eigen_sum_exp_logits.device(*dev_ctx.eigen_device()) =
301-
eigen_softmax.sum(along_axis);
298+
phi::SumKernel<T, phi::GPUContext>(
299+
dev_ctx, softmax_2d, {-1}, softmax_2d.dtype(), true, &sum_exp_logits);
302300

303301
if (comm_ctx) {
304302
comm_ctx->AllReduce(&sum_exp_logits, sum_exp_logits, ncclSum, stream);
@@ -333,6 +331,8 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {
333331
N);
334332
}
335333

334+
auto eigen_sum_exp_logits =
335+
phi::funcs::EigenMatrix<T>::From(sum_exp_logits);
336336
eigen_softmax.device(*dev_ctx.eigen_device()) =
337337
(eigen_softmax *
338338
eigen_sum_exp_logits.inverse().broadcast(one_by_class));

paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc

+111-11
Original file line numberDiff line numberDiff line change
@@ -287,20 +287,30 @@ std::vector<std::vector<pir::OpResult>> IfOp::Vjp(
287287
void WhileOp::Build(pir::Builder &builder, // NOLINT
288288
pir::OperationArgument &argument, // NOLINT
289289
pir::Value cond,
290-
const std::vector<pir::Value> &inputs) {
290+
const std::vector<pir::Value> &inputs,
291+
bool construct_body) {
291292
argument.AddInput(cond);
292293
argument.AddInputs(inputs);
293-
auto &body = argument.AddRegion().emplace_back();
294294
std::vector<pir::Attribute> outs_stop_gradient;
295-
for (auto val : inputs) {
296-
argument.AddOutput(val.type());
297-
auto arg = body.AddArgument(val.type());
298-
299-
auto bool_attr = val.attribute<pir::BoolAttribute>(kStopGradientAttrName);
300-
arg.set_attribute(kStopGradientAttrName,
301-
bool_attr ? bool_attr : builder.bool_attr(false));
302-
outs_stop_gradient.push_back(bool_attr ? bool_attr
303-
: builder.bool_attr(false));
295+
if (construct_body) {
296+
auto &body = argument.AddRegion().emplace_back();
297+
for (auto val : inputs) {
298+
argument.AddOutput(val.type());
299+
auto arg = body.AddArgument(val.type());
300+
auto bool_attr = val.attribute<pir::BoolAttribute>(kStopGradientAttrName);
301+
outs_stop_gradient.push_back(bool_attr ? bool_attr
302+
: builder.bool_attr(false));
303+
arg.set_attribute(kStopGradientAttrName,
304+
bool_attr ? bool_attr : builder.bool_attr(false));
305+
}
306+
} else {
307+
argument.AddRegion(nullptr);
308+
for (auto val : inputs) {
309+
argument.AddOutput(val.type());
310+
auto bool_attr = val.attribute<pir::BoolAttribute>(kStopGradientAttrName);
311+
outs_stop_gradient.push_back(bool_attr ? bool_attr
312+
: builder.bool_attr(false));
313+
}
304314
}
305315

306316
argument.AddAttribute(
@@ -343,6 +353,96 @@ void WhileOp::Print(pir::IrPrinter &printer) {
343353
os << "\n }";
344354
}
345355

356+
void WhileOp::VerifySig() {
357+
VLOG(4) << "Start Verifying inputs, outputs and attributes for: WhileOp.";
358+
auto input_size = num_operands();
359+
PADDLE_ENFORCE_GE(
360+
input_size,
361+
1u,
362+
phi::errors::PreconditionNotMet(
363+
"The size %d of inputs must be greater or equal to 1.", input_size));
364+
365+
if (auto cond_type = operand_type(0).dyn_cast<pir::DenseTensorType>()) {
366+
PADDLE_ENFORCE_EQ(
367+
cond_type.dtype().isa<pir::BoolType>(),
368+
true,
369+
phi::errors::PreconditionNotMet(
370+
"Type validation failed for the 0th input, it should be a "
371+
"bool DenseTensorType."));
372+
} else if (auto cond_type =
373+
operand_type(0).dyn_cast<AllocatedDenseTensorType>()) {
374+
PADDLE_ENFORCE_EQ(
375+
cond_type.dtype().isa<pir::BoolType>(),
376+
true,
377+
phi::errors::PreconditionNotMet(
378+
"Type validation failed for the 0th input, it should be a "
379+
"bool DenseTensorType."));
380+
} else {
381+
PADDLE_THROW(phi::errors::PreconditionNotMet(
382+
"Currently, the while op cond input only support bool dense_tensor "
383+
"and bool allocated_dense_tensor."));
384+
}
385+
PADDLE_ENFORCE_EQ((*this)->num_regions(),
386+
1u,
387+
phi::errors::PreconditionNotMet(
388+
"The size %d of regions must be equal to 1.",
389+
(*this)->num_regions()));
390+
auto output_size = num_results();
391+
PADDLE_ENFORCE_EQ(output_size + 1,
392+
input_size,
393+
phi::errors::PreconditionNotMet(
394+
"The result size (%d) not equal to input size(%d) + 1.",
395+
num_results(),
396+
input_size));
397+
for (size_t index = 0; index < output_size; ++index) {
398+
PADDLE_ENFORCE_EQ(
399+
operand_type(index + 1),
400+
result_type(index),
401+
phi::errors::PreconditionNotMet(
402+
"The (%d) result and operand type is not equal.", index));
403+
}
404+
}
405+
406+
void WhileOp::VerifyRegion() {
407+
VLOG(4) << "Start verifying sub regions for: WhileOp.";
408+
PADDLE_ENFORCE_EQ(
409+
(*this)->region(0).size(),
410+
1u,
411+
phi::errors::PreconditionNotMet("The size %d of body_region must be 1.",
412+
(*this)->region(0).size()));
413+
auto &body_block = body();
414+
auto output_size = num_results();
415+
PADDLE_ENFORCE_EQ(
416+
body_block.args_size(),
417+
output_size,
418+
phi::errors::PreconditionNotMet(
419+
"The result size (%d) not equal to block args size(%d) + 1.",
420+
output_size,
421+
body_block.args_size()));
422+
423+
PADDLE_ENFORCE_EQ(
424+
body_block.empty(),
425+
false,
426+
phi::errors::PreconditionNotMet("The body block is empty."));
427+
428+
auto yield_op = body_block.back().dyn_cast<pir::YieldOp>();
429+
auto input_size = num_operands();
430+
PADDLE_ENFORCE_EQ(
431+
yield_op && yield_op.num_operands() == input_size,
432+
true,
433+
phi::errors::PreconditionNotMet(
434+
"The body block yield size not equal to operands size."));
435+
// Todo: fix other bugs and make the following code work.
436+
// for (size_t index = 0; index < input_size; ++index) {
437+
// PADDLE_ENFORCE_EQ(
438+
// operand_type(index),
439+
// yield_op.operand_type(index),
440+
// phi::errors::PreconditionNotMet(
441+
// "The (%d) operand and block yield type is not equal.", index));
442+
// }
443+
VLOG(4) << "Successful end verifying sub regions for: WhileOp.";
444+
}
445+
346446
std::vector<std::vector<pir::OpResult>> WhileOp::Vjp(
347447
pir::Operation *op,
348448
const std::vector<std::vector<pir::Value>> &inputs,

paddle/fluid/pir/dialect/operator/ir/control_flow_op.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,14 @@ class WhileOp : public pir::Op<WhileOp, VjpInterface> {
7777
static void Build(pir::Builder &builder, // NOLINT
7878
pir::OperationArgument &argument, // NOLINT
7979
pir::Value cond,
80-
const std::vector<pir::Value> &inputs);
80+
const std::vector<pir::Value> &inputs,
81+
bool construct_body = true);
8182
TEST_API pir::Block &body();
8283
pir::Value cond();
8384
const pir::Block::ArgListType &block_args() { return body().args(); }
8485
void Print(pir::IrPrinter &printer); // NOLINT
85-
void VerifySig() {}
86-
void VerifyRegion() {}
86+
void VerifySig();
87+
void VerifyRegion();
8788
static std::vector<std::vector<pir::OpResult>> Vjp(
8889
pir::Operation *op,
8990
const std::vector<std::vector<pir::Value>> &inputs_,

paddle/fluid/pir/dialect/operator/ir/op_dialect.cc

+2-4
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,11 @@ OperatorDialect::OperatorDialect(pir::IrContext *ctx)
6161
ctx->GetOrRegisterDialect<::pir::ControlFlowDialect>();
6262
auto info = ctx->GetRegisteredOpInfo(pir::TuplePushOp::name());
6363
info.AttachInterface(std::move(
64-
pir::InterfaceValue::
65-
Get<pir::TuplePushOp, VjpInterface, TuplePushOpVjpInterfaceModel>()));
64+
pir::InterfaceValue::Get<VjpInterface, TuplePushOpVjpInterfaceModel>()));
6665

6766
info = ctx->GetRegisteredOpInfo(pir::CombineOp::name());
6867
info.AttachInterface(std::move(
69-
pir::InterfaceValue::Get<pir::CombineOp,
70-
InferSymbolicShapeInterface,
68+
pir::InterfaceValue::Get<InferSymbolicShapeInterface,
7169
CombineOpInferSymbolicShapeInterfaceModel>()));
7270
}
7371

paddle/fluid/pybind/control_flow_api.cc

+74-11
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ using paddle::dialect::AssertOp;
4040
using paddle::dialect::HasElementsOp;
4141
using paddle::dialect::IfOp;
4242
using paddle::dialect::WhileOp;
43+
using paddle::pybind::PyIfOp;
44+
using paddle::pybind::PyWhileOp;
4345
using pir::Block;
4446
using pir::Builder;
4547
using pir::Operation;
@@ -51,8 +53,6 @@ using pir::Type;
5153
using pir::Value;
5254
using pir::YieldOp;
5355
using pybind11::return_value_policy;
54-
55-
using paddle::pybind::PyIfOp;
5656
namespace {
5757

5858
void BindIfOp(py::module* m) {
@@ -79,22 +79,24 @@ void BindIfOp(py::module* m) {
7979
}
8080

8181
void BindWhileOp(py::module* m) {
82-
m->def("build_while_op", [](Value cond, py::list loop_vars) {
82+
m->def("build_while_op", [](Value cond, py::list loop_vars) -> PyWhileOp {
8383
std::vector<Value> loop_values;
8484
for (auto var : loop_vars) {
8585
loop_values.push_back(var.cast<Value>());
8686
}
87-
return ApiBuilder::Instance().GetBuilder()->Build<WhileOp>(cond,
88-
loop_values);
87+
return PyWhileOp(
88+
ApiBuilder::Instance().GetBuilder()->Build<WhileOp>(cond, loop_values));
8989
});
90-
py::class_<WhileOp> while_op(*m, "WhileOp", R"DOC(
90+
py::class_<PyWhileOp> while_op(*m, "WhileOp", R"DOC(
9191
WhileOp in python api.
9292
)DOC");
93-
while_op.def("body", &WhileOp::body, return_value_policy::reference)
94-
.def("as_operation", &WhileOp::operation, return_value_policy::reference)
93+
while_op.def("body", &PyWhileOp::body, return_value_policy::reference)
94+
.def(
95+
"as_operation", &PyWhileOp::operation, return_value_policy::reference)
9596
.def("block_arguments",
9697
&WhileOp::block_args,
97-
return_value_policy::reference);
98+
return_value_policy::reference)
99+
.def("optimize_update", &PyWhileOp::OptimizeUpdate);
98100
}
99101

100102
void BindAssertOp(py::module* m) {
@@ -183,7 +185,7 @@ PyIfOp::PyIfOp(IfOp if_op) : IfOp(if_op) {
183185

184186
void PyIfOp::UpdateOutput() {
185187
PADDLE_ENFORCE_NOT_NULL(
186-
*this,
188+
operation_,
187189
paddle::platform::errors::InvalidArgument(
188190
"The if_op in PyIfOp used to update output can't be nullptr"));
189191
auto block = parent();
@@ -197,7 +199,68 @@ void PyIfOp::UpdateOutput() {
197199
cond(), true_region().TakeBack(), false_region().TakeBack());
198200
block->Assign(iter, new_if_op);
199201
IfOp::operator=(new_if_op);
200-
VerifyRegion();
202+
operation_->Verify();
203+
}
204+
205+
PyWhileOp::PyWhileOp(WhileOp while_op) : WhileOp(while_op) {
206+
PADDLE_ENFORCE_NOT_NULL(
207+
operation_,
208+
paddle::platform::errors::InvalidArgument(
209+
"The while_op used to construct PyWhileOp can't be nullptr"));
210+
}
211+
212+
std::vector<Value> PyWhileOp::OptimizeUpdate() {
213+
PADDLE_ENFORCE_NOT_NULL(operation_,
214+
paddle::platform::errors::InvalidArgument(
215+
"The while_op in PyWhileOp used to remove unused "
216+
"loop vars can't be nullptr"));
217+
auto parent_block = parent();
218+
PADDLE_ENFORCE_NOT_NULL(
219+
parent_block,
220+
paddle::platform::errors::InvalidArgument(
221+
"The parent block of while_op which used to remove "
222+
"unused loop vars can't be nullptr"));
223+
224+
operation_->Verify();
225+
auto& body_block = body();
226+
auto yield_op = body_block.back().dyn_cast<YieldOp>();
227+
auto operand_num = operation_->num_operands();
228+
bool no_change = true;
229+
std::vector<size_t> index_vec;
230+
std::vector<Value> res, new_input, new_yield_val{yield_op.operand_source(0)};
231+
for (uint32_t i = 0; i < num_results(); ++i) {
232+
res.push_back(result(i));
233+
}
234+
for (size_t operand_index = 1u, arg_index = 0u; operand_index < operand_num;
235+
++operand_index) {
236+
if (yield_op.operand_source(operand_index) == body_block.arg(arg_index)) {
237+
body_block.arg(arg_index).ReplaceAllUsesWith(
238+
operand_source(operand_index));
239+
body_block.EraseArgument(arg_index);
240+
no_change = false;
241+
res[operand_index - 1u] = operand_source(operand_index);
242+
} else {
243+
new_input.push_back(operand_source(operand_index));
244+
index_vec.push_back(operand_index - 1u);
245+
new_yield_val.push_back(yield_op.operand_source(operand_index));
246+
++arg_index;
247+
}
248+
}
249+
if (no_change) return res;
250+
Block::Iterator iter = **this;
251+
Builder builder(ir_context(), false);
252+
auto new_while_op = builder.Build<WhileOp>(cond(), new_input, false);
253+
new_while_op->region(0).swap(std::move(operation_->region(0)));
254+
parent_block->Assign(iter, new_while_op);
255+
WhileOp::operator=(new_while_op);
256+
body_block.pop_back();
257+
builder.SetInsertionPointToBlockEnd(&body_block);
258+
builder.Build<YieldOp>(new_yield_val);
259+
operation_->Verify();
260+
for (size_t result_index = 0; result_index < num_results(); ++result_index) {
261+
res[index_vec[result_index]] = result(result_index);
262+
}
263+
return res;
201264
}
202265

203266
void BindControlFlowApi(py::module* m) {

paddle/fluid/pybind/control_flow_api.h

+16
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,22 @@ class PyIfOp : public dialect::IfOp {
2525
void UpdateOutput();
2626
};
2727

28+
class PyWhileOp : public dialect::WhileOp {
29+
public:
30+
explicit PyWhileOp(dialect::WhileOp while_op);
31+
32+
///
33+
/// \brief Construct a new while_op to replace the original while_op. The
34+
/// input, output, and parameters of the new while_op no longer contain the
35+
/// variables that have not been modified in the loop. The size of the return
36+
/// value is equal to the output size of the original while_op, where the
37+
/// value of the read-only loop variable is the corresponding operand of the
38+
/// original while_op, and the value of the non-read-only loop variable is the
39+
/// corresponding output of the new while_op,
40+
///
41+
std::vector<pir::Value> OptimizeUpdate();
42+
};
43+
2844
void BindControlFlowApi(pybind11::module *m);
2945
} // namespace pybind
3046
} // namespace paddle

paddle/fluid/pybind/pir.cc

+2-8
Original file line numberDiff line numberDiff line change
@@ -527,14 +527,8 @@ void BindOperation(py::module *m) {
527527
})
528528
.def("as_if_op",
529529
[](Operation &self) { return PyIfOp(self.dyn_cast<IfOp>()); })
530-
.def("as_while_op", [](Operation &self) -> WhileOp {
531-
auto while_op = self.dyn_cast<WhileOp>();
532-
if (!while_op) {
533-
PADDLE_THROW(phi::errors::InvalidArgument(
534-
"Can't cast non-while type Operation to WhileOp."));
535-
}
536-
return while_op;
537-
});
530+
.def("as_while_op",
531+
[](Operation &self) { return PyWhileOp(self.dyn_cast<WhileOp>()); });
538532
py::class_<Operation::BlockContainer> block_container(
539533
*m, "Operation_BlockContainer", R"DOC(
540534
The Operation_BlockContainer only use to walk all blocks in the operation.

paddle/phi/infermeta/unary.cc

+1
Original file line numberDiff line numberDiff line change
@@ -1859,6 +1859,7 @@ void IncrementInferMeta(const MetaTensor& x, float value, MetaTensor* out) {
18591859
product(x.dims())));
18601860
out->set_dims(x.dims());
18611861
out->share_lod(x);
1862+
out->set_layout(x.layout());
18621863
out->set_dtype(x.dtype());
18631864
}
18641865

0 commit comments

Comments
 (0)