Skip to content

Commit fbc3021

Browse files
committed
refine WhileGradOp code
1 parent 8f962f7 commit fbc3021

File tree

1 file changed

+7
-17
lines changed

1 file changed

+7
-17
lines changed

paddle/operators/while_op.cc

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -219,18 +219,15 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
219219

220220
auto *grad_block = this->grad_block_[0];
221221
auto *fwd_block = grad_block->ParentBlock();
222-
// auto *parent_block = fwd_block->ParentBlock();
223222

224223
// Not all of IGs will be generated by inner gradient operators of while op.
225224
// Ignore IGs that is not generated by the inside block.
226225
std::unordered_set<std::string> inner_op_outputs;
227-
LOG(INFO) << "FUCK1";
228226
for (const auto *op : grad_block->AllOps()) {
229227
for (auto &oname : op->OutputArgumentNames()) {
230228
inner_op_outputs.insert(oname);
231229
}
232230
}
233-
LOG(INFO) << "FUCK2";
234231
auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
235232
for (auto &each_ig : igs) {
236233
if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) {
@@ -243,11 +240,13 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
243240
// OG should be re-calculated by step blocks, since many outputs of while op
244241
// do not need to calculate gradients.
245242
std::unordered_set<std::string> block_ins;
246-
std::copy(Input(kX).begin(), Input(kX).end(),
247-
std::inserter(block_ins, block_ins.end()));
248-
std::copy(Output(kOutputs).begin(), Output(kOutputs).end(),
249-
std::inserter(block_ins, block_ins.end()));
250-
243+
block_ins.reserve(Input(kX).size() + Output(kOutputs).size());
244+
for (auto &p : Input(kX)) {
245+
block_ins.insert(p);
246+
}
247+
for (auto &o : Output(kOutputs)) {
248+
block_ins.insert(o);
249+
}
251250
std::unordered_set<std::string> extra_inputs;
252251
for (const auto *op : grad_block->AllOps()) {
253252
for (auto &input_name : op->InputArgumentNames()) {
@@ -257,15 +256,6 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
257256
fwd_block->FindVar(input_name) != nullptr) {
258257
continue;
259258
}
260-
261-
/*
262-
if (parent_block->FindVarRecursive(input_name) == nullptr) {
263-
VLOG(5) << "WARNING! Variable '" << input_name
264-
<< "' is the input of '" << op->Type()
265-
<< "'. But can not be found in any block.";
266-
continue;
267-
}
268-
*/
269259
extra_inputs.insert(input_name);
270260
}
271261
for (auto &output_name : op->OutputArgumentNames()) {

0 commit comments

Comments
 (0)