Skip to content

Commit 3b70f87

Browse files
authored
Using Smart pointer to optimizer memory usage of dyGraph (#17768)
* for debug * test=develop, memory optimize for dygraph using shared_ptr * test=develop, fix travis ci showed error * test=develop, fix bug for recurrent usage of varbase * test=develop, init varbase when it need to be Add
1 parent 82358bf commit 3b70f87

File tree

8 files changed

+277
-204
lines changed

8 files changed

+277
-204
lines changed

paddle/fluid/imperative/layer.cc

Lines changed: 89 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include "paddle/fluid/imperative/layer.h"
1616

17+
#include <algorithm>
1718
#include <deque>
1819
#include <limits>
1920
#include <map>
@@ -77,52 +78,63 @@ class TensorAddToFunctor : public boost::static_visitor<> {
7778

7879
} // namespace detail
7980

80-
void AddTo(Variable* src, Variable* dst, platform::Place place) {
81-
framework::Tensor* dst_tensor = dst->GetMutable<framework::LoDTensor>();
82-
framework::Tensor* src_tensor = src->GetMutable<framework::LoDTensor>();
83-
84-
// FIXME(minqiyang): loss_grad op will pass a zero grad of label
85-
// ugly fix for it
86-
if (src_tensor->numel() == 0) {
81+
void AddTo(std::shared_ptr<VarBase> src, std::shared_ptr<VarBase> dst,
82+
platform::Place place) {
83+
if (!dst->IsInitialize()) {
84+
VLOG(2) << "im here1";
85+
PADDLE_ENFORCE(src->IsInitialize(), "Using uninitialized VarBase");
86+
dst->var_ = std::move(src->var_);
87+
dst->SetInitialize(true);
8788
return;
88-
}
89+
} else {
90+
framework::Tensor* dst_tensor =
91+
dst->var_->GetMutable<framework::LoDTensor>();
92+
framework::Tensor* src_tensor =
93+
src->var_->GetMutable<framework::LoDTensor>();
94+
95+
// FIXME(minqiyang): loss_grad op will pass a zero grad of label
96+
// ugly fix for it
97+
if (src_tensor->numel() == 0) {
98+
return;
99+
}
89100

90-
PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
91-
"dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
92-
src_tensor->numel());
101+
PADDLE_ENFORCE(dst_tensor->numel() == src_tensor->numel(),
102+
"dst_numel %lld vs. src_numel %lld", dst_tensor->numel(),
103+
src_tensor->numel());
93104

94-
detail::TensorAddToFunctor<float> func(
95-
src_tensor->numel(), src_tensor->data<float>(),
96-
dst_tensor->mutable_data<float>(place));
97-
boost::apply_visitor(func, place);
105+
detail::TensorAddToFunctor<float> func(
106+
src_tensor->numel(), src_tensor->data<float>(),
107+
dst_tensor->mutable_data<float>(place));
108+
boost::apply_visitor(func, place);
109+
}
98110
}
99111

100-
void ZeroGrads(VarBase* vb, const platform::Place& place) {
112+
void ZeroGrads(const std::shared_ptr<imperative::VarBase> vb,
113+
const platform::Place& place) {
101114
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
102115
auto* dev_ctx = pool.Get(place);
103116
auto grad_t = vb->var_->GetMutable<framework::LoDTensor>();
104117
operators::math::set_constant(*dev_ctx, grad_t, 0.0);
105118
}
106119

107-
void AddGradBySort(BackwardSumMap* bck_map, VarBase* target) {
108-
PADDLE_ENFORCE(bck_map->find(target) != bck_map->end(),
120+
void AddGradBySort(BackwardSumMap* bck_map,
121+
std::shared_ptr<imperative::VarBase> target) {
122+
PADDLE_ENFORCE(bck_map->find(target.get()) != bck_map->end(),
109123
"Can't find %s in backward grad map", target->Name());
110-
std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>& current =
111-
bck_map->at(target);
112-
std::sort(
113-
current.second.begin(), current.second.end(),
114-
[](const std::pair<int, VarBase*>& a, const std::pair<int, VarBase*>& b) {
115-
return a.first > b.first;
116-
});
124+
std::pair<platform::Place,
125+
std::vector<std::pair<int, std::shared_ptr<imperative::VarBase>>>>&
126+
current = bck_map->at(target.get());
127+
std::sort(current.second.begin(), current.second.end(),
128+
[](const std::pair<int, std::shared_ptr<imperative::VarBase>>& a,
129+
const std::pair<int, std::shared_ptr<imperative::VarBase>>& b) {
130+
return a.first > b.first;
131+
});
117132
for (auto& var_pair : current.second) {
118-
Variable* origin_grad = target->var_.get();
119-
Variable* grad_to_add = var_pair.second->var_.get();
120133
VLOG(10) << "add origin_grad: " << target->Name();
121134
VLOG(10) << "added grad: " << var_pair.second->Name()
122135
<< " trace id is: " << var_pair.first;
123-
AddTo(grad_to_add, origin_grad, current.first);
124-
delete var_pair.second;
125-
var_pair.second = nullptr;
136+
AddTo(var_pair.second, target, current.first);
137+
var_pair.second.reset();
126138
}
127139
}
128140

@@ -146,24 +158,22 @@ class Autograd {
146158
while (!ready.empty()) {
147159
OpBase* ready_op = ready.front();
148160
ready.pop_front();
149-
std::map<std::string, std::vector<VarBase*>> input_grads =
161+
std::vector<VarBasePtrMap> grads_outputs =
150162
ready_op->ApplyGrad(&bck_map, &grad_ref, bck_stratedy);
151163

152-
for (auto it = input_grads.rbegin(); it != input_grads.rend(); ++it) {
153-
const std::vector<VarBase*>& ingrads = it->second;
154-
for (size_t i = 0; i < ingrads.size(); ++i) {
155-
if (!ingrads[i]) continue;
156-
auto p = ready_op->input_vars_[it->first][i];
157-
158-
if (p->IsStopGradient()) continue;
159-
OpBase* pre_op = ready_op->pre_ops_[it->first][i];
160-
if (!pre_op) continue;
161-
162-
dep_counts[pre_op] -= 1;
163-
PADDLE_ENFORCE(dep_counts[pre_op] >= 0);
164-
bool pre_op_ready = dep_counts[pre_op] == 0;
165-
if (pre_op_ready) {
166-
ready.push_back(pre_op);
164+
for (const auto& map : grads_outputs) {
165+
for (auto it = map.rbegin(); it != map.rend(); ++it) {
166+
const std::vector<std::shared_ptr<VarBase>>& grad_outs = it->second;
167+
for (size_t i = 0; i < grad_outs.size(); ++i) {
168+
if (!grad_outs[i] || grad_outs[i]->IsStopGradient()) continue;
169+
OpBase* pre_op = grad_outs[i]->PreOp();
170+
if (!pre_op) continue;
171+
dep_counts[pre_op] -= 1;
172+
PADDLE_ENFORCE(dep_counts[pre_op] >= 0);
173+
bool pre_op_ready = dep_counts[pre_op] == 0;
174+
if (pre_op_ready) {
175+
ready.push_back(pre_op);
176+
}
167177
}
168178
}
169179
}
@@ -194,15 +204,15 @@ class Autograd {
194204
for (const auto& map : candidate->grad_output_vars_) {
195205
for (const auto& it : map) {
196206
for (const auto& vb : it.second) {
197-
++(*grad_ref)[vb];
207+
++(*grad_ref)[vb.get()];
198208
}
199209
}
200210
}
201211
}
202212
for (auto it : candidate->pre_ops_) {
203213
for (OpBase* pre_op : it.second) {
204214
if (!pre_op) continue;
205-
VLOG(9) << "op dep " << candidate->Type() << " trace id "
215+
VLOG(2) << "op dep " << candidate->Type() << " trace id "
206216
<< candidate->trace_id_ << " <---- " << it.first << " <---- "
207217
<< pre_op->Type() << " trace id " << pre_op->trace_id_;
208218
if (visited.find(pre_op) == visited.end()) {
@@ -254,7 +264,7 @@ framework::LoDTensor& VarBase::GradValue() {
254264
return *(grads_->var_->GetMutable<framework::LoDTensor>());
255265
}
256266

257-
std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
267+
std::vector<VarBasePtrMap> OpBase::ApplyGrad(
258268
BackwardSumMap* bck_map, GradientRef* grad_ref,
259269
const detail::BackwardStrategy& bck_stratedy) {
260270
PADDLE_ENFORCE(!grad_op_descs_.empty(), "%s has no backward implementation",
@@ -274,17 +284,14 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
274284
for (const auto& it : grad_output_variable_map) {
275285
auto& outputs = tmp_grad_outputs[k][it.first];
276286
outputs.reserve(it.second.size());
277-
for (VarBase* origin_grad_var_base : it.second) {
278-
if (!origin_grad_var_base->IsInitialize()) {
279-
origin_grad_var_base->InitBuffer();
280-
ZeroGrads(origin_grad_var_base, place_);
281-
}
287+
for (const std::shared_ptr<imperative::VarBase>& origin_grad_var_base :
288+
it.second) {
282289
// Allocate a new variable
283-
VarBase* tmp_grad_var_base = new VarBase(
290+
std::shared_ptr<imperative::VarBase> tmp_grad_var_base(new VarBase(
284291
string::Sprintf("%s@IGrad", origin_grad_var_base->Name()),
285292
origin_grad_var_base->DataType(), origin_grad_var_base->Dims(),
286-
place_, true, false);
287-
outputs.emplace_back(tmp_grad_var_base);
293+
place_, true, false));
294+
outputs.emplace_back(std::move(tmp_grad_var_base));
288295
}
289296
}
290297

@@ -298,7 +305,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
298305
auto& info = framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
299306
if (info.infer_var_type_) {
300307
RuntimeInferVarTypeContext infer_var_type_ctx(
301-
&grad_input_vars_[k], &tmp_grad_outputs[k], &attrs_);
308+
&grad_input_vars_[k], &tmp_grad_outputs[k], &(opbase->Attrs()));
302309
info.infer_var_type_(&infer_var_type_ctx);
303310
}
304311

@@ -313,22 +320,22 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
313320
for (const auto& it : grad_input_vars_[k]) {
314321
auto& grad_invars = grad_invars_map[it.first];
315322
grad_invars.reserve(it.second.size());
316-
for (VarBase* grad_inp : it.second) {
323+
for (const std::shared_ptr<imperative::VarBase>& grad_inp : it.second) {
317324
PADDLE_ENFORCE_NOT_NULL(grad_inp->var_, "op %s input %s nullptr",
318325
grad_op_desc->Type(), grad_inp->Name());
319326
if (!grad_inp->IsInitialize()) {
320327
grad_inp->InitBuffer();
321328
ZeroGrads(grad_inp, place_);
322329
}
323-
const VarBase* const_grad_inp = grad_inp;
330+
const std::shared_ptr<imperative::VarBase>& const_grad_inp = grad_inp;
324331
grad_invars.emplace_back(const_grad_inp->var_.get());
325332
}
326333
}
327334

328335
for (const auto& it : tmp_grad_outputs[k]) {
329336
auto& grad_outvars = grad_outvars_map[it.first];
330337
grad_outvars.reserve(it.second.size());
331-
for (VarBase* grad_out : it.second) {
338+
for (const std::shared_ptr<imperative::VarBase>& grad_out : it.second) {
332339
PADDLE_ENFORCE_NOT_NULL(grad_out->var_, "op %s output %s nullptr",
333340
grad_op_desc->Type(), grad_out->Name());
334341

@@ -355,56 +362,48 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad(
355362
for (size_t i = 0; i < outputs.size(); ++i) {
356363
// track outputs used by sum
357364
if (bck_stratedy.sorted_sum_gradient_) {
358-
#ifndef PADDLE_WITH_CUDA
359-
VLOG(10) << "origin_outputs is : " << origin_outputs[i]->Name()
360-
<< " ";
361-
VLOG(10) << origin_outputs[i]
362-
->var_->GetMutable<framework::LoDTensor>()
363-
->data<float>()[0];
364-
VLOG(10) << "outputs is : " << outputs[i]->Name() << " ";
365-
VLOG(10) << outputs[i]
366-
->var_->GetMutable<framework::LoDTensor>()
367-
->data<float>()[0];
368-
#endif
369-
if (bck_map->find(origin_outputs[i]) != bck_map->end()) {
365+
if (bck_map->find(origin_outputs[i].get()) != bck_map->end()) {
370366
VLOG(10) << "add sub grad to " << origin_outputs[i]->Name();
371-
bck_map->at(origin_outputs[i])
367+
bck_map->at(origin_outputs[i].get())
372368
.second.emplace_back(
373-
std::pair<int, VarBase*>(this->trace_id_, outputs[i]));
369+
std::pair<int, std::shared_ptr<imperative::VarBase>>(
370+
this->trace_id_, std::move(outputs[i])));
374371
} else {
375372
VLOG(10) << "insert new map for " << origin_outputs[i]->Name();
376-
std::pair<platform::Place, std::vector<std::pair<int, VarBase*>>>
377-
tmp(place_, {std::make_pair(this->trace_id_, outputs[i])});
378-
bck_map->insert(std::make_pair(origin_outputs[i], tmp));
373+
std::pair<platform::Place,
374+
std::vector<
375+
std::pair<int, std::shared_ptr<imperative::VarBase>>>>
376+
tmp(place_,
377+
{std::make_pair(this->trace_id_, std::move(outputs[i]))});
378+
bck_map->insert(std::make_pair(origin_outputs[i].get(), tmp));
379379
}
380380

381-
PADDLE_ENFORCE(grad_ref->find(origin_outputs[i]) != grad_ref->end(),
382-
"Can't find %s in grad_reference count map",
383-
origin_outputs[i]->Name());
384-
PADDLE_ENFORCE(grad_ref->at(origin_outputs[i]) >= 1,
381+
PADDLE_ENFORCE(
382+
grad_ref->find(origin_outputs[i].get()) != grad_ref->end(),
383+
"Can't find %s in grad_reference count map",
384+
origin_outputs[i]->Name());
385+
PADDLE_ENFORCE(grad_ref->at(origin_outputs[i].get()) >= 1,
385386
"Backward error when calculate grad reference");
386-
if (grad_ref->at(origin_outputs[i]) > 1) {
387+
if (grad_ref->at(origin_outputs[i].get()) > 1) {
387388
VLOG(10) << "remove ref for " << origin_outputs[i]->Name();
388-
grad_ref->at(origin_outputs[i])--;
389+
grad_ref->at(origin_outputs[i].get())--;
389390
} else {
390391
VLOG(10) << "Add grad for: " << origin_outputs[i]->Name();
391392
AddGradBySort(bck_map, origin_outputs[i]);
392-
grad_ref->at(origin_outputs[i])--;
393+
grad_ref->at(origin_outputs[i].get())--;
393394
}
394395
} else {
395-
framework::Variable* grad = outputs[i]->var_.get();
396-
framework::Variable* orig_grad = origin_outputs[i]->var_.get();
397396
VLOG(10) << "AddTo Called with orig_grad is: "
398397
<< origin_outputs[i]->name_ << " Grad to be added is "
399398
<< outputs[i]->name_;
400-
AddTo(grad, orig_grad, place_);
401-
delete outputs[i];
399+
AddTo(outputs[i], origin_outputs[i], place_);
400+
outputs[i].reset();
402401
}
403402
}
404403
}
405404
}
406405

407-
return input_vars_;
406+
return grad_output_vars_;
408407
}
409408

410409
void OpBase::InvokeBackwardHooks() {
@@ -434,9 +433,6 @@ void VarBase::RunBackward(const detail::BackwardStrategy& bck_stratedy) {
434433
var_->GetMutable<framework::LoDTensor>()->place())),
435434
grads_t, 1.0);
436435

437-
PADDLE_ENFORCE(
438-
grads_ ==
439-
pre_op_->output_vars_[pre_op_out_name_][pre_op_out_idx_]->grads_);
440436
Autograd().RunBackward(this, bck_stratedy);
441437
}
442438

0 commit comments

Comments
 (0)