-
Notifications
You must be signed in to change notification settings - Fork 5.7k
support testing when training and handle dropout and batch_norm operator in testing mode #5734
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
669ed1d
4ee8abc
984b56b
77c417a
8f05d71
ac49e8b
38e2766
369daae
9360bf8
a62e802
422b837
26429a4
ed34a76
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,8 @@ namespace framework { | |
|
||
const std::string kFeedOpType = "feed"; | ||
const std::string kFetchOpType = "fetch"; | ||
const std::string kDropOutOpType = "dropout"; | ||
const std::string kBatchNormOpType = "batch_norm"; | ||
|
||
bool HasDependentVar(const OpDesc& op_desc, | ||
const std::set<std::string>& dependent_vars) { | ||
|
@@ -46,7 +48,8 @@ bool IsTarget(const OpDesc& op_desc) { | |
return false; | ||
} | ||
|
||
void prune_impl(const ProgramDesc& input, ProgramDesc* output, int block_id) { | ||
void prune_impl(const ProgramDesc& input, ProgramDesc* output, int block_id, | ||
bool is_test) { | ||
// TODO(tonyyang-svail): | ||
// - will change to use multiple blocks for RNN op and Cond Op | ||
|
||
|
@@ -99,11 +102,23 @@ void prune_impl(const ProgramDesc& input, ProgramDesc* output, int block_id) { | |
*op_field->Add() = input.blocks(block_id).ops(i); | ||
} | ||
} | ||
if (is_test) { | ||
for (auto& op_desc : *op_field) { | ||
if (op_desc.type() == kDropOutOpType || | ||
op_desc.type() == kBatchNormOpType) { | ||
for (auto& attr : *op_desc.mutable_attrs()) { | ||
if (attr.name() == "is_test") { | ||
attr.set_b(true); | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies | ||
void Prune(const ProgramDesc& input, ProgramDesc* output) { | ||
prune_impl(input, output, 0); | ||
void Prune(const ProgramDesc& input, ProgramDesc* output, bool is_test) { | ||
prune_impl(input, output, 0, is_test); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think setting attribute There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We now have
|
||
} | ||
|
||
} // namespace framework | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
break;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done