Skip to content

Fix FLAGS_fuse_parameter_memory_size unit from Bytes to MBytes. #17924

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions paddle/fluid/framework/ir/alloc_continuous_space_for_grad_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_registry.h"

DEFINE_uint64(fuse_parameter_memory_size, 0, // Bytes
"fuse_parameter_memory_size is up limited memory size "
DEFINE_double(fuse_parameter_memory_size, -1.0, // MBytes
"fuse_parameter_memory_size is up limited memory size(MB)"
"of one group parameters' gradient which is the input "
"of communication calling(e.g NCCLAllReduce). "
"The default value is 0, it means that "
Expand All @@ -51,13 +51,11 @@ void SetFuseParameterGroupsSize(int group_size) {

int GetFuseParameterGroupsSize() { return FLAGS_fuse_parameter_groups_size; }

void SetFuseParameterMemorySize(uint64_t memory_size) {
void SetFuseParameterMemorySize(double memory_size) {
FLAGS_fuse_parameter_memory_size = memory_size;
}

uint64_t GetFuseParameterMemorySize() {
return FLAGS_fuse_parameter_memory_size;
}
double GetFuseParameterMemorySize() { return FLAGS_fuse_parameter_memory_size; }

static framework::proto::VarType::Type kDefaultDtype =
framework::proto::VarType::Type::VarType_Type_BOOL;
Expand Down Expand Up @@ -230,15 +228,16 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
}
VLOG(10) << out.str()
<< ", group size:" << group_grads_params->at(i).size()
<< ", group memory size:" << gps_size;
<< ", group memory size:"
<< static_cast<double>(gps_size) / 1048576.0 << "(MB)";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't use 1048576.0 directely.

static double kMB = 1048576.0;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Will polish later.

}
}

void SetGroupAccordingToMemorySize(
const std::unordered_map<std::string, ir::Node *> &var_nodes,
details::GroupGradsAndParams *group_grads_params) const {
const uint64_t group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size == 0) {
const double group_memory_size = GetFuseParameterMemorySize();
if (group_memory_size <= 0.0) {
return;
}
details::GroupGradsAndParams local_group_grads_params;
Expand Down Expand Up @@ -271,7 +270,8 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
break;
}

if (local_group_memory_size >= group_memory_size) {
if (static_cast<double>(local_group_memory_size) / 1048576.0 >=
group_memory_size) {
break;
}
}
Expand All @@ -280,7 +280,7 @@ class AllocContinuousSpaceForGradPass : public ir::Pass {
std::swap(*group_grads_params, local_group_grads_params);

VLOG(10) << string::Sprintf(
"SetGroupAccordingToMemorySize(memory_size: %d):", group_memory_size);
"SetGroupAccordingToMemorySize(memory_size: %f):", group_memory_size);

if (VLOG_IS_ON(10)) {
PrintGroupInfo(var_nodes, group_grads_params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ namespace ir {
void SetFuseParameterGroupsSize(int group_size);
int GetFuseParameterGroupsSize();

void SetFuseParameterMemorySize(uint64_t memory_size);
uint64_t GetFuseParameterMemorySize();
void SetFuseParameterMemorySize(double memory_size);
double GetFuseParameterMemorySize();

} // namespace ir
} // namespace framework
Expand Down