Skip to content

Commit b9f44b6

Browse files
committed
remove condition dygraph grad maker, modify local name; test=develop
1 parent 13c4a60 commit b9f44b6

File tree

4 files changed

+11
-13
lines changed

4 files changed

+11
-13
lines changed

paddle/fluid/imperative/layer.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,7 @@ void OpBase::CreateOperatorBase() {
264264
auto input_name_map = CreateVarNameMap(info, type_, ins_, true);
265265
auto output_name_map = CreateVarNameMap(info, type_, outs_, false);
266266
op_ = framework::OpRegistry::CreateOp(type_, std::move(input_name_map),
267-
std::move(output_name_map),
268-
std::move(attrs_));
267+
std::move(output_name_map), attrs_);
269268
}
270269

271270
void OpBase::Run(const NameVarBaseMap& ins, const NameVarBaseMap& outs) {

paddle/fluid/imperative/layer.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,8 @@ class OpBase : public std::enable_shared_from_this<OpBase> {
384384
return grad_pending_ops_;
385385
}
386386

387-
void SetGradPendingOps(std::vector<OpBase*>* vec_temp) {
388-
grad_pending_ops_.swap(*vec_temp);
387+
void SetGradPendingOps(std::vector<OpBase*> vec_temp) {
388+
grad_pending_ops_.swap(vec_temp);
389389
}
390390

391391
void InsertGradPendingOps(OpBase* op) { grad_pending_ops_.emplace_back(op); }
@@ -426,11 +426,11 @@ class OpBase : public std::enable_shared_from_this<OpBase> {
426426
void SetType(const std::string& type) { type_ = type; }
427427
void SetInput(const std::string& name,
428428
std::vector<std::shared_ptr<VarBase>> vec_var_base) {
429-
ins_.emplace(name, vec_var_base);
429+
ins_[name] = std::move(vec_var_base);
430430
}
431431
void SetOutput(const std::string& name,
432432
std::vector<std::shared_ptr<VarBase>> vec_var_base) {
433-
outs_.emplace(name, vec_var_base);
433+
outs_[name] = std::move(vec_var_base);
434434
}
435435
void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; }
436436
void SetAttr(const std::string& name, const framework::Attribute& v) {

paddle/fluid/imperative/tracer.cc

+5-5
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,14 @@ void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
9494
size_t grad_op_num = grad_op_bases_.size();
9595

9696
std::set<VarBase*> set_input_vars;
97-
for (auto& grad_in_it : ins) {
98-
for (auto& var_base_it : grad_in_it.second) {
97+
for (auto& fwd_in_it : ins) {
98+
for (auto& var_base_it : fwd_in_it.second) {
9999
set_input_vars.insert(var_base_it.get());
100100
}
101101
}
102102

103-
for (auto& grad_out_it : outs) {
104-
for (auto& var_base_it : grad_out_it.second) {
103+
for (auto& fwd_out_it : outs) {
104+
for (auto& var_base_it : fwd_out_it.second) {
105105
set_input_vars.insert(var_base_it.get());
106106
}
107107
}
@@ -148,7 +148,7 @@ void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
148148
std::vector<OpBase*> vec_preceding_ops(visited_preceding_ops.begin(),
149149
visited_preceding_ops.end());
150150

151-
grad_op->SetGradPendingOps(&vec_preceding_ops);
151+
grad_op->SetGradPendingOps(std::move(vec_preceding_ops));
152152

153153
// this OpBase* is just used to manage op's life time
154154
engine_->InsertOp(grad_op.get(), grad_op);

paddle/fluid/operators/controlflow/conditional_block_op.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpMaker<T> {
204204
namespace ops = paddle::operators;
205205
REGISTER_OPERATOR(conditional_block, ops::ConditionalBlockOp,
206206
ops::ConditionalBlockOpProtoMaker,
207-
ops::ConditionalBlockGradMaker<paddle::framework::OpDesc>,
208-
ops::ConditionalBlockGradMaker<paddle::imperative::OpBase>);
207+
ops::ConditionalBlockGradMaker<paddle::framework::OpDesc>);
209208
REGISTER_OPERATOR(conditional_block_grad, ops::ConditionalBlockGradOp,
210209
ops::ConditionalBlockGradInferShape);

0 commit comments

Comments
 (0)