14
14
15
15
#include " paddle/fluid/imperative/layer.h"
16
16
17
+ #include < algorithm>
17
18
#include < deque>
18
19
#include < limits>
19
20
#include < map>
@@ -77,52 +78,63 @@ class TensorAddToFunctor : public boost::static_visitor<> {
77
78
78
79
} // namespace detail
79
80
80
- void AddTo (Variable* src, Variable* dst, platform::Place place) {
81
- framework::Tensor* dst_tensor = dst-> GetMutable <framework::LoDTensor>();
82
- framework::Tensor* src_tensor = src-> GetMutable <framework::LoDTensor>();
83
-
84
- // FIXME(minqiyang): loss_grad op will pass a zero grad of label
85
- // ugly fix for it
86
- if (src_tensor-> numel () == 0 ) {
81
+ void AddTo (std::shared_ptr<VarBase> src, std::shared_ptr<VarBase> dst,
82
+ platform::Place place) {
83
+ if (!dst-> IsInitialize ()) {
84
+ VLOG ( 2 ) << " im here1 " ;
85
+ PADDLE_ENFORCE (src-> IsInitialize (), " Using uninitialized VarBase " );
86
+ dst-> var_ = std::move (src-> var_ );
87
+ dst-> SetInitialize ( true );
87
88
return ;
88
- }
89
+ } else {
90
+ framework::Tensor* dst_tensor =
91
+ dst->var_ ->GetMutable <framework::LoDTensor>();
92
+ framework::Tensor* src_tensor =
93
+ src->var_ ->GetMutable <framework::LoDTensor>();
94
+
95
+ // FIXME(minqiyang): loss_grad op will pass a zero grad of label
96
+ // ugly fix for it
97
+ if (src_tensor->numel () == 0 ) {
98
+ return ;
99
+ }
89
100
90
- PADDLE_ENFORCE (dst_tensor->numel () == src_tensor->numel (),
91
- " dst_numel %lld vs. src_numel %lld" , dst_tensor->numel (),
92
- src_tensor->numel ());
101
+ PADDLE_ENFORCE (dst_tensor->numel () == src_tensor->numel (),
102
+ " dst_numel %lld vs. src_numel %lld" , dst_tensor->numel (),
103
+ src_tensor->numel ());
93
104
94
- detail::TensorAddToFunctor<float > func (
95
- src_tensor->numel (), src_tensor->data <float >(),
96
- dst_tensor->mutable_data <float >(place));
97
- boost::apply_visitor (func, place);
105
+ detail::TensorAddToFunctor<float > func (
106
+ src_tensor->numel (), src_tensor->data <float >(),
107
+ dst_tensor->mutable_data <float >(place));
108
+ boost::apply_visitor (func, place);
109
+ }
98
110
}
99
111
100
- void ZeroGrads (VarBase* vb, const platform::Place& place) {
112
+ void ZeroGrads (const std::shared_ptr<imperative::VarBase> vb,
113
+ const platform::Place& place) {
101
114
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance ();
102
115
auto * dev_ctx = pool.Get (place);
103
116
auto grad_t = vb->var_ ->GetMutable <framework::LoDTensor>();
104
117
operators::math::set_constant (*dev_ctx, grad_t , 0.0 );
105
118
}
106
119
107
- void AddGradBySort (BackwardSumMap* bck_map, VarBase* target) {
108
- PADDLE_ENFORCE (bck_map->find (target) != bck_map->end (),
120
+ void AddGradBySort (BackwardSumMap* bck_map,
121
+ std::shared_ptr<imperative::VarBase> target) {
122
+ PADDLE_ENFORCE (bck_map->find (target.get ()) != bck_map->end (),
109
123
" Can't find %s in backward grad map" , target->Name ());
110
- std::pair<platform::Place, std::vector<std::pair<int , VarBase*>>>& current =
111
- bck_map->at (target);
112
- std::sort (
113
- current.second .begin (), current.second .end (),
114
- [](const std::pair<int , VarBase*>& a, const std::pair<int , VarBase*>& b) {
115
- return a.first > b.first ;
116
- });
124
+ std::pair<platform::Place,
125
+ std::vector<std::pair<int , std::shared_ptr<imperative::VarBase>>>>&
126
+ current = bck_map->at (target.get ());
127
+ std::sort (current.second .begin (), current.second .end (),
128
+ [](const std::pair<int , std::shared_ptr<imperative::VarBase>>& a,
129
+ const std::pair<int , std::shared_ptr<imperative::VarBase>>& b) {
130
+ return a.first > b.first ;
131
+ });
117
132
for (auto & var_pair : current.second ) {
118
- Variable* origin_grad = target->var_ .get ();
119
- Variable* grad_to_add = var_pair.second ->var_ .get ();
120
133
VLOG (10 ) << " add origin_grad: " << target->Name ();
121
134
VLOG (10 ) << " added grad: " << var_pair.second ->Name ()
122
135
<< " trace id is: " << var_pair.first ;
123
- AddTo (grad_to_add, origin_grad, current.first );
124
- delete var_pair.second ;
125
- var_pair.second = nullptr ;
136
+ AddTo (var_pair.second , target, current.first );
137
+ var_pair.second .reset ();
126
138
}
127
139
}
128
140
@@ -146,24 +158,22 @@ class Autograd {
146
158
while (!ready.empty ()) {
147
159
OpBase* ready_op = ready.front ();
148
160
ready.pop_front ();
149
- std::map<std::string, std:: vector<VarBase*>> input_grads =
161
+ std::vector<VarBasePtrMap> grads_outputs =
150
162
ready_op->ApplyGrad (&bck_map, &grad_ref, bck_stratedy);
151
163
152
- for (auto it = input_grads.rbegin (); it != input_grads.rend (); ++it) {
153
- const std::vector<VarBase*>& ingrads = it->second ;
154
- for (size_t i = 0 ; i < ingrads.size (); ++i) {
155
- if (!ingrads[i]) continue ;
156
- auto p = ready_op->input_vars_ [it->first ][i];
157
-
158
- if (p->IsStopGradient ()) continue ;
159
- OpBase* pre_op = ready_op->pre_ops_ [it->first ][i];
160
- if (!pre_op) continue ;
161
-
162
- dep_counts[pre_op] -= 1 ;
163
- PADDLE_ENFORCE (dep_counts[pre_op] >= 0 );
164
- bool pre_op_ready = dep_counts[pre_op] == 0 ;
165
- if (pre_op_ready) {
166
- ready.push_back (pre_op);
164
+ for (const auto & map : grads_outputs) {
165
+ for (auto it = map.rbegin (); it != map.rend (); ++it) {
166
+ const std::vector<std::shared_ptr<VarBase>>& grad_outs = it->second ;
167
+ for (size_t i = 0 ; i < grad_outs.size (); ++i) {
168
+ if (!grad_outs[i] || grad_outs[i]->IsStopGradient ()) continue ;
169
+ OpBase* pre_op = grad_outs[i]->PreOp ();
170
+ if (!pre_op) continue ;
171
+ dep_counts[pre_op] -= 1 ;
172
+ PADDLE_ENFORCE (dep_counts[pre_op] >= 0 );
173
+ bool pre_op_ready = dep_counts[pre_op] == 0 ;
174
+ if (pre_op_ready) {
175
+ ready.push_back (pre_op);
176
+ }
167
177
}
168
178
}
169
179
}
@@ -194,15 +204,15 @@ class Autograd {
194
204
for (const auto & map : candidate->grad_output_vars_ ) {
195
205
for (const auto & it : map) {
196
206
for (const auto & vb : it.second ) {
197
- ++(*grad_ref)[vb];
207
+ ++(*grad_ref)[vb. get () ];
198
208
}
199
209
}
200
210
}
201
211
}
202
212
for (auto it : candidate->pre_ops_ ) {
203
213
for (OpBase* pre_op : it.second ) {
204
214
if (!pre_op) continue ;
205
- VLOG (9 ) << " op dep " << candidate->Type () << " trace id "
215
+ VLOG (2 ) << " op dep " << candidate->Type () << " trace id "
206
216
<< candidate->trace_id_ << " <---- " << it.first << " <---- "
207
217
<< pre_op->Type () << " trace id " << pre_op->trace_id_ ;
208
218
if (visited.find (pre_op) == visited.end ()) {
@@ -254,7 +264,7 @@ framework::LoDTensor& VarBase::GradValue() {
254
264
return *(grads_->var_ ->GetMutable <framework::LoDTensor>());
255
265
}
256
266
257
- std::map<std::string, std:: vector<VarBase*> > OpBase::ApplyGrad (
267
+ std::vector<VarBasePtrMap > OpBase::ApplyGrad (
258
268
BackwardSumMap* bck_map, GradientRef* grad_ref,
259
269
const detail::BackwardStrategy& bck_stratedy) {
260
270
PADDLE_ENFORCE (!grad_op_descs_.empty (), " %s has no backward implementation" ,
@@ -274,17 +284,14 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
274
284
for (const auto & it : grad_output_variable_map) {
275
285
auto & outputs = tmp_grad_outputs[k][it.first ];
276
286
outputs.reserve (it.second .size ());
277
- for (VarBase* origin_grad_var_base : it.second ) {
278
- if (!origin_grad_var_base->IsInitialize ()) {
279
- origin_grad_var_base->InitBuffer ();
280
- ZeroGrads (origin_grad_var_base, place_);
281
- }
287
+ for (const std::shared_ptr<imperative::VarBase>& origin_grad_var_base :
288
+ it.second ) {
282
289
// Allocate a new variable
283
- VarBase* tmp_grad_var_base = new VarBase (
290
+ std::shared_ptr<imperative:: VarBase> tmp_grad_var_base ( new VarBase (
284
291
string::Sprintf (" %s@IGrad" , origin_grad_var_base->Name ()),
285
292
origin_grad_var_base->DataType (), origin_grad_var_base->Dims (),
286
- place_, true , false );
287
- outputs.emplace_back (tmp_grad_var_base);
293
+ place_, true , false )) ;
294
+ outputs.emplace_back (std::move ( tmp_grad_var_base) );
288
295
}
289
296
}
290
297
@@ -298,7 +305,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
298
305
auto & info = framework::OpInfoMap::Instance ().Get (grad_op_desc->Type ());
299
306
if (info.infer_var_type_ ) {
300
307
RuntimeInferVarTypeContext infer_var_type_ctx (
301
- &grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_ );
308
+ &grad_input_vars_[k], &tmp_grad_outputs[k], &(opbase-> Attrs ()) );
302
309
info.infer_var_type_ (&infer_var_type_ctx);
303
310
}
304
311
@@ -313,22 +320,22 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
313
320
for (const auto & it : grad_input_vars_[k]) {
314
321
auto & grad_invars = grad_invars_map[it.first ];
315
322
grad_invars.reserve (it.second .size ());
316
- for (VarBase* grad_inp : it.second ) {
323
+ for (const std::shared_ptr<imperative:: VarBase>& grad_inp : it.second ) {
317
324
PADDLE_ENFORCE_NOT_NULL (grad_inp->var_ , " op %s input %s nullptr" ,
318
325
grad_op_desc->Type (), grad_inp->Name ());
319
326
if (!grad_inp->IsInitialize ()) {
320
327
grad_inp->InitBuffer ();
321
328
ZeroGrads (grad_inp, place_);
322
329
}
323
- const VarBase* const_grad_inp = grad_inp;
330
+ const std::shared_ptr<imperative:: VarBase>& const_grad_inp = grad_inp;
324
331
grad_invars.emplace_back (const_grad_inp->var_ .get ());
325
332
}
326
333
}
327
334
328
335
for (const auto & it : tmp_grad_outputs[k]) {
329
336
auto & grad_outvars = grad_outvars_map[it.first ];
330
337
grad_outvars.reserve (it.second .size ());
331
- for (VarBase* grad_out : it.second ) {
338
+ for (const std::shared_ptr<imperative:: VarBase>& grad_out : it.second ) {
332
339
PADDLE_ENFORCE_NOT_NULL (grad_out->var_ , " op %s output %s nullptr" ,
333
340
grad_op_desc->Type (), grad_out->Name ());
334
341
@@ -355,56 +362,48 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
355
362
for (size_t i = 0 ; i < outputs.size (); ++i) {
356
363
// track outputs used by sum
357
364
if (bck_stratedy.sorted_sum_gradient_ ) {
358
- #ifndef PADDLE_WITH_CUDA
359
- VLOG (10 ) << " origin_outputs is : " << origin_outputs[i]->Name ()
360
- << " " ;
361
- VLOG (10 ) << origin_outputs[i]
362
- ->var_ ->GetMutable <framework::LoDTensor>()
363
- ->data <float >()[0 ];
364
- VLOG (10 ) << " outputs is : " << outputs[i]->Name () << " " ;
365
- VLOG (10 ) << outputs[i]
366
- ->var_ ->GetMutable <framework::LoDTensor>()
367
- ->data <float >()[0 ];
368
- #endif
369
- if (bck_map->find (origin_outputs[i]) != bck_map->end ()) {
365
+ if (bck_map->find (origin_outputs[i].get ()) != bck_map->end ()) {
370
366
VLOG (10 ) << " add sub grad to " << origin_outputs[i]->Name ();
371
- bck_map->at (origin_outputs[i])
367
+ bck_map->at (origin_outputs[i]. get () )
372
368
.second .emplace_back (
373
- std::pair<int , VarBase*>(this ->trace_id_ , outputs[i]));
369
+ std::pair<int , std::shared_ptr<imperative::VarBase>>(
370
+ this ->trace_id_ , std::move (outputs[i])));
374
371
} else {
375
372
VLOG (10 ) << " insert new map for " << origin_outputs[i]->Name ();
376
- std::pair<platform::Place, std::vector<std::pair<int , VarBase*>>>
377
- tmp (place_, {std::make_pair (this ->trace_id_ , outputs[i])});
378
- bck_map->insert (std::make_pair (origin_outputs[i], tmp));
373
+ std::pair<platform::Place,
374
+ std::vector<
375
+ std::pair<int , std::shared_ptr<imperative::VarBase>>>>
376
+ tmp (place_,
377
+ {std::make_pair (this ->trace_id_ , std::move (outputs[i]))});
378
+ bck_map->insert (std::make_pair (origin_outputs[i].get (), tmp));
379
379
}
380
380
381
- PADDLE_ENFORCE (grad_ref->find (origin_outputs[i]) != grad_ref->end (),
382
- " Can't find %s in grad_reference count map" ,
383
- origin_outputs[i]->Name ());
384
- PADDLE_ENFORCE (grad_ref->at (origin_outputs[i]) >= 1 ,
381
+ PADDLE_ENFORCE (
382
+ grad_ref->find (origin_outputs[i].get ()) != grad_ref->end (),
383
+ " Can't find %s in grad_reference count map" ,
384
+ origin_outputs[i]->Name ());
385
+ PADDLE_ENFORCE (grad_ref->at (origin_outputs[i].get ()) >= 1 ,
385
386
" Backward error when calculate grad reference" );
386
- if (grad_ref->at (origin_outputs[i]) > 1 ) {
387
+ if (grad_ref->at (origin_outputs[i]. get () ) > 1 ) {
387
388
VLOG (10 ) << " remove ref for " << origin_outputs[i]->Name ();
388
- grad_ref->at (origin_outputs[i])--;
389
+ grad_ref->at (origin_outputs[i]. get () )--;
389
390
} else {
390
391
VLOG (10 ) << " Add grad for: " << origin_outputs[i]->Name ();
391
392
AddGradBySort (bck_map, origin_outputs[i]);
392
- grad_ref->at (origin_outputs[i])--;
393
+ grad_ref->at (origin_outputs[i]. get () )--;
393
394
}
394
395
} else {
395
- framework::Variable* grad = outputs[i]->var_ .get ();
396
- framework::Variable* orig_grad = origin_outputs[i]->var_ .get ();
397
396
VLOG (10 ) << " AddTo Called with orig_grad is: "
398
397
<< origin_outputs[i]->name_ << " Grad to be added is "
399
398
<< outputs[i]->name_ ;
400
- AddTo (grad, orig_grad , place_);
401
- delete outputs[i];
399
+ AddTo (outputs[i], origin_outputs[i] , place_);
400
+ outputs[i]. reset () ;
402
401
}
403
402
}
404
403
}
405
404
}
406
405
407
- return input_vars_ ;
406
+ return grad_output_vars_ ;
408
407
}
409
408
410
409
void OpBase::InvokeBackwardHooks () {
@@ -434,9 +433,6 @@ void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {
434
433
var_->GetMutable <framework::LoDTensor>()->place ())),
435
434
grads_t , 1.0 );
436
435
437
- PADDLE_ENFORCE (
438
- grads_ ==
439
- pre_op_->output_vars_ [pre_op_out_name_][pre_op_out_idx_]->grads_ );
440
436
Autograd ().RunBackward (this , bck_stratedy);
441
437
}
442
438
0 commit comments