@@ -219,18 +219,15 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
219
219
220
220
auto *grad_block = this ->grad_block_ [0 ];
221
221
auto *fwd_block = grad_block->ParentBlock ();
222
- // auto *parent_block = fwd_block->ParentBlock();
223
222
224
223
// Not all of IGs will be generated by inner gradient operators of while op.
225
224
// Ignore IGs that is not generated by the inside block.
226
225
std::unordered_set<std::string> inner_op_outputs;
227
- LOG (INFO) << " FUCK1" ;
228
226
for (const auto *op : grad_block->AllOps ()) {
229
227
for (auto &oname : op->OutputArgumentNames ()) {
230
228
inner_op_outputs.insert (oname);
231
229
}
232
230
}
233
- LOG (INFO) << " FUCK2" ;
234
231
auto igs = InputGrad (kX , /* do not drop empty gradient*/ false );
235
232
for (auto &each_ig : igs) {
236
233
if (inner_op_outputs.find (each_ig) == inner_op_outputs.end ()) {
@@ -243,11 +240,13 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
243
240
// OG should be re-calculated by step blocks, since many outputs of while op
244
241
// do not need to calculate gradients.
245
242
std::unordered_set<std::string> block_ins;
246
- std::copy (Input (kX ).begin (), Input (kX ).end (),
247
- std::inserter (block_ins, block_ins.end ()));
248
- std::copy (Output (kOutputs ).begin (), Output (kOutputs ).end (),
249
- std::inserter (block_ins, block_ins.end ()));
250
-
243
+ block_ins.reserve (Input (kX ).size () + Output (kOutputs ).size ());
244
+ for (auto &p : Input (kX )) {
245
+ block_ins.insert (p);
246
+ }
247
+ for (auto &o : Output (kOutputs )) {
248
+ block_ins.insert (o);
249
+ }
251
250
std::unordered_set<std::string> extra_inputs;
252
251
for (const auto *op : grad_block->AllOps ()) {
253
252
for (auto &input_name : op->InputArgumentNames ()) {
@@ -257,15 +256,6 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
257
256
fwd_block->FindVar (input_name) != nullptr ) {
258
257
continue ;
259
258
}
260
-
261
- /*
262
- if (parent_block->FindVarRecursive(input_name) == nullptr) {
263
- VLOG(5) << "WARNING! Variable '" << input_name
264
- << "' is the input of '" << op->Type()
265
- << "'. But can not be found in any block.";
266
- continue;
267
- }
268
- */
269
259
extra_inputs.insert (input_name);
270
260
}
271
261
for (auto &output_name : op->OutputArgumentNames ()) {
0 commit comments