Skip to content

Commit 93c9f05

Browse files
123malinseiriosPlus
andcommitted
test=develop, save/load, shrink (#30625)
* test=develop, save/load, shrink Co-authored-by: seiriosPlus <tangwei12@baidu.com>
1 parent 2961e59 commit 93c9f05

File tree

19 files changed

+206
-106
lines changed

19 files changed

+206
-106
lines changed

paddle/fluid/distributed/fleet.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,9 +479,15 @@ void FleetWrapper::PrintTableStat(const uint64_t table_id) {
479479
}
480480
}
481481

482-
void FleetWrapper::ShrinkSparseTable(int table_id) {
483-
auto ret = pserver_ptr_->_worker_ptr->shrink(table_id);
482+
void FleetWrapper::ShrinkSparseTable(int table_id, int threshold) {
483+
auto* communicator = Communicator::GetInstance();
484+
auto ret =
485+
communicator->_worker_ptr->shrink(table_id, std::to_string(threshold));
484486
ret.wait();
487+
int32_t err_code = ret.get();
488+
if (err_code == -1) {
489+
LOG(ERROR) << "shrink sparse table stat failed";
490+
}
485491
}
486492

487493
void FleetWrapper::ClearModel() {

paddle/fluid/distributed/fleet.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class FleetWrapper {
207207
// clear one table
208208
void ClearOneTable(const uint64_t table_id);
209209
// shrink sparse table
210-
void ShrinkSparseTable(int table_id);
210+
void ShrinkSparseTable(int table_id, int threshold);
211211
// shrink dense table
212212
void ShrinkDenseTable(int table_id, Scope* scope,
213213
std::vector<std::string> var_list, float decay,

paddle/fluid/distributed/service/brpc_ps_client.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,9 @@ std::future<int32_t> BrpcPsClient::send_save_cmd(
353353
return fut;
354354
}
355355

356-
std::future<int32_t> BrpcPsClient::shrink(uint32_t table_id) {
357-
return send_cmd(table_id, PS_SHRINK_TABLE, {std::string("1")});
356+
std::future<int32_t> BrpcPsClient::shrink(uint32_t table_id,
357+
const std::string threshold) {
358+
return send_cmd(table_id, PS_SHRINK_TABLE, {threshold});
358359
}
359360

360361
std::future<int32_t> BrpcPsClient::load(const std::string &epoch,

paddle/fluid/distributed/service/brpc_ps_client.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ class BrpcPsClient : public PSClient {
102102
}
103103
virtual int32_t create_client2client_connection(
104104
int pserver_timeout_ms, int pserver_connect_timeout_ms, int max_retry);
105-
virtual std::future<int32_t> shrink(uint32_t table_id) override;
105+
virtual std::future<int32_t> shrink(uint32_t table_id,
106+
const std::string threshold) override;
106107
virtual std::future<int32_t> load(const std::string &epoch,
107108
const std::string &mode) override;
108109
virtual std::future<int32_t> load(uint32_t table_id, const std::string &epoch,

paddle/fluid/distributed/service/brpc_ps_server.cc

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,8 @@ int32_t BrpcPsService::save_one_table(Table *table,
460460
table->flush();
461461

462462
int32_t feasign_size = 0;
463+
464+
VLOG(0) << "save one table " << request.params(0) << " " << request.params(1);
463465
feasign_size = table->save(request.params(0), request.params(1));
464466
if (feasign_size < 0) {
465467
set_response_code(response, -1, "table save failed");
@@ -491,10 +493,18 @@ int32_t BrpcPsService::shrink_table(Table *table,
491493
PsResponseMessage &response,
492494
brpc::Controller *cntl) {
493495
CHECK_TABLE_EXIST(table, request, response)
496+
if (request.params_size() < 1) {
497+
set_response_code(
498+
response, -1,
499+
"PsRequestMessage.datas is requeired at least 1, threshold");
500+
return -1;
501+
}
494502
table->flush();
495-
if (table->shrink() != 0) {
503+
if (table->shrink(request.params(0)) != 0) {
496504
set_response_code(response, -1, "table shrink failed");
505+
return -1;
497506
}
507+
VLOG(0) << "Pserver Shrink Finished";
498508
return 0;
499509
}
500510

paddle/fluid/distributed/service/ps_client.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ class PSClient {
6969
int max_retry) = 0;
7070

7171
// 触发table数据退场
72-
virtual std::future<int32_t> shrink(uint32_t table_id) = 0;
72+
virtual std::future<int32_t> shrink(uint32_t table_id,
73+
const std::string threshold) = 0;
7374

7475
// 全量table进行数据load
7576
virtual std::future<int32_t> load(const std::string &epoch,

paddle/fluid/distributed/table/common_dense_table.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class CommonDenseTable : public DenseTable {
5858
}
5959

6060
virtual int32_t flush() override { return 0; }
61-
virtual int32_t shrink() override { return 0; }
61+
virtual int32_t shrink(const std::string& param) override { return 0; }
6262
virtual void clear() override { return; }
6363

6464
protected:

paddle/fluid/distributed/table/common_sparse_table.cc

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,12 @@
2222
#include "paddle/fluid/string/string_helper.h"
2323

2424
#define PSERVER_SAVE_SUFFIX "_txt"
25+
2526
namespace paddle {
2627
namespace distributed {
2728

29+
enum SaveMode { all, base, delta };
30+
2831
struct Meta {
2932
std::string param;
3033
int shard_id;
@@ -94,12 +97,9 @@ struct Meta {
9497

9598
void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
9699
std::vector<std::vector<float>>* values) {
97-
PADDLE_ENFORCE_EQ(columns.size(), 2,
98-
paddle::platform::errors::InvalidArgument(
99-
"The data format does not meet the requirements. It "
100-
"should look like feasign_id \t params."));
101-
102-
auto load_values = paddle::string::split_string<std::string>(columns[1], ",");
100+
auto colunmn_size = columns.size();
101+
auto load_values =
102+
paddle::string::split_string<std::string>(columns[colunmn_size - 1], ",");
103103
values->reserve(meta.names.size());
104104

105105
int offset = 0;
@@ -121,11 +121,18 @@ void ProcessALine(const std::vector<std::string>& columns, const Meta& meta,
121121

122122
int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
123123
const int mode) {
124+
int64_t not_save_num = 0;
124125
for (auto value : block->values_) {
126+
if (mode == SaveMode::delta && !value.second->need_save_) {
127+
not_save_num++;
128+
continue;
129+
}
130+
125131
auto* vs = value.second->data_.data();
126132
std::stringstream ss;
127133
auto id = value.first;
128-
ss << id << "\t";
134+
ss << id << "\t" << value.second->count_ << "\t"
135+
<< value.second->unseen_days_ << "\t" << value.second->is_entry_ << "\t";
129136

130137
for (int i = 0; i < block->value_length_; i++) {
131138
ss << vs[i];
@@ -135,9 +142,13 @@ int64_t SaveToText(std::ostream* os, std::shared_ptr<ValueBlock> block,
135142
ss << "\n";
136143

137144
os->write(ss.str().c_str(), sizeof(char) * ss.str().size());
145+
146+
if (mode == SaveMode::base || mode == SaveMode::delta) {
147+
value.second->need_save_ = false;
148+
}
138149
}
139150

140-
return block->values_.size();
151+
return block->values_.size() - not_save_num;
141152
}
142153

143154
int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,
@@ -165,8 +176,21 @@ int64_t LoadFromText(const std::string& valuepath, const std::string& metapath,
165176

166177
std::vector<std::vector<float>> kvalues;
167178
ProcessALine(values, meta, &kvalues);
168-
// warning: need fix
169-
block->Init(id);
179+
180+
block->Init(id, false);
181+
182+
auto value_instant = block->GetValue(id);
183+
if (values.size() == 5) {
184+
value_instant->count_ = std::stoi(values[1]);
185+
value_instant->unseen_days_ = std::stoi(values[2]);
186+
value_instant->is_entry_ = static_cast<bool>(std::stoi(values[3]));
187+
}
188+
189+
std::vector<float*> block_values = block->Get(id, meta.names, meta.dims);
190+
auto blas = GetBlas<float>();
191+
for (int x = 0; x < meta.names.size(); ++x) {
192+
blas.VCOPY(meta.dims[x], kvalues[x].data(), block_values[x]);
193+
}
170194
}
171195

172196
return 0;
@@ -393,7 +417,7 @@ int32_t CommonSparseTable::pull_sparse(float* pull_values, const uint64_t* keys,
393417
for (int i = 0; i < offsets.size(); ++i) {
394418
auto offset = offsets[i];
395419
auto id = keys[offset];
396-
auto* value = block->InitFromInitializer(id);
420+
auto* value = block->Init(id);
397421
std::copy_n(value + param_offset_, param_dim_,
398422
pull_values + param_dim_ * offset);
399423
}
@@ -488,9 +512,10 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
488512
for (int i = 0; i < offsets.size(); ++i) {
489513
auto offset = offsets[i];
490514
auto id = keys[offset];
491-
auto* value = block->InitFromInitializer(id);
515+
auto* value = block->Init(id, false);
492516
std::copy_n(values + param_dim_ * offset, param_dim_,
493517
value + param_offset_);
518+
block->SetEntry(id, true);
494519
}
495520
return 0;
496521
});
@@ -505,10 +530,20 @@ int32_t CommonSparseTable::push_sparse_param(const uint64_t* keys,
505530

506531
int32_t CommonSparseTable::flush() { return 0; }
507532

508-
int32_t CommonSparseTable::shrink() {
509-
VLOG(0) << "shrink coming soon";
533+
int32_t CommonSparseTable::shrink(const std::string& param) {
534+
rwlock_->WRLock();
535+
int threshold = std::stoi(param);
536+
VLOG(0) << "sparse table shrink: " << threshold;
537+
538+
for (int shard_id = 0; shard_id < task_pool_size_; ++shard_id) {
539+
// shrink
540+
VLOG(0) << shard_id << " " << task_pool_size_ << " begin shrink";
541+
shard_values_[shard_id]->Shrink(threshold);
542+
}
543+
rwlock_->UNLock();
510544
return 0;
511545
}
546+
512547
void CommonSparseTable::clear() { VLOG(0) << "clear coming soon"; }
513548

514549
} // namespace distributed

paddle/fluid/distributed/table/common_sparse_table.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class CommonSparseTable : public SparseTable {
7373

7474
virtual int32_t pour();
7575
virtual int32_t flush();
76-
virtual int32_t shrink();
76+
virtual int32_t shrink(const std::string& param);
7777
virtual void clear();
7878

7979
protected:

paddle/fluid/distributed/table/common_table.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class DenseTable : public Table {
108108
int32_t push_dense_param(const float *values, size_t num) override {
109109
return 0;
110110
}
111-
int32_t shrink() override { return 0; }
111+
int32_t shrink(const std::string &param) override { return 0; }
112112
};
113113

114114
class BarrierTable : public Table {
@@ -133,7 +133,7 @@ class BarrierTable : public Table {
133133
int32_t push_dense_param(const float *values, size_t num) override {
134134
return 0;
135135
}
136-
int32_t shrink() override { return 0; }
136+
int32_t shrink(const std::string &param) override { return 0; }
137137
virtual void clear(){};
138138
virtual int32_t flush() { return 0; };
139139
virtual int32_t load(const std::string &path, const std::string &param) {

0 commit comments

Comments
 (0)