Skip to content

【Prim】Refactor prim flags system #49930

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/eager/auto_code_generator/generator/eager_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,7 +1841,7 @@ def GenerateHigherOrderNodeCreationCode(self):

if is_composite_grad_api and next_grad_node_creation_str != '':
next_grad_node_creation_str = f"""
if (!paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
if (!paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{
{next_grad_node_creation_str}
}}
"""
Expand Down Expand Up @@ -2261,7 +2261,7 @@ def GenerateNodeDefinition(
# TODO(Ruting):using composite only when we don't have backward kernel in the future.
elif is_composite_grad_api:
grad_function_call_str = f"""
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {{
if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {{
{indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str});
VLOG(4) << "Composite api {composite_grad_api_name} is called ";
}}else{{
Expand Down
12 changes: 11 additions & 1 deletion paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string.h>
#include <memory>
#include <sstream>
#include <string>
Expand Down Expand Up @@ -166,7 +167,16 @@ Tensor full<DescTensor>(const IntArray& shape,
phi::errors::InvalidArgument(
"We only support float32/float16 for full, but we got data type: %s",
phi::DataTypeToString(dtype)));
op->SetAttr("value", value.to<float>());
if (dtype == phi::DataType::FLOAT32) {
op->SetAttr("value", value.to<float>());
} else if (dtype == phi::DataType::FLOAT64) {
op->SetAttr("str_value", std::to_string(value.to<double>()));
} else if (dtype == phi::DataType::FLOAT16) {
op->SetAttr("str_value", std::to_string(value.to<float>()));
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"We only support float64/float32/float16 for full"));
}
op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype));
op->SetOutput(
"Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ void divide_grad(const Tensor& x,
} // indicate we will compute dy
if (dx) {
// dx = (1/y) * dout
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0);
auto one_tensor = full<T>(phi::vectorize(y.dims()), 1.0, y.dtype());
auto tmp0 = divide<T>(one_tensor, y);
auto dx_res = multiply<T>(tmp0, out_grad);
if (y.dims() != x.dims()) {
Expand Down
16 changes: 8 additions & 8 deletions paddle/fluid/prim/tests/test_eager_prim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,16 +68,16 @@ TEST(EagerPrim, TanhBackwardTest) {
paddle::experimental::Tensor out0 = tanh_ad_func(tensor0);
std::vector<paddle::experimental::Tensor> outs0 = {out0};
// Disable prim
PrimCommonUtils::SetPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
// 4. Run Backward
egr::Backward(outs0, {}, false);

paddle::experimental::Tensor out1 = tanh_ad_func(tensor1);
std::vector<paddle::experimental::Tensor> outs1 = {out1};
// Disable prim
PrimCommonUtils::SetPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
// 4. Run Backward
::egr::Backward(outs1, {}, false);
VLOG(7)
Expand All @@ -99,10 +99,10 @@ TEST(EagerPrim, TanhBackwardTest) {
}

TEST(EagerPrim, TestFlags) {
PrimCommonUtils::SetPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
}

} // namespace prim
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/prim/tests/test_static_prim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,10 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) {
}

TEST(StaticPrim, TestFlags) {
PrimCommonUtils::SetPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(true);
ASSERT_TRUE(PrimCommonUtils::IsBwdPrimEnabled());
PrimCommonUtils::SetBwdPrimEnabled(false);
ASSERT_FALSE(PrimCommonUtils::IsBwdPrimEnabled());
}

} // namespace prim
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/prim/utils/static/static_global_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ namespace paddle {
namespace prim {
StaticCompositeContext* StaticCompositeContext::static_composite_context_ =
new StaticCompositeContext();
thread_local bool StaticCompositeContext::enable_prim_ = false;
thread_local bool StaticCompositeContext::enable_bwd_prim_ = false;
thread_local bool StaticCompositeContext::enable_fwd_prim_ = false;
} // namespace prim
} // namespace paddle
16 changes: 13 additions & 3 deletions paddle/fluid/prim/utils/static/static_global_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,27 @@ class StaticCompositeContext {
return generator_->Generate(key);
}

void SetPrimEnabled(bool enable_prim) { enable_prim_ = enable_prim; }
void SetBwdPrimEnabled(bool enable_prim) { enable_bwd_prim_ = enable_prim; }

bool IsPrimEnabled() { return enable_prim_; }
bool IsBwdPrimEnabled() { return enable_bwd_prim_; }

void SetFwdPrimEnabled(bool enable_prim) { enable_fwd_prim_ = enable_prim; }

bool IsFwdPrimEnabled() { return enable_fwd_prim_; }

void SetAllPrimEnabled(bool enable_prim) {
enable_fwd_prim_ = enable_prim;
enable_bwd_prim_ = enable_prim;
}

private:
StaticCompositeContext()
: current_block_desc_(nullptr), generator_(new UniqueNameGenerator()) {}

framework::BlockDesc* current_block_desc_;
std::unique_ptr<UniqueNameGenerator> generator_;
static thread_local bool enable_prim_;
static thread_local bool enable_bwd_prim_;
static thread_local bool enable_fwd_prim_;
static StaticCompositeContext* static_composite_context_;
DISABLE_COPY_AND_ASSIGN(StaticCompositeContext);
};
Expand Down
20 changes: 16 additions & 4 deletions paddle/fluid/prim/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,24 @@
PADDLE_DEFINE_EXPORTED_bool(prim_enabled, false, "enable_prim or not");
namespace paddle {
namespace prim {
bool PrimCommonUtils::IsPrimEnabled() {
return StaticCompositeContext::Instance().IsPrimEnabled();
bool PrimCommonUtils::IsBwdPrimEnabled() {
return StaticCompositeContext::Instance().IsBwdPrimEnabled();
}

void PrimCommonUtils::SetPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetPrimEnabled(enable_prim);
void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim);
}

bool PrimCommonUtils::IsFwdPrimEnabled() {
return StaticCompositeContext::Instance().IsFwdPrimEnabled();
}

void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetFwdPrimEnabled(enable_prim);
}

void PrimCommonUtils::SetAllPrimEnabled(bool enable_prim) {
return StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim);
}
} // namespace prim
} // namespace paddle
7 changes: 5 additions & 2 deletions paddle/fluid/prim/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@ namespace paddle {
namespace prim {
class PrimCommonUtils {
public:
static bool IsPrimEnabled();
static void SetPrimEnabled(bool enabled);
static bool IsBwdPrimEnabled();
static void SetBwdPrimEnabled(bool enabled);
static bool IsFwdPrimEnabled();
static void SetFwdPrimEnabled(bool enabled);
static void SetAllPrimEnabled(bool enabled);
};
} // namespace prim
} // namespace paddle
15 changes: 12 additions & 3 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,16 @@ PYBIND11_MODULE(libpaddle, m) {
return oss.str();
});

m.def("set_prim_enabled", &paddle::prim::PrimCommonUtils::SetPrimEnabled);
m.def("is_prim_enabled", &paddle::prim::PrimCommonUtils::IsPrimEnabled);
m.def("__set_bwd_prim_enabled",
&paddle::prim::PrimCommonUtils::SetBwdPrimEnabled);
m.def("_is_bwd_prim_enabled",
&paddle::prim::PrimCommonUtils::IsBwdPrimEnabled);
m.def("__set_fwd_prim_enabled",
&paddle::prim::PrimCommonUtils::SetFwdPrimEnabled);
m.def("_is_fwd_prim_enabled",
&paddle::prim::PrimCommonUtils::IsFwdPrimEnabled);
m.def("__set_all_prim_enabled",
&paddle::prim::PrimCommonUtils::SetAllPrimEnabled);
m.def("set_num_threads", &platform::SetNumThreads);

m.def("disable_signal_handler", &DisableSignalHandler);
Expand Down Expand Up @@ -1264,8 +1272,9 @@ All parameter, weight, gradient are variables in Paddle.
// priority of GradCompOpMaker is less than GradCompMaker for better
// performance.
std::vector<std::unique_ptr<OpDesc>> grad_op_descs;
if (paddle::prim::PrimCommonUtils::IsPrimEnabled()) {
if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) {
if (grad_comp_op_maker != nullptr) {
VLOG(3) << "Runing composite fun for " << op_desc.Type();
grad_op_descs = grad_comp_op_maker(op_desc,
no_grad_set,
&grad_to_var,
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
kernel :
func : add_grad
no_need_buffer : x, y
composite : add_grad(Tensor x, Tensor y, Tensor out_grad, int axis)
composite : add_grad(x, y, out_grad, axis)
backward : add_double_grad
inplace : (out_grad -> x_grad)

Expand Down Expand Up @@ -390,7 +390,7 @@
param : [x, y]
kernel :
func : divide_grad
composite : divide_grad(Tensor x, Tensor y, Tensor out, Tensor out_grad, int axis = -1)
composite : divide_grad(x, y, out, out_grad, -1)
backward : divide_double_grad

- backward_op : dropout_grad
Expand Down Expand Up @@ -1319,7 +1319,7 @@
kernel :
func : subtract_grad
no_need_buffer : x, y
composite : subtract_grad(Tensor x, Tensor y, Tensor out_grad, int axis)
composite : subtract_grad(x, y, out_grad, axis)
backward : subtract_double_grad
inplace : (out_grad -> x_grad)

Expand Down
5 changes: 3 additions & 2 deletions python/paddle/fluid/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,14 +1493,15 @@ def update_distop_context(

# remove some backward ops
# TODO(Jiabin): Support this in prime later, it will prune add_grad, fix this problem
if not core.is_prim_enabled():
if not core._is_bwd_prim_enabled():
not_need_ops = _find_not_need_ops(
grad_op_descs, ops, input_grad_names_set
)

grad_op_descs = [
op_desc for op_desc in grad_op_descs if op_desc not in not_need_ops
]
else:
logging.debug("Runing backward composite and disable find_not_need_ops")

# append op_desc in grad_op_descs to target_block
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
Expand Down
Loading