Skip to content

Commit e661334

Browse files
DesmonDaydanleifeng
authored andcommitted
fix slot sage hang (PaddlePaddle#300)
1 parent b877559 commit e661334

File tree

5 files changed

+20
-11
lines changed

5 files changed

+20
-11
lines changed

paddle/fluid/framework/data_feed.cu

+2-1
Original file line numberDiff line numberDiff line change
@@ -1459,7 +1459,8 @@ int GraphDataGenerator::FillSlotFeature(uint64_t *d_walk, size_t key_num, int te
14591459
d_feature_size_list_buf_,
14601460
d_feature_size_prefixsum_buf_,
14611461
d_feature_list,
1462-
d_slot_list);
1462+
d_slot_list,
1463+
conf_.sage_mode);
14631464
// num of slot feature
14641465
int slot_num = conf_.slot_num - float_slot_num_;
14651466
int64_t *slot_tensor_ptr_[slot_num];

paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,8 @@ class GpuPsGraphTable
205205
std::shared_ptr<phi::Allocation> &size_list,
206206
std::shared_ptr<phi::Allocation> &size_list_prefix_sum,
207207
std::shared_ptr<phi::Allocation> &feature_list, // NOLINT
208-
std::shared_ptr<phi::Allocation> &slot_list); // NOLINT
208+
std::shared_ptr<phi::Allocation> &slot_list,
209+
bool sage_mode = false); // NOLINT
209210
int get_float_feature_info_of_nodes(
210211
int gpu_id,
211212
uint64_t *d_nodes,
@@ -231,7 +232,8 @@ class GpuPsGraphTable
231232
std::shared_ptr<phi::Allocation> &size_list,
232233
std::shared_ptr<phi::Allocation> &size_list_prefix_sum,
233234
std::shared_ptr<phi::Allocation> &feature_list,
234-
std::shared_ptr<phi::Allocation> &slot_list);
235+
std::shared_ptr<phi::Allocation> &slot_list,
236+
bool sage_mode = false);
235237

236238

237239
NodeQueryResult query_node_list(int gpu_id,

paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table_inl.cu

+8-5
Original file line numberDiff line numberDiff line change
@@ -3073,7 +3073,8 @@ int GpuPsGraphTable::get_feature_info_of_nodes(
30733073
std::shared_ptr<phi::Allocation> &size_list,
30743074
std::shared_ptr<phi::Allocation> &size_list_prefix_sum,
30753075
std::shared_ptr<phi::Allocation>& feature_list,
3076-
std::shared_ptr<phi::Allocation>& slot_list) {
3076+
std::shared_ptr<phi::Allocation>& slot_list,
3077+
bool sage_mode) {
30773078
if (node_num == 0) {
30783079
return 0;
30793080
}
@@ -3086,7 +3087,8 @@ int GpuPsGraphTable::get_feature_info_of_nodes(
30863087
} else {
30873088
if (FLAGS_enable_graph_multi_node_sampling) {
30883089
all_fea_num = get_feature_info_of_nodes_all2all(gpu_id, d_nodes, node_num, size_list,
3089-
size_list_prefix_sum, feature_list, slot_list);
3090+
size_list_prefix_sum, feature_list, slot_list,
3091+
sage_mode);
30903092
}
30913093
}
30923094
} else {
@@ -3104,7 +3106,8 @@ int GpuPsGraphTable::get_feature_info_of_nodes_all2all(
31043106
std::shared_ptr<phi::Allocation> &size_list,
31053107
std::shared_ptr<phi::Allocation> &size_list_prefix_sum,
31063108
std::shared_ptr<phi::Allocation>& feature_list,
3107-
std::shared_ptr<phi::Allocation>& slot_list) {
3109+
std::shared_ptr<phi::Allocation>& slot_list,
3110+
bool sage_mode) {
31083111
if (node_num == 0) {
31093112
return 0;
31103113
}
@@ -3241,7 +3244,7 @@ int GpuPsGraphTable::get_feature_info_of_nodes_all2all(
32413244
reinterpret_cast<uint64_t*>(feature_list_ptr),
32423245
reinterpret_cast<uint64_t*>(inter_feature_list_ptr),
32433246
stream,
3244-
false,
3247+
sage_mode,
32453248
true);
32463249
VLOG(2) << "end send feature list";
32473250
@@ -3260,7 +3263,7 @@ int GpuPsGraphTable::get_feature_info_of_nodes_all2all(
32603263
reinterpret_cast<uint8_t*>(slot_list_ptr),
32613264
reinterpret_cast<uint8_t*>(inter_slot_list_ptr),
32623265
stream,
3263-
false,
3266+
sage_mode,
32643267
true);
32653268
VLOG(2) << "end send slot list";
32663269

paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu

+4-2
Original file line numberDiff line numberDiff line change
@@ -997,7 +997,8 @@ int GraphGpuWrapper::get_feature_info_of_nodes(
997997
std::shared_ptr<phi::Allocation> &size_list,
998998
std::shared_ptr<phi::Allocation> &size_list_prefix_sum,
999999
std::shared_ptr<phi::Allocation> &feature_list,
1000-
std::shared_ptr<phi::Allocation> &slot_list) {
1000+
std::shared_ptr<phi::Allocation> &slot_list,
1001+
bool sage_mode) {
10011002
platform::CUDADeviceGuard guard(gpu_id);
10021003
PADDLE_ENFORCE_NOT_NULL(graph_table,
10031004
paddle::platform::errors::InvalidArgument(
@@ -1009,7 +1010,8 @@ int GraphGpuWrapper::get_feature_info_of_nodes(
10091010
size_list,
10101011
size_list_prefix_sum,
10111012
feature_list,
1012-
slot_list);
1013+
slot_list,
1014+
sage_mode);
10131015
}
10141016

10151017
int GraphGpuWrapper::get_float_feature_info_of_nodes(

paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ class GraphGpuWrapper {
186186
std::shared_ptr<phi::Allocation>& size_list,
187187
std::shared_ptr<phi::Allocation>& size_list_prefix_sum,
188188
std::shared_ptr<phi::Allocation>& feature_list, // NOLINT
189-
std::shared_ptr<phi::Allocation>& slot_list); // NOLINT
189+
std::shared_ptr<phi::Allocation>& slot_list,
190+
bool sage_mode = false); // NOLINT
190191
int get_float_feature_info_of_nodes(
191192
int gpu_id,
192193
uint64_t *d_nodes,

0 commit comments

Comments
 (0)