@@ -3073,7 +3073,8 @@ int GpuPsGraphTable::get_feature_info_of_nodes(
3073
3073
std::shared_ptr<phi::Allocation> &size_list,
3074
3074
std::shared_ptr<phi::Allocation> &size_list_prefix_sum,
3075
3075
std::shared_ptr<phi::Allocation>& feature_list,
3076
- std::shared_ptr<phi::Allocation>& slot_list) {
3076
+ std::shared_ptr<phi::Allocation>& slot_list,
3077
+ bool sage_mode) {
3077
3078
if (node_num == 0 ) {
3078
3079
return 0 ;
3079
3080
}
@@ -3086,7 +3087,8 @@ int GpuPsGraphTable::get_feature_info_of_nodes(
3086
3087
} else {
3087
3088
if (FLAGS_enable_graph_multi_node_sampling) {
3088
3089
all_fea_num = get_feature_info_of_nodes_all2all (gpu_id, d_nodes, node_num, size_list,
3089
- size_list_prefix_sum, feature_list, slot_list);
3090
+ size_list_prefix_sum, feature_list, slot_list,
3091
+ sage_mode);
3090
3092
}
3091
3093
}
3092
3094
} else {
@@ -3104,7 +3106,8 @@ int GpuPsGraphTable::get_feature_info_of_nodes_all2all(
3104
3106
std::shared_ptr<phi::Allocation> &size_list,
3105
3107
std::shared_ptr<phi::Allocation> &size_list_prefix_sum,
3106
3108
std::shared_ptr<phi::Allocation>& feature_list,
3107
- std::shared_ptr<phi::Allocation>& slot_list) {
3109
+ std::shared_ptr<phi::Allocation>& slot_list,
3110
+ bool sage_mode) {
3108
3111
if (node_num == 0 ) {
3109
3112
return 0 ;
3110
3113
}
@@ -3241,7 +3244,7 @@ int GpuPsGraphTable::get_feature_info_of_nodes_all2all(
3241
3244
reinterpret_cast <uint64_t *>(feature_list_ptr),
3242
3245
reinterpret_cast <uint64_t *>(inter_feature_list_ptr),
3243
3246
stream,
3244
- false ,
3247
+ sage_mode ,
3245
3248
true );
3246
3249
VLOG (2 ) << " end send feature list" ;
3247
3250
@@ -3260,7 +3263,7 @@ int GpuPsGraphTable::get_feature_info_of_nodes_all2all(
3260
3263
reinterpret_cast <uint8_t *>(slot_list_ptr),
3261
3264
reinterpret_cast <uint8_t *>(inter_slot_list_ptr),
3262
3265
stream,
3263
- false ,
3266
+ sage_mode ,
3264
3267
true );
3265
3268
VLOG (2 ) << " end send slot list" ;
3266
3269
0 commit comments