Skip to content

Support stream priority for standalone executor #49939

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions paddle/fluid/distributed/auto_parallel/dist_attr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ std::vector<std::string> OperatorDistAttr::fields_{"process_mesh",
"impl_idx",
"is_recompute",
"execution_stream",
"stream_priority",
"scheduling_priority"};

OperatorDistAttr::OperatorDistAttr(const OpDesc& op) {
Expand All @@ -318,6 +319,8 @@ OperatorDistAttr& OperatorDistAttr::operator=(
std::swap(this->impl_idx_, tmp.impl_idx_);
std::swap(this->is_recompute_, tmp.is_recompute_);
std::swap(this->execution_stream_, tmp.execution_stream_);
std::swap(this->stream_priority_, tmp.stream_priority_);
std::swap(this->scheduling_priority_, tmp.scheduling_priority_);
std::swap(this->annotated_, tmp.annotated_);
// Note: Make sure all tensor dist attr has the same process_mesh
set_process_mesh(this->process_mesh_);
Expand Down Expand Up @@ -349,6 +352,7 @@ void OperatorDistAttr::initialize(const OpDesc* op) {
impl_idx_ = 0;
is_recompute_ = false;
execution_stream_ = kDefault;
stream_priority_ = 0;
scheduling_priority_ = 0;
}

Expand All @@ -361,6 +365,7 @@ void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) {
set_impl_idx(dist_attr.impl_idx());
set_is_recompute(dist_attr.is_recompute());
set_execution_stream(dist_attr.execution_stream());
set_stream_priority(dist_attr.stream_priority());
set_scheduling_priority(dist_attr.scheduling_priority());
set_annotated(dist_attr.annotated());
}
Expand Down Expand Up @@ -599,6 +604,7 @@ std::string OperatorDistAttr::to_string() const {
str += "{impl_type: " + impl_type_ + ", ";
str += "impl_idx: " + std::to_string(impl_idx_) + ", ";
str += "execution_stream: " + execution_stream_ + ", ";
str += "stream_priority: " + std::to_string(stream_priority_) + ", ";
str += "scheduling_priority: " + std::to_string(scheduling_priority_) + ", ";
str += "annotated: [" + str_join(annotated_) + "], ";
str += "\nprocess_mesh: " + process_mesh_.to_string() + ", ";
Expand Down Expand Up @@ -684,6 +690,9 @@ bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs) {
if (lhs.execution_stream() != rhs.execution_stream()) {
return false;
}
if (lhs.stream_priority() != rhs.stream_priority()) {
return false;
}
if (lhs.scheduling_priority() != rhs.scheduling_priority()) {
return false;
}
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/distributed/auto_parallel/dist_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,12 @@ class OperatorDistAttr {
execution_stream_ = execution_stream;
}

int stream_priority() const { return stream_priority_; }

void set_stream_priority(int stream_priority) {
stream_priority_ = stream_priority;
}

int64_t scheduling_priority() const { return scheduling_priority_; }

void set_scheduling_priority(int64_t scheduling_priority) {
Expand Down Expand Up @@ -289,6 +295,7 @@ class OperatorDistAttr {
int64_t impl_idx_ = 0;
bool is_recompute_ = false;
std::string execution_stream_ = kDefault;
int stream_priority_ = 0; // lower value, higher priority
int64_t scheduling_priority_ = 0; // lower value, higher priority
std::map<std::string, bool> annotated_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
platform::EmplaceDeviceContexts(
&fetch_ctxs_,
places,
/*disable_setting_default_stream_for_allocator=*/true);
/*disable_setting_default_stream_for_allocator=*/true,
/*stream_priority=*/0);
if (ir::IsTopologySortOperationsUnique(*graph_)) {
VLOG(10)
<< "Change thread number to 1 because the toposort order is unique";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
platform::EmplaceDeviceContexts(
&fetch_ctxs_,
places,
/*disable_setting_default_stream_for_allocator=*/true);
/*disable_setting_default_stream_for_allocator=*/true,
/*stream_priority=*/0);

if (strategy_.num_iteration_per_run_ > 1) {
int read_op_num = 0;
Expand Down
33 changes: 20 additions & 13 deletions paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -666,21 +666,28 @@ bool BuildOpFuncList(const platform::Place& place,
op_func_node.output_index = outs_name2id;

const OperatorDistAttr* dist_attr = block.Op(i)->DistAttr();
if (dist_attr &&
dist_attr->execution_stream() != distributed::auto_parallel::kDefault) {
op_func_node.execution_stream_ = dist_attr->execution_stream();
}

if (dist_attr) {
op_func_node.priority_ = dist_attr->scheduling_priority();
} else if (interpreter::IsCommunicationOp(op_type)) {
// NOTE(Ruibiao): Dispatching computation before communication improves
// multi-stream overlap when the time cost of communication less than that
// of the calculation (e.g., ResNet50_bs128_pure_fp16 N4C32 training).
op_func_node.priority_ = 1;
if (dist_attr->execution_stream() !=
distributed::auto_parallel::kDefault) {
op_func_node.execution_stream_ = dist_attr->execution_stream();
}
op_func_node.stream_priority_ = dist_attr->stream_priority();
op_func_node.scheduling_priority_ = dist_attr->scheduling_priority();
} else {
if (interpreter::IsCommunicationOp(op_type)) {
// NOTE(Ruibiao): Dispatching computation before communication improves
// multi-stream overlap when the time cost of communication less than
// that of the calculation (e.g., ResNet50_bs128_pure_fp16 N4C32
// training).
op_func_node.scheduling_priority_ = 1;
}
}
VLOG(6) << "scheduling priority of " << op_type << " : "
<< op_func_node.priority_;

VLOG(6) << op_type
<< " : [execution_stream, stream_priority, scheduling_priority] = ["
<< op_func_node.execution_stream_ << ", "
<< op_func_node.stream_priority_ << ", "
<< op_func_node.scheduling_priority_ << "]";

SingleStreamGuard single_stream_guard(ops[i]);

Expand Down
20 changes: 15 additions & 5 deletions paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class ContextManager {
}

std::shared_future<std::unique_ptr<DeviceContext>> Get(
const std::string& type, const platform::Place& place) {
const std::string& type,
const platform::Place& place,
int stream_priority) {
std::lock_guard<std::mutex> lk(ctx_mtx_);
VLOG(6) << "Get dev_ctx for " << type << " - " << place;

Expand All @@ -48,7 +50,8 @@ class ContextManager {
platform::EmplaceDeviceContexts(
&ctxs,
{place},
/*disable_setting_default_stream_for_allocator=*/true);
/*disable_setting_default_stream_for_allocator=*/true,
stream_priority);
}
return ctxs[place];
}
Expand Down Expand Up @@ -142,6 +145,7 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext(
auto& op = op_func_node.operator_base_;
auto& op_type = op->Type();
const std::string& execution_stream = op_func_node.execution_stream_;
const int stream_priority = op_func_node.stream_priority_;
ContextManager& ctx_manager = ContextManager::Instance();

// only gpu/npu need update. xpu not need, because xpu memcpy op kernel is
Expand All @@ -152,15 +156,21 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext(
<< ", execution stream = " << execution_stream;
if (execution_stream != kDefaultStream) {
return ctx_manager
.Get(std::string(kCustomStream) + "-" + execution_stream, place_)
.Get(std::string(kCustomStream) + "-" + execution_stream,
place_,
stream_priority)
.get()
.get();
}

if (op_type == interpreter::kMemcpyD2H) {
return ctx_manager.Get(std::string(kD2HStream), place_).get().get();
return ctx_manager.Get(std::string(kD2HStream), place_, stream_priority)
.get()
.get();
} else if (op_type == interpreter::kMemcpyH2D) {
return ctx_manager.Get(std::string(kH2DStream), place_).get().get();
return ctx_manager.Get(std::string(kH2DStream), place_, stream_priority)
.get()
.get();
}

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
Expand Down
16 changes: 9 additions & 7 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,15 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
}
var_scope_.SetLocalScope(local_scope_);

instruction_prority_less = [this](size_t lhs, size_t rhs) {
Priority lhs_prority = vec_instruction_[lhs].GetPriority();
Priority rhs_prority = vec_instruction_[rhs].GetPriority();
if (lhs_prority == rhs_prority) {
instruction_scheduling_priority_less = [this](size_t lhs, size_t rhs) {
SchedulingPriority lhs_scheduling_priority =
vec_instruction_[lhs].GetSchedulingPriority();
SchedulingPriority rhs_scheduling_priority =
vec_instruction_[rhs].GetSchedulingPriority();
if (lhs_scheduling_priority == rhs_scheduling_priority) {
return lhs < rhs;
}
return lhs_prority > rhs_prority;
return lhs_scheduling_priority > rhs_scheduling_priority;
};

PrepareForCUDAGraphCapture();
Expand Down Expand Up @@ -1089,7 +1091,7 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
// scheduling, the priority order involved cross-thread scheduling is not
// guaranteed. Only Ops scheduled by the same AddTask call have the guarantee
// of priority order.
SchedulingQueue ready_ops(instruction_prority_less);
SchedulingQueue ready_ops(instruction_scheduling_priority_less);
ready_ops.push(instr_id);
while (!ready_ops.empty()) {
instr_id = ready_ops.top();
Expand Down Expand Up @@ -1427,7 +1429,7 @@ void InterpreterCore::AnalyseExecuteOrderForTrace() {
};

std::vector<size_t> trace_order;
SchedulingQueue ready_ops(instruction_prority_less);
SchedulingQueue ready_ops(instruction_scheduling_priority_less);

for (size_t instr_id = 0; instr_id < dependecy_count_.size(); ++instr_id) {
if (dependecy_count_[instr_id] == 0) {
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/framework/new_executor/interpretercore.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ class InterpreterCore {
const platform::Place& GetPlace() const { return place_; }

private:
using InstructionPriorityLess = std::function<bool(size_t, size_t)>;
using InstructionSchedulingPriorityLess = std::function<bool(size_t, size_t)>;
using SchedulingQueue =
std::priority_queue<size_t, std::vector<size_t>, InstructionPriorityLess>;
std::priority_queue<size_t,
std::vector<size_t>,
InstructionSchedulingPriorityLess>;

// build graph
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
Expand Down Expand Up @@ -181,7 +183,7 @@ class InterpreterCore {
int64_t sync_op_num_{-1};
std::vector<size_t> trace_execute_order_;

InstructionPriorityLess instruction_prority_less;
InstructionSchedulingPriorityLess instruction_scheduling_priority_less;
};

std::shared_ptr<InterpreterCore> CreateInterpreterCore(
Expand Down
11 changes: 7 additions & 4 deletions paddle/fluid/framework/new_executor/new_executor_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ namespace framework {

using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;

using Priority = int64_t;
using SchedulingPriority = int64_t;

constexpr const char* kCoalesceTensor = "coalesce_tensor";

// stream types
constexpr const char* kCustomStream = "CustromStream";
constexpr const char* kCustomStream = "CustomStream";
constexpr const char* kDefaultStream = "DefaultStream";
constexpr const char* kD2HStream = "D2HStream";
constexpr const char* kH2DStream = "H2DStream";
Expand Down Expand Up @@ -263,6 +263,7 @@ enum class OpFuncType {
class RuntimeInferShapeContext;

struct OpFuncNode {
int stream_priority_{0}; // lower value, higher priority
// fit for phi kernel
phi::Kernel* phi_kernel_{nullptr}; // not owned
platform::DeviceContext* dev_ctx_; // not owned
Expand All @@ -279,7 +280,7 @@ struct OpFuncNode {
OpFuncType type_;
OpKernelComputeFunc kernel_func_;

Priority priority_{0}; // lower value, higher priority
SchedulingPriority scheduling_priority_{0}; // lower value, higher priority
};

class Instruction {
Expand Down Expand Up @@ -369,7 +370,9 @@ class Instruction {

void ClearInplace();

Priority GetPriority() const { return op_func_node_.priority_; }
SchedulingPriority GetSchedulingPriority() const {
return op_func_node_.scheduling_priority_;
}

private:
bool is_artificial_; // Instruction is artificial means that it is only used
Expand Down
Loading