Skip to content

【GPUPS】add env for gpups and fix cache table #70301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion paddle/fluid/distributed/ps/service/brpc_ps_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,9 @@ int FlClientBrpcClosure::check_response(size_t request_idx, int cmd_id) {
return 0;
}

std::future<int32_t> BrpcPsClient::PrintTableStat(uint32_t table_id) {
std::future<int32_t> BrpcPsClient::PrintTableStat(uint32_t table_id,
uint16_t pass_id,
size_t threshold) {
size_t request_call_num = _server_channels.size();
DownpourBrpcClosure *closure = new DownpourBrpcClosure(
request_call_num, [request_call_num, table_id](void *done) {
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/distributed/ps/service/brpc_ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,9 @@ class BrpcPsClient : public PSClient {
size_t num,
bool is_training);

virtual std::future<int32_t> PrintTableStat(uint32_t table_id);
virtual std::future<int32_t> PrintTableStat(uint32_t table_id,
uint16_t pass_id,
size_t threshold);

virtual std::future<int32_t> Barrier(size_t table_id, uint32_t barrier_type);

Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/distributed/ps/service/ps_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ class PSClient {
return fut;
}

virtual std::future<int32_t> PrintTableStat(uint32_t table_id) = 0;
virtual std::future<int32_t> PrintTableStat(uint32_t table_id,
uint16_t pass_id,
size_t threshold) = 0;
virtual std::future<int32_t> SaveCacheTable(uint32_t table_id UNUSED,
uint16_t pass_id UNUSED,
size_t threshold UNUSED) {
Expand Down
9 changes: 8 additions & 1 deletion paddle/fluid/distributed/ps/service/ps_local_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,18 @@ ::std::future<int32_t> PsLocalClient::PullSparsePtr(
return done();
}

::std::future<int32_t> PsLocalClient::PrintTableStat(uint32_t table_id) {
::std::future<int32_t> PsLocalClient::PrintTableStat(uint32_t table_id,
uint16_t pass_id,
size_t threshold) {
auto* table_ptr = GetTable(table_id);
std::pair<int64_t, int64_t> ret = table_ptr->PrintTableStat();
VLOG(0) << "table id: " << table_id << ", feasign size: " << ret.first
<< ", mf size: " << ret.second;
// > 50亿,40%内存
if (static_cast<size_t>(ret.first) > threshold) {
VLOG(0) << "run cache table";
table_ptr->CacheTable(pass_id);
}
return done();
}

Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/distributed/ps/service/ps_local_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,9 @@ class PsLocalClient : public PSClient {
const std::vector<std::unordered_map<uint64_t, uint32_t>>& keys2rank_vec,
const uint16_t& dim_id = 0);

virtual ::std::future<int32_t> PrintTableStat(uint32_t table_id);
virtual ::std::future<int32_t> PrintTableStat(uint32_t table_id,
uint16_t pass_id,
size_t threshold);

virtual ::std::future<int32_t> SaveCacheTable(uint32_t table_id,
uint16_t pass_id,
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ void CtrDymfAccessor::UpdateStatAfterSave(float* value, int param) {
int32_t CtrDymfAccessor::Create(float** values, size_t num) {
for (size_t value_item = 0; value_item < num; ++value_item) {
float* value = values[value_item];
#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
common_feature_value.UnseenDays(value) = 0;
common_feature_value.PassId(value) = 0;
#else
Expand Down Expand Up @@ -385,7 +385,7 @@ std::string CtrDymfAccessor::ParseToString(const float* v, int param) {

int CtrDymfAccessor::ParseFromString(const std::string& str, float* value) {
auto ret = paddle::string::str_to_float(str.data(), value);
#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
float unseen_day = value[common_feature_value.UnseenDaysIndex()];
common_feature_value.UnseenDays(value) = (uint16_t)(unseen_day);
common_feature_value.PassId(value) = 0;
Expand Down Expand Up @@ -437,7 +437,7 @@ void CtrDymfAccessor::UpdateTimeDecay(float* value, bool is_update_seen_day) {
}
}

#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
bool CtrDymfAccessor::SaveMemCache(float* value,
int param,
double global_cache_threshold,
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/ps/table/ctr_dymf_accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class CtrDymfAccessor : public ValueAccessor {
// 根据mf_dim计算的总byte数
int Size(int mf_dim) { return (Dim(mf_dim)) * sizeof(float); }

#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
uint16_t& PassId(float* val) {
uint16_t* int16_val =
reinterpret_cast<uint16_t*>(val + UnseenDaysIndex());
Expand Down Expand Up @@ -258,7 +258,7 @@ class CtrDymfAccessor : public ValueAccessor {

void SetDayId(int day_id) override;

#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
// 根据pass_id和show_threshold阈值来判断cache到ssd
bool SaveMemCache(float* value,
int param,
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/distributed/ps/table/ssd_sparse_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ int32_t SSDSparseTable::PullSparsePtr(int shard_id,
}

_value_accessor->UpdateTimeDecay(ret->data(), true);
#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
_value_accessor->UpdatePassId(ret->data(), pass_id);
#endif
int pull_data_idx = cur_ctx->batch_index[idx];
Expand All @@ -280,7 +280,7 @@ int32_t SSDSparseTable::PullSparsePtr(int shard_id,
ret = itr.value_ptr();
// int pull_data_idx = keys[i].second;
_value_accessor->UpdateTimeDecay(ret->data(), true);
#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
_value_accessor->UpdatePassId(ret->data(), pass_id);
#endif
pull_values[i] = reinterpret_cast<char*>(ret);
Expand Down Expand Up @@ -332,7 +332,7 @@ int32_t SSDSparseTable::PullSparsePtr(int shard_id,
ret = &feature_value;
}
_value_accessor->UpdateTimeDecay(ret->data(), true);
#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
_value_accessor->UpdatePassId(ret->data(), pass_id);
#endif
int pull_data_idx = cur_ctx->batch_index[idx];
Expand Down Expand Up @@ -2945,7 +2945,7 @@ int32_t SSDSparseTable::LoadWithBinary(const std::string& path, int param) {
abort();
}
last_k = k;
#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
_value_accessor->UpdatePassId(convert_value, 0);
#endif
rocksdb::Status status = sst_writer.Put(
Expand All @@ -2963,7 +2963,7 @@ int32_t SSDSparseTable::LoadWithBinary(const std::string& path, int param) {
}
} else {
auto& feature_value = shard[k];
#ifdef PADDLE_WITH_PSLIB
#if defined(PADDLE_WITH_PSLIB) || defined(PADDLE_WITH_HETERPS)
_value_accessor->UpdatePassId(convert_value, 0);
#endif
feature_value.resize(dim);
Expand Down Expand Up @@ -3051,7 +3051,7 @@ std::pair<int64_t, int64_t> SSDSparseTable::PrintTableStat() {

int32_t SSDSparseTable::CacheTable(uint16_t pass_id) {
std::lock_guard<std::mutex> guard(_table_mutex);
VLOG(0) << "cache_table";
VLOG(0) << "cache_table, pass_id:" << pass_id;
std::atomic<uint32_t> count{0};
std::vector<std::future<int>> tasks;

Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/distributed/ps/wrapper/fleet.cc
Original file line number Diff line number Diff line change
Expand Up @@ -815,8 +815,10 @@ void FleetWrapper::RecvAndSaveTable(const uint64_t table_id,
}
}

void FleetWrapper::PrintTableStat(const uint64_t table_id) {
auto ret = worker_ptr_->PrintTableStat(table_id);
void FleetWrapper::PrintTableStat(const uint64_t table_id,
uint32_t pass_id,
size_t threshold) {
auto ret = worker_ptr_->PrintTableStat(table_id, pass_id, threshold);
ret.wait();
int32_t err_code = ret.get();
if (err_code == -1) {
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/distributed/ps/wrapper/fleet.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ class FleetWrapper {
// barrier with barrier table
void BarrierWithTable(uint32_t barrier_type);

void PrintTableStat(const uint64_t table_id);
void PrintTableStat(const uint64_t table_id,
uint32_t pass_id,
size_t threshold);
void SaveCacheTable(const uint64_t table_id,
uint16_t pass_id,
size_t threshold);
Expand Down
18 changes: 17 additions & 1 deletion paddle/fluid/framework/fleet/ps_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ limitations under the License. */

#pragma once
#ifdef PADDLE_WITH_HETERPS

#include <google/protobuf/text_format.h>
#include <stdlib.h>
#include <atomic>
#include <ctime>
#include <map>
Expand Down Expand Up @@ -390,6 +390,22 @@ class PSGPUWrapper {
if (s_instance_ != NULL && is_initialized_ == false) {
VLOG(3) << "PSGPUWrapper Begin InitializeGPU";
is_initialized_ = true;
#if defined(PADDLE_WITH_PSCORE) && defined(PADDLE_WITH_HETERPS) && \
defined(PADDLE_WITH_NCCL)
const char* launch_mode = std::getenv("NCCL_LAUNCH_MODE");
if (launch_mode != nullptr) {
if (std::string(launch_mode) == "PARALLEL") {
PADDLE_THROW(common::errors::Unavailable(
"on heterps-mode you must export NCCL_LAUNCH_MODE=GROUP for no "
"hang, but received [%s]",
launch_mode));
}
} else {
PADDLE_THROW(
common::errors::Unavailable("on heterps-mode you must export "
"NCCL_LAUNCH_MODE=GROUP for no hang"));
}
#endif
resource_ = std::make_shared<HeterPsResource>(dev_ids);
resource_->enable_p2p();
keys_tensor.resize(resource_->total_device());
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pybind/fleet_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ void BindDistFleetWrapper(py::module* m) {
.def("pull_fl_strategy", &FleetWrapper::PullFlStrategy)
.def("revert", &FleetWrapper::Revert)
.def("set_date", &FleetWrapper::SetDate)
.def("print_table_stat", &FleetWrapper::PrintTableStat)
.def("check_save_pre_patch_done", &FleetWrapper::CheckSavePrePatchDone);
}

Expand Down
1 change: 1 addition & 0 deletions python/paddle/distributed/fleet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
load_inference_model = fleet.load_inference_model
load_one_table = fleet.load_one_table
set_date = fleet.set_date
print_table_stat = fleet.print_table_stat
minimize = fleet.minimize
distributed_model = distributed_model
shrink = fleet.shrink
Expand Down
21 changes: 21 additions & 0 deletions python/paddle/distributed/fleet/fleet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,6 +1419,27 @@ def set_date(self, table_id: int, day_id: str) -> None:
"""
self._runtime_handle._set_date(table_id, str(day_id))

@is_non_distributed_check
@inited_runtime_handler
def print_table_stat(self, table_id: int, pass_id: int, threshold: float):
"""
Print stat info of table_id for gpups table, format: tableid, feasign size, mf size.

Args:

table_id (int): The id of table.
pass_id (int): The id of pass.
threshold (float): The threshold of print.

Examples:

.. code-block:: text

fleet.print_table_stat(0,6,600000)

"""
self._runtime_handle._print_table_stat(table_id, pass_id, threshold)

@is_non_distributed_check
@inited_runtime_handler
def shrink(self, threshold: int | None = None) -> None:
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/distributed/ps/the_one_ps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1760,6 +1760,12 @@ def _set_date(self, table_id, day_id):
self._worker.set_date(table_id, day_id)
fleet.util.barrier()

def _print_table_stat(self, table_id, pass_id, threshold):
fleet.util.barrier()
if self.role_maker._is_first_worker():
self._worker.print_table_stat(table_id, pass_id, threshold)
fleet.util.barrier()

def _shrink(self, threshold=None):
if threshold is not None:
warnings.warn(
Expand Down