@@ -72,6 +72,11 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
72
72
std::vector<uint64_t > node_id_array[task_pool_size_];
73
73
std::vector<paddle::framework::GpuPsFeaInfo>
74
74
node_fea_info_array[task_pool_size_];
75
+ slot_feature_num_map_.resize (slot_num);
76
+ for (int k = 0 ; k < slot_num; ++k) {
77
+ slot_feature_num_map_[k] = 0 ;
78
+ }
79
+
75
80
for (size_t i = 0 ; i < bags.size (); i++) {
76
81
if (bags[i].size () > 0 ) {
77
82
tasks.push_back (_shards_task_pool[i]->enqueue ([&, i, this ]() -> int {
@@ -92,13 +97,17 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
92
97
int total_feature_size = 0 ;
93
98
for (int k = 0 ; k < slot_num; ++k) {
94
99
v->get_feature_ids (k, &feature_ids);
95
- total_feature_size += feature_ids.size ();
100
+ int feature_ids_size = feature_ids.size ();
101
+ if (slot_feature_num_map_[k] < feature_ids_size) {
102
+ slot_feature_num_map_[k] = feature_ids_size;
103
+ }
104
+ total_feature_size += feature_ids_size;
96
105
if (!feature_ids.empty ()) {
97
106
feature_array[i].insert (feature_array[i].end (),
98
107
feature_ids.begin (),
99
108
feature_ids.end ());
100
109
slot_id_array[i].insert (
101
- slot_id_array[i].end (), feature_ids. size () , k);
110
+ slot_id_array[i].end (), feature_ids_size , k);
102
111
}
103
112
}
104
113
x.feature_size = total_feature_size;
@@ -111,6 +120,13 @@ paddle::framework::GpuPsCommGraphFea GraphTable::make_gpu_ps_graph_fea(
111
120
}
112
121
}
113
122
for (int i = 0 ; i < (int )tasks.size (); i++) tasks[i].get ();
123
+
124
+ std::stringstream ss;
125
+ for (int k = 0 ; k < slot_num; ++k) {
126
+ ss << slot_feature_num_map_[k] << " " ;
127
+ }
128
+ VLOG (0 ) << " slot_feature_num_map: " << ss.str ();
129
+
114
130
paddle::framework::GpuPsCommGraphFea res;
115
131
uint64_t tot_len = 0 ;
116
132
for (int i = 0 ; i < task_pool_size_; i++) {
@@ -1849,9 +1865,14 @@ int GraphTable::parse_feature(int idx,
1849
1865
// "")
1850
1866
thread_local std::vector<paddle::string::str_ptr> fields;
1851
1867
fields.clear ();
1852
- const char c = feature_separator_ .at (0 );
1868
+ char c = slot_feature_separator_ .at (0 );
1853
1869
paddle::string::split_string_ptr (feat_str, len, c, &fields);
1854
1870
1871
+ thread_local std::vector<paddle::string::str_ptr> fea_fields;
1872
+ fea_fields.clear ();
1873
+ c = feature_separator_.at (0 );
1874
+ paddle::string::split_string_ptr (fields[1 ].ptr , fields[1 ].len , c, &fea_fields);
1875
+
1855
1876
std::string name = fields[0 ].to_string ();
1856
1877
auto it = feat_id_map[idx].find (name);
1857
1878
if (it != feat_id_map[idx].end ()) {
@@ -1862,26 +1883,26 @@ int GraphTable::parse_feature(int idx,
1862
1883
// string_vector_2_string(fields.begin() + 1, fields.end(), ' ',
1863
1884
// fea_ptr);
1864
1885
FeatureNode::parse_value_to_bytes<uint64_t >(
1865
- fields .begin () + 1 , fields .end (), fea_ptr);
1886
+ fea_fields .begin (), fea_fields .end (), fea_ptr);
1866
1887
return 0 ;
1867
1888
} else if (dtype == " string" ) {
1868
- string_vector_2_string (fields .begin () + 1 , fields .end (), ' ' , fea_ptr);
1889
+ string_vector_2_string (fea_fields .begin (), fea_fields .end (), ' ' , fea_ptr);
1869
1890
return 0 ;
1870
1891
} else if (dtype == " float32" ) {
1871
1892
FeatureNode::parse_value_to_bytes<float >(
1872
- fields .begin () + 1 , fields .end (), fea_ptr);
1893
+ fea_fields .begin (), fea_fields .end (), fea_ptr);
1873
1894
return 0 ;
1874
1895
} else if (dtype == " float64" ) {
1875
1896
FeatureNode::parse_value_to_bytes<double >(
1876
- fields .begin () + 1 , fields .end (), fea_ptr);
1897
+ fea_fields .begin (), fea_fields .end (), fea_ptr);
1877
1898
return 0 ;
1878
1899
} else if (dtype == " int32" ) {
1879
1900
FeatureNode::parse_value_to_bytes<int32_t >(
1880
- fields .begin () + 1 , fields .end (), fea_ptr);
1901
+ fea_fields .begin (), fea_fields .end (), fea_ptr);
1881
1902
return 0 ;
1882
1903
} else if (dtype == " int64" ) {
1883
1904
FeatureNode::parse_value_to_bytes<uint64_t >(
1884
- fields .begin () + 1 , fields .end (), fea_ptr);
1905
+ fea_fields .begin (), fea_fields .end (), fea_ptr);
1885
1906
return 0 ;
1886
1907
}
1887
1908
} else {
@@ -2118,6 +2139,10 @@ void GraphTable::set_feature_separator(const std::string &ch) {
2118
2139
feature_separator_ = ch;
2119
2140
}
2120
2141
2142
+ void GraphTable::set_slot_feature_separator (const std::string &ch) {
2143
+ slot_feature_separator_ = ch;
2144
+ }
2145
+
2121
2146
int32_t GraphTable::get_server_index_by_id (uint64_t id) {
2122
2147
return id % shard_num / shard_num_per_server;
2123
2148
}
0 commit comments