From 91d661f9094ad4360d8933d7724ea07cc929425d Mon Sep 17 00:00:00 2001 From: "zhangzechao.zzc" Date: Wed, 29 Oct 2025 20:04:10 +0800 Subject: [PATCH 1/6] feat[accl-barex]: add barex_transport by build with USE_BAREX --- mooncake-common/common.cmake | 5 + mooncake-integration/allocator.py | 2 +- .../transfer_engine/transfer_engine_py.cpp | 33 +- .../example/transfer_engine_bench.cpp | 16 +- mooncake-transfer-engine/include/config.h | 1 + .../include/transfer_engine.h | 3 + .../include/transfer_metadata.h | 2 + .../include/transfer_metadata_plugin.h | 2 +- .../transport/barex_transport/barex_context.h | 194 +++ .../barex_transport/barex_transport.h | 175 +++ .../include/transport/transport.h | 5 + mooncake-transfer-engine/src/CMakeLists.txt | 6 +- mooncake-transfer-engine/src/config.cpp | 11 + .../src/multi_transport.cpp | 40 + .../src/transfer_engine.cpp | 76 +- .../src/transfer_metadata.cpp | 21 +- .../src/transfer_metadata_plugin.cpp | 26 +- .../src/transport/CMakeLists.txt | 5 + .../transport/barex_transport/CMakeLists.txt | 5 + .../barex_transport/barex_context.cpp | 179 +++ .../barex_transport/barex_transport.cpp | 1322 +++++++++++++++++ scripts/build_wheel.sh | 11 +- 22 files changed, 2117 insertions(+), 23 deletions(-) create mode 100644 mooncake-transfer-engine/include/transport/barex_transport/barex_context.h create mode 100644 mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h create mode 100644 mooncake-transfer-engine/src/transport/barex_transport/CMakeLists.txt create mode 100644 mooncake-transfer-engine/src/transport/barex_transport/barex_context.cpp create mode 100644 mooncake-transfer-engine/src/transport/barex_transport/barex_transport.cpp diff --git a/mooncake-common/common.cmake b/mooncake-common/common.cmake index f117c4b83..042bc000d 100644 --- a/mooncake-common/common.cmake +++ b/mooncake-common/common.cmake @@ -60,6 +60,7 @@ option(USE_CUDA "option for enabling gpu features" OFF) option(USE_MUSA "option for enabling Moore Threads gpu features by leveraging MUSA (Meta-computing Unified System Architecture)" OFF) option(USE_NVMEOF "option for using NVMe over Fabric" OFF) option(USE_TCP "option for using TCP transport" ON) +option(USE_BAREX "option for using accl-barex transport" OFF) option(USE_ASCEND "option for using npu with HCCL" OFF) option(USE_ASCEND_DIRECT "option for using ascend npu with adxl engine" OFF) option(USE_ASCEND_HETEROGENEOUS "option for transferring between ascend npu and gpu" OFF) @@ -123,6 +124,10 @@ if (USE_TCP) add_compile_definitions(USE_TCP) endif() +if (USE_BAREX) + add_compile_definitions(USE_BAREX) +endif() + if (USE_ASCEND OR USE_ASCEND_DIRECT) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DOPEN_BUILD_PROJECT ") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DOPEN_BUILD_PROJECT ") diff --git a/mooncake-integration/allocator.py b/mooncake-integration/allocator.py index 185d7535f..af56cbc45 100644 --- a/mooncake-integration/allocator.py +++ b/mooncake-integration/allocator.py @@ -94,6 +94,6 @@ def get_allocator(cls, device: torch_device) -> CUDAPluggableAllocator: if device not in cls._instances: so_path = cls._get_so_path() cls._instances[device] = CUDAPluggableAllocator( - so_path, "u2mm_alloc_wrapper", "u2mm_free_wrapper" + so_path, "u2mm_alloc_wrapper_with_stream", "u2mm_free_wrapper_with_stream" ) return cls._instances[device] diff --git a/mooncake-integration/transfer_engine/transfer_engine_py.cpp b/mooncake-integration/transfer_engine/transfer_engine_py.cpp index 316a4d109..4dc00de62 100644 --- a/mooncake-integration/transfer_engine/transfer_engine_py.cpp +++ b/mooncake-integration/transfer_engine/transfer_engine_py.cpp @@ -132,7 +132,17 @@ int TransferEnginePy::initializeExt(const char *local_hostname, free_list_.resize(kSlabSizeKBTabLen); #if !defined(USE_ASCEND) && !defined(USE_ASCEND_DIRECT) && \ !defined(USE_ASCEND_HETEROGENEOUS) - doBuddyAllocate(kMaxClassId); + bool pass_alloc = false; + const char *pass_alloc_env = std::getenv("PASS_ALLOC"); + if (pass_alloc_env) { + int val = atoi(pass_alloc_env); + if (val != 0) { + pass_alloc = true; + } + } + if (!pass_alloc) { + doBuddyAllocate(kMaxClassId); + } #endif return 0; } @@ -266,6 +276,7 @@ int TransferEnginePy::transferSync(const char *target_hostname, if (handle_map_.count(target_hostname)) { handle = handle_map_[target_hostname]; } else { + LOG(INFO) << "transferSync, cache not found, openSegment with target " << target_hostname; handle = engine_->openSegment(target_hostname); if (handle == (Transport::SegmentHandle)-1) return -1; handle_map_[target_hostname] = handle; @@ -300,7 +311,17 @@ int TransferEnginePy::transferSync(const char *target_hostname, batch_id, {entry}, TransferMetadata::NotifyDesc{notify->name, notify->msg}) : engine_->submitTransfer(batch_id, {entry}); - if (!s.ok()) return -1; + if (!s.ok()) { + Status segment_status = engine_->CheckSegmentStatus(handle); + if (!segment_status.ok()) { + LOG(WARNING) << "submitTransfer failed with target " << target_hostname << ", CheckSegmentStatus not ok, ready to closeSegment"; + std::lock_guard guard(mutex_); + engine_->closeSegment(handle); + engine_->getMetadata()->removeSegmentDesc(target_hostname); + handle_map_.erase(target_hostname); + } + return -1; + } TransferStatus status; bool completed = false; @@ -387,6 +408,14 @@ int TransferEnginePy::batchTransferSync( : engine_->submitTransfer(batch_id, entries); if (!s.ok()) { engine_->freeBatchID(batch_id); + Status segment_status = engine_->CheckSegmentStatus(handle); + if (!segment_status.ok()) { + LOG(WARNING) << "submitTransfer failed with target " << target_hostname << ", CheckSegmentStatus not ok, ready to closeSegment"; + std::lock_guard guard(mutex_); + engine_->closeSegment(handle); + engine_->getMetadata()->removeSegmentDesc(target_hostname); + handle_map_.erase(target_hostname); + } return -1; } diff --git a/mooncake-transfer-engine/example/transfer_engine_bench.cpp b/mooncake-transfer-engine/example/transfer_engine_bench.cpp index 528d37d18..2cc00c779 100644 --- a/mooncake-transfer-engine/example/transfer_engine_bench.cpp +++ b/mooncake-transfer-engine/example/transfer_engine_bench.cpp @@ -74,7 +74,7 @@ DEFINE_string(mode, "initiator", "data blocks from target node"); DEFINE_string(operation, "read", "Operation type: read or write"); -DEFINE_string(protocol, "rdma", "Transfer protocol: rdma|tcp"); +DEFINE_string(protocol, "rdma", "Transfer protocol: rdma|barex|tcp"); DEFINE_string(device_name, "mlx5_2", "Device name to use, valid if protocol=rdma"); @@ -317,6 +317,12 @@ int initiator() { args[0] = (void *)nic_priority_matrix.c_str(); args[1] = nullptr; xport = engine->installTransport("rdma", args); + } else if (FLAGS_protocol == "barex") { + auto nic_priority_matrix = loadNicPriorityMatrix(); + void **args = (void **)malloc(2 * sizeof(void *)); + args[0] = (void *)nic_priority_matrix.c_str(); + args[1] = nullptr; + xport = engine->installTransport("barex", args); } else if (FLAGS_protocol == "tcp") { xport = engine->installTransport("tcp", nullptr); } else if (FLAGS_protocol == "nvlink") { @@ -436,7 +442,13 @@ int target() { void **args = (void **)malloc(2 * sizeof(void *)); args[0] = (void *)nic_priority_matrix.c_str(); args[1] = nullptr; - engine->installTransport("rdma", args); + engine->installTransport("rdma", args); + } else if (FLAGS_protocol == "barex") { + auto nic_priority_matrix = loadNicPriorityMatrix(); + void **args = (void **)malloc(2 * sizeof(void *)); + args[0] = (void *)nic_priority_matrix.c_str(); + args[1] = nullptr; + engine->installTransport("barex", args); } else if (FLAGS_protocol == "tcp") { engine->installTransport("tcp", nullptr); } else if (FLAGS_protocol == "nvlink") { diff --git a/mooncake-transfer-engine/include/config.h b/mooncake-transfer-engine/include/config.h index 33e71322e..930342469 100644 --- a/mooncake-transfer-engine/include/config.h +++ b/mooncake-transfer-engine/include/config.h @@ -50,6 +50,7 @@ struct GlobalConfig { bool use_ipv6 = false; size_t fragment_limit = 16384; bool enable_dest_device_affinity = false; + size_t eic_max_block_size = 64UL * 1024 * 1024; }; void loadGlobalConfig(GlobalConfig &config); diff --git a/mooncake-transfer-engine/include/transfer_engine.h b/mooncake-transfer-engine/include/transfer_engine.h index 0807ef690..7253c5f9c 100644 --- a/mooncake-transfer-engine/include/transfer_engine.h +++ b/mooncake-transfer-engine/include/transfer_engine.h @@ -95,6 +95,8 @@ class TransferEngine { SegmentHandle openSegment(const std::string &segment_name); + Status CheckSegmentStatus(SegmentID sid); + int closeSegment(SegmentHandle handle); int removeLocalSegment(const std::string &segment_name); @@ -249,6 +251,7 @@ class TransferEngine { // Set it to false only for testing. bool auto_discover_; std::vector filter_; + bool use_barex_ = false; #ifdef WITH_METRICS ylt::metric::counter_t transferred_bytes_counter_{ diff --git a/mooncake-transfer-engine/include/transfer_metadata.h b/mooncake-transfer-engine/include/transfer_metadata.h index 70f15c8d4..6c6511ca6 100644 --- a/mooncake-transfer-engine/include/transfer_metadata.h +++ b/mooncake-transfer-engine/include/transfer_metadata.h @@ -103,12 +103,14 @@ class TransferMetadata { struct RpcMetaDesc { std::string ip_or_host_name; uint16_t rpc_port; + uint16_t barex_port; int sockfd; // local cache }; struct HandShakeDesc { std::string local_nic_path; std::string peer_nic_path; + uint16_t barex_port; std::vector qp_num; std::string reply_msg; // on error }; diff --git a/mooncake-transfer-engine/include/transfer_metadata_plugin.h b/mooncake-transfer-engine/include/transfer_metadata_plugin.h index 44b5610f0..22c361161 100644 --- a/mooncake-transfer-engine/include/transfer_metadata_plugin.h +++ b/mooncake-transfer-engine/include/transfer_metadata_plugin.h @@ -69,7 +69,7 @@ struct HandShakePlugin { std::vector findLocalIpAddresses(); -uint16_t findAvailableTcpPort(int &sockfd); +uint16_t findAvailableTcpPort(int &sockfd, bool set_range=false); } // namespace mooncake diff --git a/mooncake-transfer-engine/include/transport/barex_transport/barex_context.h b/mooncake-transfer-engine/include/transport/barex_transport/barex_context.h new file mode 100644 index 000000000..15c8f8e4d --- /dev/null +++ b/mooncake-transfer-engine/include/transport/barex_transport/barex_context.h @@ -0,0 +1,194 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef BAREX_CONTEXT_H_ +#define BAREX_CONTEXT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common.h" +#include "transport/transport.h" + +#ifdef USE_BAREX +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +namespace mooncake { + +#ifdef USE_BAREX + +using namespace accl::barex; +using XChannel = accl::barex::XChannel; +using SegmentID = Transport::SegmentID; +using XContext = accl::barex::XContext; +using BarexResult = accl::barex::BarexResult; + +class ChannelCache { +public: + // 添加一个 channel 到指定 key & nic_id + void put(SegmentID key, int nic_id, XChannel* channel) { + RWSpinlock::WriteGuard guard(lock_); + auto& channels = cache_[key]; + auto& vec = channels[nic_id]; + status_map_[key] = true; + vec.push_back(channel); + } + + // 获取 sid 下指定 nic_id 和 idx 的 channel + XChannel* find(SegmentID key, int nic_id, int idx) { + RWSpinlock::ReadGuard guard(lock_); + auto it = cache_.find(key); + if (it == cache_.end()) return nullptr; + auto& channels = it->second; + auto ch_it = channels.find(nic_id); + if (ch_it == channels.end()) return nullptr; + auto& vec = ch_it->second; + if (idx >= 0 && idx < static_cast(vec.size())) { + return vec[idx]; + } + return nullptr; + } + + // 删除某个 channel(通过id和idx) + bool erase(SegmentID key, int nic_id, int idx) { + RWSpinlock::WriteGuard guard(lock_); + auto it = cache_.find(key); + if (it == cache_.end()) return false; + + auto& channels = it->second; + auto ch_it = channels.find(nic_id); + if (ch_it == channels.end()) return false; + + auto& vec = ch_it->second; + if (idx < 0 || idx >= static_cast(vec.size())) return false; + + vec.erase(vec.begin() + idx); + status_map_[key] = false; + if (vec.empty()) { + channels.erase(ch_it); + if (channels.empty()) { + cache_.erase(it); + } + } + return true; + } + + // 查询某个 SegmentID 下的 channel 状态 + bool CheckAllChannels(SegmentID segment_id) { + RWSpinlock::WriteGuard guard(lock_); + auto it = cache_.find(segment_id); + if (it == cache_.end()) { + return false; + } + auto& inner_map = it->second; + for (auto& pair : inner_map) { + auto& channels = pair.second; + for (XChannel* channel : channels) { + if (!channel->IsActive()) { + return false; + } + } + } + return true; + } + + // 检查并删除某个 SegmentID 下的异常channel,并返回删除的数量 + int RemoveInvalidChannels(SegmentID segment_id) { + RWSpinlock::WriteGuard guard(lock_); + auto it = cache_.find(segment_id); + if (it == cache_.end()) { + return 0; + } + + int invalid_count = 0; + auto& inner_map = it->second; + + for (auto& pair : inner_map) { + auto& channels = pair.second; + auto new_end = std::remove_if(channels.begin(), channels.end(), + [](XChannel* channel) { + return !channel->IsActive(); + }); + invalid_count += std::distance(new_end, channels.end()); + channels.erase(new_end, channels.end()); + } + return invalid_count; + } + + // 将所有的 channel 以 vector 形式返回 + std::vector copyAll() { + RWSpinlock::WriteGuard guard(lock_); + std::vector result; + for (const auto& [key, channels] : cache_) { + for (const auto& [nic_id, vec] : channels) { + result.insert(result.end(), vec.begin(), vec.end()); + } + } + return result; + } + +private: + std::unordered_map>> cache_; + std::unordered_map status_map_; + RWSpinlock lock_; +}; +class BarexContext { + public: + int submitPostSend(const std::vector &slice_list); + int addChannel(SegmentID sid, int device_id, XChannel *ch); + XChannel* getChannel(SegmentID sid, int device_id, int idx); + int checkStatus(SegmentID sid); + XContext* getCtx(); + // int ClearAllChannel(); + std::vector getAllChannel(); + bool active() const { return active_; } + void setQpNum(int qp_num) { qp_num_per_ctx_ = qp_num; } + int getQpNum() const { return qp_num_per_ctx_; } + + public: + BarexContext(XContext* xcontext, bool use_cpu, int device_id); + + ~BarexContext(); + + XContext* xcontext_; + bool barex_use_cpu_; + int barex_local_device_; + + private: + ChannelCache channel_cache_; + bool active_ = true; + int qp_num_per_ctx_ = 2; + +}; +#endif +} // namespace mooncake + +#endif // BAREX_CONTEXT_H_ \ No newline at end of file diff --git a/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h b/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h new file mode 100644 index 000000000..2a9d80c25 --- /dev/null +++ b/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h @@ -0,0 +1,175 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef BAREX_TRANSPORT_H_ +#define BAREX_TRANSPORT_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "topology.h" +#include "transfer_metadata.h" +#include "transport/transport.h" +#include "transport/barex_transport/barex_context.h" + +namespace mooncake { + +using TransferRequest = Transport::TransferRequest; +using TransferStatus = Transport::TransferStatus; +using TransferStatusEnum = Transport::TransferStatusEnum; +using SegmentID = Transport::SegmentID; +using BatchID = Transport::BatchID; + + +class TransferMetadata; +class CountDownLatch { +private: + int count_; + std::mutex mtx; + std::condition_variable cv; + +public: + CountDownLatch(int count) : count_(count){}; + + void CountDown() { + std::unique_lock lk(mtx); + count_--; + if (count_ <= 0) { + cv.notify_all(); + } + } + + void Wait() { + std::unique_lock lk(mtx); + cv.wait(lk, [this] { return count_ <= 0; }); + } +}; +class BarexTransport : public Transport { + public: + using BufferDesc = TransferMetadata::BufferDesc; + using SegmentDesc = TransferMetadata::SegmentDesc; + using HandShakeDesc = TransferMetadata::HandShakeDesc; + + public: + BarexTransport(); + + ~BarexTransport(); + + int install(std::string &local_server_name, + std::shared_ptr meta, + std::shared_ptr topo) override; + + const char *getName() const override { return "rdma"; } + + void setLocalPort(int port) { local_port_ = port; } + + void setPeerPort(int port) { peer_port_ = port; } + + int getLocalPort() { return local_port_; } + + int getPeerPort() { return peer_port_; } + + int registerLocalMemory(void *addr, size_t length, + const std::string &location, bool remote_accessible, + bool update_metadata) override; + + int registerLocalMemoryBase(void *addr, size_t length, + const std::string &location, bool remote_accessible, + bool update_metadata, bool is_gpu); + + int unregisterLocalMemory(void *addr, bool update_metadata = true) override; + + int registerLocalMemoryBatch(const std::vector &buffer_list, + const std::string &location) override; + + int unregisterLocalMemoryBatch( + const std::vector &addr_list) override; + + // TRANSFER + + Status submitTransfer(BatchID batch_id, + const std::vector &entries) override; + + Status submitTransferTask( + const std::vector &task_list) override; + + Status getTransferStatus(BatchID batch_id, + std::vector &status); + + Status getTransferStatus(BatchID batch_id, size_t task_id, + TransferStatus &status) override; + + SegmentID getSegmentID(const std::string &segment_name); + + Status OpenChannel(const std::string &segment_name, SegmentID sid) override; + Status CheckStatus(SegmentID sid) override; + + private: + int allocateLocalSegmentID(); + + public: + int onSetupRdmaConnections(const HandShakeDesc &peer_desc, + HandShakeDesc &local_desc); + + int sendHandshake(const std::string &peer_server_name, + const HandShakeDesc &local_desc, + HandShakeDesc &peer_desc) { + return metadata_->sendHandshake(peer_server_name, local_desc, + peer_desc); + } + + private: + int initializeRdmaResources(); + + int startHandshakeDaemon(std::string &local_server_name); + + public: + static int selectDevice(SegmentDesc *desc, uint64_t offset, size_t length, + int &buffer_id, int &device_id, int retry_cnt = 0); + + private: +#ifdef USE_BAREX + std::vector> server_context_list_; + std::vector> client_context_list_; + std::shared_ptr server_threadpool_; + std::shared_ptr client_threadpool_; + std::shared_ptr mempool_; + std::shared_ptr listerner_; + std::shared_ptr connector_; +#endif + std::shared_ptr local_topology_; + std::mutex buf_mutex_; + std::map> buf_length_map_; + bool use_random_dev_ = false; + bool barex_use_cpu_ = false; + int barex_local_device_ = 0; + int local_port_ = 8089; + int peer_port_ = 8089; + std::random_device rd; +}; + +} // namespace mooncake + +#endif // BAREX_TRANSPORT_H_ \ No newline at end of file diff --git a/mooncake-transfer-engine/include/transport/transport.h b/mooncake-transfer-engine/include/transport/transport.h index 1aaac3cf7..c4fda5bc6 100644 --- a/mooncake-transfer-engine/include/transport/transport.h +++ b/mooncake-transfer-engine/include/transport/transport.h @@ -91,6 +91,7 @@ class Transport { std::string peer_nic_path; SliceStatus status; TransferTask *task; + std::vector dest_rkeys; bool from_cache; union { @@ -98,6 +99,7 @@ class Transport { uint64_t dest_addr; uint32_t source_lkey; uint32_t dest_rkey; + int lkey_index; int rkey_index; volatile int *qp_depth; uint32_t retry_cnt; @@ -258,6 +260,9 @@ class Transport { size_t length; }; + virtual Status OpenChannel(const std::string &segment_name, SegmentID sid) { return Status::OK(); } + virtual Status CheckStatus(SegmentID sid) { return Status::OK(); } + protected: virtual int install(std::string &local_server_name, std::shared_ptr meta, diff --git a/mooncake-transfer-engine/src/CMakeLists.txt b/mooncake-transfer-engine/src/CMakeLists.txt index 769359bfc..b1f2e7746 100644 --- a/mooncake-transfer-engine/src/CMakeLists.txt +++ b/mooncake-transfer-engine/src/CMakeLists.txt @@ -9,7 +9,7 @@ if (BUILD_SHARED_LIBS) install(TARGETS transfer_engine DESTINATION lib) endif() -add_compile_definitions(transfer_engine PUBLIC MOONCAKE_USE_ETCD) +add_compile_definitions(transfer_engine PUBLIC MOONCAKE_USE_ETCD CMAKE_INCLUDE) if (USE_ETCD) if (USE_ETCD_LEGACY) if (USE_STATIC_ETCD_CPP_API) @@ -39,6 +39,10 @@ target_link_libraries( base transport rdma_transport ibverbs glog::glog gflags::gflags pthread JsonCpp::JsonCpp numa yalantinglibs::yalantinglibs ) +if (USE_BAREX) + target_link_libraries(transfer_engine PUBLIC barex_transport) +endif() + if (USE_CUDA) target_include_directories(transfer_engine PRIVATE /usr/local/cuda/include) target_link_libraries(transfer_engine PUBLIC cuda cudart rt) diff --git a/mooncake-transfer-engine/src/config.cpp b/mooncake-transfer-engine/src/config.cpp index 9bc4b76fb..44b00f0b3 100644 --- a/mooncake-transfer-engine/src/config.cpp +++ b/mooncake-transfer-engine/src/config.cpp @@ -167,6 +167,17 @@ void loadGlobalConfig(GlobalConfig &config) { << "Ignore value from environment variable MC_SLICE_SIZE"; } + const char *min_reg_size_env = std::getenv("MC_MIN_REG_SIZE"); + if (min_reg_size_env) { + size_t val = atoll(min_reg_size_env); + if (val > 0) { + config.eic_max_block_size = val; + LOG(INFO) << "Barex set MC_MIN_REG_SIZE=" << val; + } else + LOG(WARNING) + << "Ignore value from environment variable MC_MIN_REG_SIZE"; + } + const char *retry_cnt_env = std::getenv("MC_RETRY_CNT"); if (retry_cnt_env) { size_t val = atoi(retry_cnt_env); diff --git a/mooncake-transfer-engine/src/multi_transport.cpp b/mooncake-transfer-engine/src/multi_transport.cpp index 9c24836a2..1e16da43f 100644 --- a/mooncake-transfer-engine/src/multi_transport.cpp +++ b/mooncake-transfer-engine/src/multi_transport.cpp @@ -17,6 +17,9 @@ #include "config.h" #include "transport/rdma_transport/rdma_transport.h" +#ifdef USE_BAREX +#include "transport/barex_transport/barex_transport.h" +#endif #ifdef USE_TCP #include "transport/tcp_transport/tcp_transport.h" #endif @@ -202,6 +205,11 @@ Transport *MultiTransport::installTransport(const std::string &proto, if (std::string(proto) == "rdma") { transport = new RdmaTransport(); } +#ifdef USE_BAREX + else if (std::string(proto) == "barex") { + transport = new BarexTransport(); + } +#endif #ifdef USE_TCP else if (std::string(proto) == "tcp") { transport = new TcpTransport(); @@ -244,6 +252,38 @@ Transport *MultiTransport::installTransport(const std::string &proto, return nullptr; } +#ifdef USE_BAREX + bool use_eic = false; + for (auto& dev : topo->getHcaList()) { + if (dev.find("soe") != std::string::npos || dev.find("solar") != std::string::npos) { + use_eic = true; + } + } + + if (std::string(proto) == "barex") { + std::string nics; + for (auto& dev : topo->getHcaList()) { + if (use_eic) { + if (dev.find("soe") == std::string::npos && dev.find("solar") == std::string::npos) { + // ignore no eic nics + continue; + } + } + nics += dev; + nics += ","; + } + + // 移除最后一个多余的逗号 + if (!nics.empty()) { + nics.pop_back(); + } + + if (!nics.empty()) { + LOG(INFO) << "ACCL_USE_NICS is set to " << nics; + setenv("ACCL_USE_NICS", nics.c_str(), 1); + } + } +#endif if (transport->install(local_server_name_, metadata_, topo)) { return nullptr; } diff --git a/mooncake-transfer-engine/src/transfer_engine.cpp b/mooncake-transfer-engine/src/transfer_engine.cpp index 32b46b13f..c86ccce81 100644 --- a/mooncake-transfer-engine/src/transfer_engine.cpp +++ b/mooncake-transfer-engine/src/transfer_engine.cpp @@ -25,6 +25,7 @@ #include "transfer_metadata_plugin.h" #include "transport/transport.h" +#include "transport/barex_transport/barex_transport.h" namespace mooncake { @@ -72,6 +73,15 @@ int TransferEngine::init(const std::string &metadata_conn_string, "files are opened."; } // Set resources to the maximum value +#ifdef USE_BAREX + const char *use_barex_env = std::getenv("USE_BAREX"); + if (use_barex_env) { + int val = atoi(use_barex_env); + if (val != 0) { + use_barex_ = true; + } + } +#endif #ifdef USE_ASCEND // The only difference in initializing the Ascend Transport is that the @@ -99,7 +109,18 @@ int TransferEngine::init(const std::string &metadata_conn_string, desc.ip_or_host_name = host_name; desc.rpc_port = port; desc.sockfd = -1; - +#ifdef USE_BAREX + if (use_barex_) { + int tmp_fd = -1; + desc.barex_port = findAvailableTcpPort(tmp_fd, true); + if (desc.barex_port == 0) { + LOG(ERROR) << "Barex: No valid port found for local barex service."; + return -1; + } + close(tmp_fd); + tmp_fd = -1; + } +#endif if (metadata_conn_string == P2PHANDSHAKE) { rpc_binding_method = "P2P handshake"; desc.rpc_port = findAvailableTcpPort(desc.sockfd); @@ -145,7 +166,8 @@ int TransferEngine::init(const std::string &metadata_conn_string, LOG(INFO) << "Transfer Engine RPC using " << rpc_binding_method << ", listening on " << desc.ip_or_host_name << ":" - << desc.rpc_port; + << desc.rpc_port + << (use_barex_ ? ", barex use port:" + std::to_string(desc.barex_port) : ""); metadata_ = std::make_shared(metadata_conn_string); #ifdef USE_ASCEND @@ -228,11 +250,22 @@ int TransferEngine::init(const std::string &metadata_conn_string, if (local_topology_->getHcaList().size() > 0 && !getenv("MC_FORCE_TCP")) { // only install RDMA transport when there is at least one HCA - Transport *rdma_transport = - multi_transports_->installTransport("rdma", local_topology_); - if (!rdma_transport) { - LOG(ERROR) << "Failed to install RDMA transport"; + Transport* rdma_transport = nullptr; + if (use_barex_) { +#ifdef USE_BAREX + rdma_transport = multi_transports_->installTransport("barex", local_topology_); +#else + LOG(ERROR) << "Set USE BAREX while barex not compiled"; + return -1; +#endif + } else { + rdma_transport = multi_transports_->installTransport("rdma", local_topology_); + } + if (rdma_transport == nullptr) { + LOG(ERROR) << "Failed to install RDMA transport, type=" << (use_barex_ ? "barex" : "rdma"); return -1; + } else { + LOG(INFO) << "installTransport, type=" << (use_barex_ ? "barex" : "rdma"); } } else { Transport *tcp_transport = @@ -328,7 +361,36 @@ Transport::SegmentHandle TransferEngine::openSegment( while (!trimmed_segment_name.empty() && trimmed_segment_name[0] == '/') trimmed_segment_name.erase(0, 1); if (trimmed_segment_name.empty()) return ERR_INVALID_ARGUMENT; - return metadata_->getSegmentID(trimmed_segment_name); + SegmentID sid = metadata_->getSegmentID(trimmed_segment_name); +#ifdef USE_BAREX + if (use_barex_) { + Transport* transport = multi_transports_->getTransport("barex"); + if (!transport) { + LOG(ERROR) << "Barex proto not installed"; + return (Transport::SegmentHandle)-1; + } + Status s = transport->OpenChannel(segment_name, sid); + if (!s.ok()) { + LOG(ERROR) << "openSegment, OpenChannel failed"; + return (Transport::SegmentHandle)-1; + } + } +#endif + return sid; +} + +Status TransferEngine::CheckSegmentStatus(SegmentID sid) { +#ifdef USE_BAREX + if (use_barex_) { + Transport* transport = multi_transports_->getTransport("barex"); + BarexTransport* barex_transport = dynamic_cast(transport); + return barex_transport->CheckStatus(sid); + } else { + return Status::OK(); + } +#else + return Status::OK(); +#endif } int TransferEngine::closeSegment(Transport::SegmentHandle handle) { return 0; } diff --git a/mooncake-transfer-engine/src/transfer_metadata.cpp b/mooncake-transfer-engine/src/transfer_metadata.cpp index 04c6964ac..0cfb669e9 100644 --- a/mooncake-transfer-engine/src/transfer_metadata.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata.cpp @@ -58,6 +58,7 @@ struct TransferHandshakeUtil { Json::Value root; root["local_nic_path"] = desc.local_nic_path; root["peer_nic_path"] = desc.peer_nic_path; + root["barex_port"] = desc.barex_port; Json::Value qpNums(Json::arrayValue); for (const auto &qp : desc.qp_num) qpNums.append(qp); root["qp_num"] = qpNums; @@ -68,6 +69,7 @@ struct TransferHandshakeUtil { static int decode(Json::Value root, TransferMetadata::HandShakeDesc &desc) { desc.local_nic_path = root["local_nic_path"].asString(); desc.peer_nic_path = root["peer_nic_path"].asString(); + desc.barex_port = root["barex_port"].asInt(); for (const auto &qp : root["qp_num"]) desc.qp_num.push_back(qp.asUInt()); desc.reply_msg = root["reply_msg"].asString(); @@ -126,7 +128,7 @@ int TransferMetadata::encodeSegmentDesc(const SegmentDesc &desc, segmentJSON["tcp_data_port"] = desc.tcp_data_port; segmentJSON["timestamp"] = getCurrentDateTime(); - if (segmentJSON["protocol"] == "rdma") { + if (segmentJSON["protocol"] == "rdma" || segmentJSON["protocol"] == "barex") { Json::Value devicesJSON(Json::arrayValue); for (const auto &device : desc.devices) { Json::Value deviceJSON; @@ -255,6 +257,14 @@ int TransferMetadata::updateSegmentDesc(const std::string &segment_name, int TransferMetadata::removeSegmentDesc(const std::string &segment_name) { if (p2p_handshake_mode_) { + auto iter = segment_name_to_id_map_.find(segment_name); + if (iter != segment_name_to_id_map_.end()){ + LOG(INFO) << "removeSegmentDesc " << segment_name << " finish"; + segment_id_to_desc_map_.erase(iter->second); + segment_name_to_id_map_.erase(iter); + } else { + LOG(INFO) << "removeSegmentDesc " << segment_name << " not found, already removed maybe"; + } return 0; } if (!storage_plugin_->remove(getFullMetadataKey(segment_name))) { @@ -275,7 +285,7 @@ TransferMetadata::decodeSegmentDesc(Json::Value &segmentJSON, if (segmentJSON.isMember("timestamp")) desc->timestamp = segmentJSON["timestamp"].asString(); - if (desc->protocol == "rdma") { + if (desc->protocol == "rdma" || desc->protocol == "barex") { for (const auto &deviceJSON : segmentJSON["devices"]) { DeviceDesc device; device.name = deviceJSON["name"].asString(); @@ -302,7 +312,12 @@ TransferMetadata::decodeSegmentDesc(Json::Value &segmentJSON, buffer.rkey.empty() || buffer.rkey.size() != buffer.lkey.size()) { LOG(WARNING) << "Corrupted segment descriptor, name " - << segment_name << " protocol " << desc->protocol; + << segment_name << " protocol " << desc->protocol + << ", " << buffer.name + << ", " << buffer.addr + << ", " << buffer.length + << ", " << buffer.rkey.size() + << ", " << buffer.lkey.size(); return nullptr; } desc->buffers.push_back(buffer); diff --git a/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp b/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp index 1e686c711..34b2d5ffb 100644 --- a/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp @@ -1135,11 +1135,31 @@ std::vector findLocalIpAddresses() { return ips; } -uint16_t findAvailableTcpPort(int &sockfd) { +uint16_t findAvailableTcpPort(int &sockfd, bool set_range) { static std::random_device rand_gen; std::uniform_int_distribution rand_dist; - const int min_port = globalConfig().rpc_min_port; - const int max_port = globalConfig().rpc_max_port; + int min_port = globalConfig().rpc_min_port;; + int max_port = globalConfig().rpc_max_port;; +#ifdef USE_BAREX + if (set_range) { + min_port = 17000; + max_port = 35000; + const char *min_port_env = std::getenv("ACCL_MIN_PORT"); + const char *max_port_env = std::getenv("ACCL_MAX_PORT"); + if (min_port_env) { + int val = atoi(min_port_env); + if (val > 1024 && val < 65536) { + min_port = val; + } + } + if (max_port_env) { + int val = atoi(max_port_env); + if (val > 1024 && val < 65536 && val > min_port) { + max_port = val; + } + } + } +#endif const int max_attempts = 500; bool use_ipv6 = globalConfig().use_ipv6; diff --git a/mooncake-transfer-engine/src/transport/CMakeLists.txt b/mooncake-transfer-engine/src/transport/CMakeLists.txt index 5517a5ddc..026d75fe9 100644 --- a/mooncake-transfer-engine/src/transport/CMakeLists.txt +++ b/mooncake-transfer-engine/src/transport/CMakeLists.txt @@ -9,6 +9,11 @@ if (USE_TCP) target_sources(transport PUBLIC $) endif() +if (USE_BAREX) + add_subdirectory(barex_transport) + target_sources(transport PUBLIC $) +endif() + if (USE_NVMEOF) add_subdirectory(nvmeof_transport) target_sources(transport PUBLIC $) diff --git a/mooncake-transfer-engine/src/transport/barex_transport/CMakeLists.txt b/mooncake-transfer-engine/src/transport/barex_transport/CMakeLists.txt new file mode 100644 index 000000000..f5df7ce9b --- /dev/null +++ b/mooncake-transfer-engine/src/transport/barex_transport/CMakeLists.txt @@ -0,0 +1,5 @@ +file(GLOB BAREX_SOURCES "*.cpp") + +add_library(barex_transport OBJECT ${BAREX_SOURCES}) +target_link_libraries(barex_transport PRIVATE pthread accl_barex) +target_compile_definitions(barex_transport PRIVATE CMAKE_INCLUDE=1) \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/barex_transport/barex_context.cpp b/mooncake-transfer-engine/src/transport/barex_transport/barex_context.cpp new file mode 100644 index 000000000..3f65ea668 --- /dev/null +++ b/mooncake-transfer-engine/src/transport/barex_transport/barex_context.cpp @@ -0,0 +1,179 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "transport/barex_transport/barex_context.h" + +namespace mooncake { + +using namespace accl::barex; + +BarexContext::BarexContext(XContext* xcontext, bool use_cpu, int device_id) : xcontext_(xcontext), barex_use_cpu_(use_cpu), barex_local_device_(device_id) {} + + +BarexContext::~BarexContext() { + if (xcontext_) { + xcontext_->Shutdown(); + xcontext_->WaitStop(); + delete xcontext_; + } +} + +int BarexContext::addChannel(SegmentID sid, int device_id, XChannel* ch) { + channel_cache_.put(sid, device_id, ch); + return 0; +} + +XChannel* BarexContext::getChannel(SegmentID sid, int device_id, int idx) { + XChannel* channel = channel_cache_.find(sid, device_id, idx); + return channel; +} + +int BarexContext::checkStatus(SegmentID sid) { + return channel_cache_.RemoveInvalidChannels(sid); +} + +XContext* BarexContext::getCtx() { + return xcontext_; +} + +std::vector BarexContext::getAllChannel() { + return channel_cache_.copyAll(); +} + +int BarexContext::submitPostSend( + const std::vector &slice_list) { + std::unordered_map>> sid_dev_data_map; + std::unordered_map>> sid_dev_slice_map; + for (auto slice : slice_list) { + accl::barex::rw_memp_t w_m; + w_m.sg.addr = (uint64_t)slice->source_addr; + w_m.sg.length = (uint32_t)slice->length; + w_m.sg.lkey = slice->rdma.source_lkey; + w_m.data.d_type = barex_use_cpu_ ? CPU : GPU; + w_m.data.device_id = barex_local_device_; + w_m.r_addr = slice->rdma.dest_addr; + w_m.r_key = slice->rdma.dest_rkey; + w_m.r_ttl_ms = UINT64_MAX; + auto& dev_map = sid_dev_data_map[slice->target_id]; + int lkey_index = slice->rdma.lkey_index; + dev_map[lkey_index].push_back(w_m); + auto& slice_map = sid_dev_slice_map[slice->target_id]; + slice_map[lkey_index].push_back(slice); + } + for (auto& pair : sid_dev_data_map) { + SegmentID sid = pair.first; + auto& dev_map = pair.second; + + for (auto& dev_pair : dev_map) { + int dev = dev_pair.first; + std::vector& data_vec = dev_pair.second; + std::vector& slice_vec = sid_dev_slice_map[sid][dev]; + size_t data_size = data_vec.size(); + int qp_in_use = qp_num_per_ctx_; + if (data_size < (size_t)qp_num_per_ctx_) { + qp_in_use = data_size; + } + size_t begin_idx = 0; + size_t end_idx = 0; + size_t batch_size = data_size / qp_in_use; + size_t reminder = data_size % qp_in_use; + + int retry_cnt = 5; + for (int i=0; i < qp_in_use; i++) { + XChannel* channel = nullptr; + for (int j=0; j < retry_cnt; j++) { + channel = channel_cache_.find(sid, dev, i); + if (!channel) { + LOG(ERROR) << "Write fail, sid " << sid << ", dev " << dev << ", id " << i << " not found, retry " << j << "/" << retry_cnt; + break; + } + if (!channel->IsActive()) { + LOG(WARNING) << "Write fail, channel status error " << channel << " retry " << j << "/" << retry_cnt; + channel_cache_.erase(sid, dev, i); + continue; + } + } + if (!channel) { + LOG(ERROR) << "Write fail, no channel found"; + return -1; + } + + end_idx += batch_size; + if (i == qp_in_use - 1) { + end_idx += reminder; + } + int peer_nic_id = channel->GetPeerNicId(); + auto data_chunk_read = std::make_shared>(); + auto data_chunk_write = std::make_shared>(); + auto slice_chunk_read = std::make_shared>(); + auto slice_chunk_write = std::make_shared>(); + for (size_t idx = begin_idx; idx < end_idx; idx++) { + data_vec[idx].r_key = slice_vec[idx]->dest_rkeys[peer_nic_id]; + if (slice_vec[idx]->opcode == Transport::TransferRequest::READ) { + data_chunk_read->emplace_back(data_vec[idx]); + slice_chunk_read->emplace_back(slice_vec[idx]); + } else { + data_chunk_write->emplace_back(data_vec[idx]); + slice_chunk_write->emplace_back(slice_vec[idx]); + } + } + + if (!data_chunk_write->empty()) { + BarexResult r = channel->WriteBatch( + data_chunk_write, + [slice_chunk_write](accl::barex::Status s) { + if(!s.IsOk()) { + LOG(ERROR) << "WriteBatch fail, " << s.ErrMsg().c_str(); + for (auto slice : *slice_chunk_write) { + slice->markFailed(); + } + } else { + for (auto slice : *slice_chunk_write) { + slice->markSuccess(); + } + } + }, true); + if (r != accl::barex::BAREX_SUCCESS) { + LOG(ERROR) << "WriteBatch fail, ret " << r; + return -2; + } + } + if (!data_chunk_read->empty()) { + BarexResult r = channel->ReadBatch( + data_chunk_read, + [slice_chunk_read](accl::barex::Status s) { + if(!s.IsOk()) { + LOG(ERROR) << "ReadBatch fail, " << s.ErrMsg().c_str(); + for (auto slice : *slice_chunk_read) { + slice->markFailed(); + } + } else { + for (auto slice : *slice_chunk_read) { + slice->markSuccess(); + } + } + }, true); + if (r != accl::barex::BAREX_SUCCESS) { + LOG(ERROR) << "ReadBatch fail, ret " << r; + return -2; + } + } + begin_idx += batch_size; + } + } + } + return 0; +} + +} // namespace mooncake \ No newline at end of file diff --git a/mooncake-transfer-engine/src/transport/barex_transport/barex_transport.cpp b/mooncake-transfer-engine/src/transport/barex_transport/barex_transport.cpp new file mode 100644 index 000000000..0ffdb52ed --- /dev/null +++ b/mooncake-transfer-engine/src/transport/barex_transport/barex_transport.cpp @@ -0,0 +1,1322 @@ +// Copyright 2024 KVCache.AI +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "transport/barex_transport/barex_transport.h" + +#include +#include +#include + +#include +#include +#include +#include + +#include "common.h" +#include "config.h" +#include "memory_location.h" +#include "topology.h" +// #include "transport/rdma_transport/rdma_context.h" +// #include "transport/rdma_transport/rdma_endpoint.h" + +namespace mooncake { +using namespace accl::barex; + +class EmptyCallback : public XChannelCallback { +public: + void OnRecvCall(XChannel *channel, char *buf, size_t len, x_msg_header header) {} +}; + +BarexTransport::BarexTransport() {} + +BarexTransport::~BarexTransport() { +#ifdef CONFIG_USE_BATCH_DESC_SET + for (auto &entry : batch_desc_set_) delete entry.second; + batch_desc_set_.clear(); +#endif + for (auto ctx : client_context_list_) { + std::vector chs = ctx->getAllChannel(); + for (auto ch : chs) { + BarexResult ret = connector_->CloseChannel(ch, [&, ch](accl::barex::Status s) { + LOG(INFO) << "CloseChannel() finished, s.IsOk=" << s.IsOk(); + ch->Destroy(); + }); + if (ret != accl::barex::BAREX_SUCCESS) { + LOG(ERROR) << "CloseChannel() failed, ret " << ret; + } + } + } + client_context_list_.clear(); + server_context_list_.clear(); + metadata_->removeSegmentDesc(local_server_name_); + batch_desc_set_.clear(); + connector_->Shutdown(); + connector_->WaitStop(); + listerner_->Shutdown(); + listerner_->WaitStop(); + server_threadpool_->Shutdown(); + server_threadpool_->WaitStop(); + client_threadpool_->Shutdown(); + client_threadpool_->WaitStop(); + mempool_->Shutdown(); + mempool_->WaitStop(); +} + +int BarexTransport::install(std::string &local_server_name, + std::shared_ptr meta, + std::shared_ptr topo) { + if (topo == nullptr) { + LOG(ERROR) << "BarexTransport: missing topology"; + return ERR_INVALID_ARGUMENT; + } + + metadata_ = meta; + local_server_name_ = local_server_name; + local_topology_ = topo; + + const char *barex_random_dev_env = std::getenv("BAREX_USE_RANDOM_DEV"); + if (barex_random_dev_env) { + int val = atoi(barex_random_dev_env); + if (val != 0) { + LOG(INFO) << "BarexTransport: use random rdma device"; + use_random_dev_ = true; + } + } + + const char *barex_use_cpu_env = std::getenv("ACCL_USE_CPU"); + if (barex_use_cpu_env) { + int val = atoi(barex_use_cpu_env); + if (val != 0) { + LOG(INFO) << "BarexTransport: use_cpu"; + barex_use_cpu_ = true; + } + } + + const char *barex_local_device_env = std::getenv("ACCL_LOCAL_DEVICE"); + if (barex_local_device_env) { + int val = atoi(barex_local_device_env); + LOG(INFO) << "BarexTransport: set local device id " << val; + barex_local_device_ = val; + } + + auto ret = initializeRdmaResources(); + if (ret) { + LOG(ERROR) << "BarexTransport: cannot initialize RDMA resources"; + return ret; + } + + ret = allocateLocalSegmentID(); + if (ret) { + LOG(ERROR) << "Transfer engine cannot be initialized: cannot " + "allocate local segment"; + return ret; + } + + ret = startHandshakeDaemon(local_server_name); + if (ret) { + LOG(ERROR) << "BarexTransport: cannot start handshake daemon"; + return ret; + } + + ret = metadata_->updateLocalSegmentDesc(); + if (ret) { + LOG(ERROR) << "BarexTransport: cannot publish segments"; + return ret; + } + + return 0; +} + +int BarexTransport::registerLocalMemory(void *addr, size_t length, + const std::string &name, + bool remote_accessible, + bool update_metadata) { + auto &config = globalConfig(); + size_t buffer_size = config.eic_max_block_size; + size_t remaining = length; + void *current_ptr = addr; + device_type dtype; + + if (name.find("cuda") != std::string::npos || name == kWildcardLocation) { + dtype = GPU; + } else if (name.find("cpu") != std::string::npos) { + dtype = CPU; + } else { + LOG(ERROR) << "BarexTransport: registerLocalMemory, cannot recognize: name " << name + << ", need include cpu or cuda in name"; + return ERR_INVALID_ARGUMENT; + } + + bool is_gpu = dtype == GPU ? true : false; + + while (remaining > 0) { + size_t buffer_len = std::min(buffer_size, remaining); + int ret = registerLocalMemoryBase(current_ptr, buffer_len, name, remote_accessible, update_metadata, is_gpu); + if (ret) { + LOG(ERROR) << "registerLocalMemoryBase failed, ret " << ret; + return -1; + } + current_ptr = static_cast(current_ptr) + buffer_len; + remaining -= buffer_len; + } + + std::lock_guard guard(buf_mutex_); + if (dtype == CPU) { + buf_length_map_.emplace(addr, std::make_pair(length, 0)); + } else { + buf_length_map_.emplace(addr, std::make_pair(length, 1)); + } + + return 0; +} + +int BarexTransport::registerLocalMemoryBase(void *addr, size_t length, + const std::string &name, + bool remote_accessible, + bool update_metadata, + bool is_gpu) { + (void)remote_accessible; + BufferDesc buffer_desc; + memp_t mem; + BarexResult result; + device_type dtype = is_gpu ? GPU : CPU; + result = mempool_->RegUserMr(mem, addr, length, dtype); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport: registerLocalMemory failed" + << ", result " << result << ", addr " << addr + << ", length " << length << ", name "<< name; + return ERR_ADDRESS_NOT_REGISTERED; + } else { + for (auto &mr : mem.mrs) { + buffer_desc.lkey.push_back(mr.second->lkey); + buffer_desc.rkey.push_back(mr.second->rkey); + } + } + + // Get the memory location automatically after registered MR(pinned), + // when the name is kWildcardLocation("*"). + if (name == kWildcardLocation) { + bool only_first_page = true; + const std::vector entries = + getMemoryLocation(addr, length, only_first_page); + for (auto &entry : entries) { + buffer_desc.name = entry.location; + buffer_desc.addr = entry.start; + buffer_desc.length = entry.len; + int rc = + metadata_->addLocalMemoryBuffer(buffer_desc, update_metadata); + if (rc) return rc; + } + } else { + buffer_desc.name = name; + buffer_desc.addr = (uint64_t)addr; + buffer_desc.length = length; + int rc = metadata_->addLocalMemoryBuffer(buffer_desc, update_metadata); + if (rc) return rc; + } + + return 0; +} + +int BarexTransport::unregisterLocalMemory(void *addr, bool update_metadata) { + int rc = metadata_->removeLocalMemoryBuffer(addr, update_metadata); + if (rc) return rc; + + auto &config = globalConfig(); + size_t buffer_size = config.eic_max_block_size; + void *current_ptr = addr; + device_type dtype; + BarexResult result; + size_t remaining = 0; + memp_t mem; + { + std::lock_guard guard(buf_mutex_); + auto iter = buf_length_map_.find(addr); + if (iter != buf_length_map_.end()) { + remaining = iter->second.first; + dtype = iter->second.second ? GPU : CPU; + buf_length_map_.erase(iter); + } + } + + while (remaining > 0) { + size_t buffer_len = std::min(buffer_size, remaining); + if (current_ptr > addr) { + int rc = metadata_->removeLocalMemoryBuffer(current_ptr, update_metadata); + if (rc) { + LOG(WARNING) << "unregisterLocalMemory, removeLocalMemoryBuffer failed, addr " << addr; + } + } + result = mempool_->DeregUserMr(current_ptr, dtype); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "unregisterLocalMemory, DeregUserMr, failed, ret " << result << ", addr " << current_ptr; + return -1; + } + current_ptr = static_cast(current_ptr) + buffer_len; + remaining -= buffer_len; + } + + return 0; +} + +int BarexTransport::allocateLocalSegmentID() { + auto desc = std::make_shared(); + if (!desc) return ERR_MEMORY; + desc->name = local_server_name_; + desc->protocol = "barex"; + for (auto &entry : server_context_list_) { + TransferMetadata::DeviceDesc device_desc; + device_desc.name = entry->getCtx()->GetXDevice()->GetName(); + // TODO is barex need this? + device_desc.lid = 0; //entry->lid(); + device_desc.gid = "ignore"; //entry->gid(); + desc->devices.push_back(device_desc); + } + desc->topology = *(local_topology_.get()); + metadata_->addLocalSegment(LOCAL_SEGMENT_ID, local_server_name_, + std::move(desc)); + return 0; +} + +int BarexTransport::registerLocalMemoryBatch( + const std::vector &buffer_list, + const std::string &location) { + for (auto &buffer : buffer_list) { + int ret = registerLocalMemory(buffer.addr, buffer.length, location, true, false); + if (ret) { + LOG(ERROR) << "BarexTransport: Failed to register memory: addr " + << buffer.addr << " length " + << buffer.length; + return ERR_ADDRESS_NOT_REGISTERED; + } + } + + return metadata_->updateLocalSegmentDesc(); +} + +int BarexTransport::unregisterLocalMemoryBatch( + const std::vector &addr_list) { + std::vector> results; + for (auto &addr : addr_list) { + results.emplace_back( + std::async(std::launch::async, [this, addr]() -> int { + return unregisterLocalMemory(addr, false); + })); + } + + for (size_t i = 0; i < addr_list.size(); ++i) { + if (results[i].get()) + LOG(WARNING) << "BarexTransport: Failed to unregister memory: addr " + << addr_list[i]; + } + + return metadata_->updateLocalSegmentDesc(); +} + +Status BarexTransport::submitTransfer( + BatchID batch_id, const std::vector &entries) { + auto &batch_desc = *((BatchDesc *)(batch_id)); + if (batch_desc.task_list.size() + entries.size() > batch_desc.batch_size) { + LOG(ERROR) << "BarexTransport: Exceed the limitation of current batch's " + "capacity"; + return Status::InvalidArgument( + "BarexTransport: Exceed the limitation of capacity, batch id: " + + std::to_string(batch_id)); + } + + std::unordered_map, std::vector> + slices_to_post; + size_t task_id = batch_desc.task_list.size(); + batch_desc.task_list.resize(task_id + entries.size()); + auto local_segment_desc = metadata_->getSegmentDescByID(LOCAL_SEGMENT_ID); + // const size_t kBlockSize = globalConfig().slice_size; + const size_t kBlockMaxSize = globalConfig().eic_max_block_size; + const int kMaxRetryCount = globalConfig().retry_cnt; + std::unordered_map> segment_desc_map; + for (auto &request : entries) { + auto target_id = request.target_id; + if (!segment_desc_map.count(target_id)) + segment_desc_map[target_id] = metadata_->getSegmentDescByID(target_id); + } + for (auto &request : entries) { + TransferTask &task = batch_desc.task_list[task_id]; + ++task_id; + SegmentID target_id = request.target_id; + auto peer_segment_desc = segment_desc_map[target_id]; + if (!peer_segment_desc) { + LOG(ERROR) << "peer_segment_desc not found for target_id " << target_id; + return Status::InvalidArgument( + "BarexTransport: peer_segment_desc not found, batch id: " + + std::to_string(batch_id)); + } + size_t kBlockSize = std::min(request.length, kBlockMaxSize); + for (uint64_t offset = 0; offset < request.length; offset += kBlockSize) { + Slice *slice = getSliceCache().allocate(); + slice->source_addr = (char *)request.source + offset; + slice->length = std::min(request.length - offset, kBlockSize); + slice->opcode = request.opcode; + slice->rdma.dest_addr = request.target_offset + offset; + slice->rdma.retry_cnt = 0; + slice->rdma.max_retry_cnt = kMaxRetryCount; + slice->task = &task; + slice->target_id = request.target_id; + slice->ts = 0; + slice->status = Slice::PENDING; + task.slice_list.push_back(slice); + + int peer_buffer_id = -1, extra_peer_buffer_id = 0, peer_device_id = -1; + int local_buffer_id = -1, extra_local_buffer_id = 0, device_id = -1, retry_cnt = 0; + while (retry_cnt < kMaxRetryCount) { + int ret = selectDevice(local_segment_desc.get(), + (uint64_t)slice->source_addr, slice->length, + local_buffer_id, device_id, retry_cnt++); + if (ret) { + if (ret == ERR_ADDRESS_NOT_REGISTERED) { + LOG(WARNING) << "local_segment_desc selectDevice failed"; + continue; + } else { + // need 2 blocks + extra_local_buffer_id = local_buffer_id + 1; + } + } + ret = selectDevice(peer_segment_desc.get(), + slice->rdma.dest_addr, + slice->length, peer_buffer_id, peer_device_id, + slice->rdma.retry_cnt); + if (ret) { + if (ret == ERR_ADDRESS_NOT_REGISTERED) { + LOG(WARNING) << "peer_segment_desc selectDevice failed"; + continue; + } else { + // need 2 blocks + extra_peer_buffer_id = peer_buffer_id + 1; + } + } + assert(device_id >= 0); + if (device_id >= static_cast(client_context_list_.size()) || use_random_dev_) { + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, client_context_list_.size() - 1); + device_id = dis(gen); + } + auto &context = client_context_list_[device_id]; + if (!context->active()) continue; + assert(context->getCtx()->GetXDevice()->GetId() == device_id); + // 4 types, local:peer = 1:1, 1:2, 2:1, 2:2 + if (!extra_local_buffer_id && !extra_peer_buffer_id) { // 1:1 + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += slice->length; + __sync_fetch_and_add(&task.slice_count, 1); + break; + } else if (!extra_local_buffer_id && extra_peer_buffer_id) { // 1:2 + auto &first_peer_buffer_desc = peer_segment_desc.get()->buffers[peer_buffer_id]; + auto &last_peer_buffer_desc = peer_segment_desc.get()->buffers[extra_peer_buffer_id]; + size_t first_length = first_peer_buffer_desc.addr + first_peer_buffer_desc.length - slice->rdma.dest_addr; + size_t last_length = slice->rdma.dest_addr + slice->length - last_peer_buffer_desc.addr; + assert(first_length + last_length == slice->length); + // add first part + slice->length = first_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = (char *)request.source + offset + first_length; + last_slice->length = last_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = request.target_offset + offset + first_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_length; + __sync_fetch_and_add(&task.slice_count, 1); + break; + } else if (extra_local_buffer_id && !extra_peer_buffer_id) { // 2:1 + auto &first_local_buffer_desc = local_segment_desc.get()->buffers[local_buffer_id]; + auto &last_local_buffer_desc = local_segment_desc.get()->buffers[extra_local_buffer_id]; + size_t first_length = first_local_buffer_desc.addr + first_local_buffer_desc.length - (size_t)slice->source_addr; + size_t last_length = (size_t)slice->source_addr + slice->length - last_local_buffer_desc.addr; + assert(first_length + last_length == slice->length); + // add first part + slice->length = first_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = (char *)request.source + offset + first_length; + last_slice->length = last_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = request.target_offset + offset + first_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_length; + __sync_fetch_and_add(&task.slice_count, 1); + break; + } else { // 2:2 + auto &first_local_buffer_desc = local_segment_desc.get()->buffers[local_buffer_id]; + auto &last_local_buffer_desc = local_segment_desc.get()->buffers[extra_local_buffer_id]; + size_t first_local_length = first_local_buffer_desc.addr + first_local_buffer_desc.length - (size_t)slice->source_addr; + size_t last_local_length = (size_t)slice->source_addr + slice->length - last_local_buffer_desc.addr; + assert(first_local_length + last_local_length == slice->length); + auto &first_peer_buffer_desc = peer_segment_desc.get()->buffers[peer_buffer_id]; + auto &last_peer_buffer_desc = peer_segment_desc.get()->buffers[extra_peer_buffer_id]; + size_t first_peer_length = first_peer_buffer_desc.addr + first_peer_buffer_desc.length - slice->rdma.dest_addr; + size_t last_peer_length = slice->rdma.dest_addr + slice->length - last_peer_buffer_desc.addr; + assert(first_peer_length + last_peer_length == slice->length); + if (first_local_length == first_peer_length) { + // add first part + slice->length = first_local_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = (char *)request.source + offset + first_local_length; + last_slice->length = last_local_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = request.target_offset + offset + first_local_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + } else if (first_local_length > first_peer_length) { + // add first part + slice->length = first_peer_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_peer_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add second part + Slice *second_slice = getSliceCache().allocate(); + second_slice->source_addr = (char *)request.source + offset + first_peer_length; + second_slice->length = first_local_length - first_peer_length; + second_slice->opcode = request.opcode; + second_slice->rdma.dest_addr = request.target_offset + offset + first_peer_length; + second_slice->rdma.retry_cnt = 0; + second_slice->rdma.max_retry_cnt = kMaxRetryCount; + second_slice->task = &task; + second_slice->target_id = request.target_id; + second_slice->ts = 0; + second_slice->status = Slice::PENDING; + task.slice_list.push_back(second_slice); + second_slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + second_slice->rdma.lkey_index = device_id; + second_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + slices_to_post[context].push_back(second_slice); + task.total_bytes += first_local_length - first_peer_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = (char *)request.source + offset + first_local_length; + last_slice->length = last_local_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = request.target_offset + offset + first_local_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + } else { + // first_local_length < first_peer_length + // add first part + slice->length = first_local_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add second part + Slice *second_slice = getSliceCache().allocate(); + second_slice->source_addr = (char *)request.source + offset + first_local_length; + second_slice->length = first_peer_length - first_local_length; + second_slice->opcode = request.opcode; + second_slice->rdma.dest_addr = request.target_offset + offset + first_local_length; + second_slice->rdma.retry_cnt = 0; + second_slice->rdma.max_retry_cnt = kMaxRetryCount; + second_slice->task = &task; + second_slice->target_id = request.target_id; + second_slice->ts = 0; + second_slice->status = Slice::PENDING; + task.slice_list.push_back(second_slice); + second_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + second_slice->rdma.lkey_index = device_id; + second_slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(second_slice); + task.total_bytes += first_peer_length - first_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = (char *)request.source + offset + first_peer_length; + last_slice->length = last_peer_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = request.target_offset + offset + first_peer_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_peer_length; + __sync_fetch_and_add(&task.slice_count, 1); + } + break; + } + } + if (device_id < 0) { + auto source_addr = slice->source_addr; + for (auto &entry : slices_to_post) + for (auto s : entry.second) delete s; + LOG(ERROR) + << "BarexTransport: Address not registered by any device(s) " + << source_addr; + return Status::AddressNotRegistered( + "BarexTransport: not registered by any device(s), " + "address: " + + std::to_string(reinterpret_cast(source_addr))); + } + } + } + for (auto &entry : slices_to_post) { + int ret = entry.first->submitPostSend(entry.second); + if (ret) { + return Status::InvalidArgument("submitPostSend failed"); + } + } + return Status::OK(); +} + +Status BarexTransport::submitTransferTask( + const std::vector &task_list) { + std::unordered_map, std::vector> + slices_to_post; + auto local_segment_desc = metadata_->getSegmentDescByID(LOCAL_SEGMENT_ID); + assert(local_segment_desc.get()); + // const size_t kBlockSize = globalConfig().slice_size; + const size_t kBlockMaxSize = globalConfig().eic_max_block_size; + const int kMaxRetryCount = globalConfig().retry_cnt; + std::unordered_map> segment_desc_map; + for (size_t index = 0; index < task_list.size(); ++index) { + assert(task_list[index]); + auto &task = *task_list[index]; + assert(task.request); + auto &request = *task.request; + auto target_id = request.target_id; + if (!segment_desc_map.count(target_id)) + segment_desc_map[target_id] = metadata_->getSegmentDescByID(target_id); + } + for (size_t index = 0; index < task_list.size(); ++index) { + auto &task = *task_list[index]; + auto &request = *task.request; + SegmentID target_id = request.target_id; + auto peer_segment_desc = segment_desc_map[target_id]; + if (!peer_segment_desc) { + LOG(ERROR) << "peer_segment_desc not found for target_id " << target_id; + return Status::InvalidArgument( + "BarexTransport: peer_segment_desc not found"); + } + size_t kBlockSize = std::min(request.length, kBlockMaxSize); + for (uint64_t offset = 0; offset < request.length; offset += kBlockSize) { + Slice *slice = getSliceCache().allocate(); + assert(slice); + slice->source_addr = (char *)request.source + offset; + slice->length = std::min(request.length - offset, kBlockSize); + slice->opcode = request.opcode; + slice->rdma.dest_addr = request.target_offset + offset; + slice->rdma.retry_cnt = request.advise_retry_cnt; + slice->rdma.max_retry_cnt = kMaxRetryCount; + slice->task = &task; + slice->target_id = request.target_id; + slice->status = Slice::PENDING; + slice->ts = 0; + task.slice_list.push_back(slice); + + int peer_buffer_id = -1, extra_peer_buffer_id = 0, peer_device_id = -1; + int local_buffer_id = -1, extra_local_buffer_id = 0, device_id = -1, retry_cnt = request.advise_retry_cnt; + bool found_device = false; + while (retry_cnt < kMaxRetryCount) { + int ret = selectDevice(local_segment_desc.get(), + (uint64_t)slice->source_addr, slice->length, + local_buffer_id, device_id, retry_cnt++); + if (ret) { + if (ret == ERR_ADDRESS_NOT_REGISTERED) { + LOG(WARNING) << "local_segment_desc selectDevice failed"; + continue; + } else { + // need 2 blocks + extra_local_buffer_id = local_buffer_id + 1; + } + } + ret = selectDevice(peer_segment_desc.get(), + slice->rdma.dest_addr, + slice->length, peer_buffer_id, peer_device_id, + slice->rdma.retry_cnt); + if (ret) { + if (ret == ERR_ADDRESS_NOT_REGISTERED) { + LOG(WARNING) << "peer_segment_desc selectDevice failed"; + continue; + } else { + // need 2 blocks + extra_peer_buffer_id = peer_buffer_id + 1; + } + } + assert(device_id >= 0); + if (device_id >= static_cast(client_context_list_.size()) || use_random_dev_) { + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, client_context_list_.size() - 1); + device_id = dis(gen); + } + auto &context = client_context_list_[device_id]; + assert(context.get()); + if (!context->active()) continue; + assert(context->getCtx()->GetXDevice()->GetId() == device_id); + assert(local_buffer_id >= 0 && local_buffer_id < local_segment_desc->buffers.size()); + assert(local_segment_desc->buffers[local_buffer_id].lkey.size() == client_context_list_.size()); + // 4 types, local:peer = 1:1, 1:2, 2:1, 2:2 + if (!extra_local_buffer_id && !extra_peer_buffer_id) { // 1:1 + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += slice->length; + __sync_fetch_and_add(&task.slice_count, 1); + found_device = true; + break; + } else if (!extra_local_buffer_id && extra_peer_buffer_id) { // 1:2 + auto &first_peer_buffer_desc = peer_segment_desc.get()->buffers[peer_buffer_id]; + auto &last_peer_buffer_desc = peer_segment_desc.get()->buffers[extra_peer_buffer_id]; + size_t first_length = first_peer_buffer_desc.addr + first_peer_buffer_desc.length - slice->rdma.dest_addr; + size_t last_length = slice->rdma.dest_addr + slice->length - last_peer_buffer_desc.addr; + assert(first_length + last_length == slice->length); + // add first part + slice->length = first_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = (char *)request.source + offset + first_length; + last_slice->length = last_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = request.target_offset + offset + first_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_length; + __sync_fetch_and_add(&task.slice_count, 1); + found_device = true; + break; + } else if (extra_local_buffer_id && !extra_peer_buffer_id) { // 2:1 + auto &first_local_buffer_desc = local_segment_desc.get()->buffers[local_buffer_id]; + auto &last_local_buffer_desc = local_segment_desc.get()->buffers[extra_local_buffer_id]; + size_t first_length = first_local_buffer_desc.addr + first_local_buffer_desc.length - (size_t)slice->source_addr; + size_t last_length = (size_t)slice->source_addr + slice->length - last_local_buffer_desc.addr; + assert(first_length + last_length == slice->length); + // add first part + slice->length = first_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = (char *)request.source + offset + first_length; + last_slice->length = last_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = request.target_offset + offset + first_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_length; + __sync_fetch_and_add(&task.slice_count, 1); + found_device = true; + break; + } else { // 2:2 + auto &first_local_buffer_desc = local_segment_desc.get()->buffers[local_buffer_id]; + auto &last_local_buffer_desc = local_segment_desc.get()->buffers[extra_local_buffer_id]; + size_t first_local_length = first_local_buffer_desc.addr + first_local_buffer_desc.length - (size_t)slice->source_addr; + size_t last_local_length = (size_t)slice->source_addr + slice->length - last_local_buffer_desc.addr; + assert(first_local_length + last_local_length == slice->length); + auto &first_peer_buffer_desc = peer_segment_desc.get()->buffers[peer_buffer_id]; + auto &last_peer_buffer_desc = peer_segment_desc.get()->buffers[extra_peer_buffer_id]; + size_t first_peer_length = first_peer_buffer_desc.addr + first_peer_buffer_desc.length - slice->rdma.dest_addr; + size_t last_peer_length = slice->rdma.dest_addr + slice->length - last_peer_buffer_desc.addr; + assert(first_peer_length + last_peer_length == slice->length); + if (first_local_length == first_peer_length) { + // add first part + slice->length = first_local_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = (char *)request.source + offset + first_local_length; + last_slice->length = last_local_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = request.target_offset + offset + first_local_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + } else if (first_local_length > first_peer_length) { + // add first part + slice->length = first_peer_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_peer_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add second part + Slice *second_slice = getSliceCache().allocate(); + second_slice->source_addr = (char *)request.source + offset + first_peer_length; + second_slice->length = first_local_length - first_peer_length; + second_slice->opcode = request.opcode; + second_slice->rdma.dest_addr = request.target_offset + offset + first_peer_length; + second_slice->rdma.retry_cnt = 0; + second_slice->rdma.max_retry_cnt = kMaxRetryCount; + second_slice->task = &task; + second_slice->target_id = request.target_id; + second_slice->ts = 0; + second_slice->status = Slice::PENDING; + task.slice_list.push_back(second_slice); + second_slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + second_slice->rdma.lkey_index = device_id; + second_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + slices_to_post[context].push_back(second_slice); + task.total_bytes += first_local_length - first_peer_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = (char *)request.source + offset + first_local_length; + last_slice->length = last_local_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = request.target_offset + offset + first_local_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + } else { + // first_local_length < first_peer_length + // add first part + slice->length = first_local_length; + slice->rdma.source_lkey = + local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + slice->rdma.lkey_index = device_id; + slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(slice); + task.total_bytes += first_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add second part + Slice *second_slice = getSliceCache().allocate(); + second_slice->source_addr = (char *)request.source + offset + first_local_length; + second_slice->length = first_peer_length - first_local_length; + second_slice->opcode = request.opcode; + second_slice->rdma.dest_addr = request.target_offset + offset + first_local_length; + second_slice->rdma.retry_cnt = 0; + second_slice->rdma.max_retry_cnt = kMaxRetryCount; + second_slice->task = &task; + second_slice->target_id = request.target_id; + second_slice->ts = 0; + second_slice->status = Slice::PENDING; + task.slice_list.push_back(second_slice); + second_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + second_slice->rdma.lkey_index = device_id; + second_slice->dest_rkeys = + peer_segment_desc->buffers[peer_buffer_id].rkey; + slices_to_post[context].push_back(second_slice); + task.total_bytes += first_peer_length - first_local_length; + __sync_fetch_and_add(&task.slice_count, 1); + // add last part + Slice *last_slice = getSliceCache().allocate(); + last_slice->source_addr = (char *)request.source + offset + first_peer_length; + last_slice->length = last_peer_length; + last_slice->opcode = request.opcode; + last_slice->rdma.dest_addr = request.target_offset + offset + first_peer_length; + last_slice->rdma.retry_cnt = 0; + last_slice->rdma.max_retry_cnt = kMaxRetryCount; + last_slice->task = &task; + last_slice->target_id = request.target_id; + last_slice->ts = 0; + last_slice->status = Slice::PENDING; + task.slice_list.push_back(last_slice); + last_slice->rdma.source_lkey = + local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + last_slice->rdma.lkey_index = device_id; + last_slice->dest_rkeys = + peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + slices_to_post[context].push_back(last_slice); + task.total_bytes += last_peer_length; + __sync_fetch_and_add(&task.slice_count, 1); + } + found_device = true; + break; + } + } + if (!found_device) { + auto source_addr = slice->source_addr; + for (auto &entry : slices_to_post) + for (auto s : entry.second) getSliceCache().deallocate(s); + LOG(ERROR) + << "Memory region not registered by any active device(s): " + << source_addr; + return Status::AddressNotRegistered( + "Memory region not registered by any active device(s): " + + std::to_string(reinterpret_cast(source_addr))); + } + } + } + for (auto &entry : slices_to_post) { + int ret = entry.first->submitPostSend(entry.second); + if (ret) { + return Status::InvalidArgument("submitPostSend failed"); + } + } + return Status::OK(); +} + +Status BarexTransport::getTransferStatus(BatchID batch_id, + std::vector &status) { + auto &batch_desc = *((BatchDesc *)(batch_id)); + const size_t task_count = batch_desc.task_list.size(); + status.resize(task_count); + for (size_t task_id = 0; task_id < task_count; task_id++) { + auto &task = batch_desc.task_list[task_id]; + status[task_id].transferred_bytes = task.transferred_bytes; + uint64_t success_slice_count = task.success_slice_count; + uint64_t failed_slice_count = task.failed_slice_count; + if (success_slice_count + failed_slice_count == task.slice_count) { + if (failed_slice_count) { + status[task_id].s = TransferStatusEnum::FAILED; + } else { + status[task_id].s = TransferStatusEnum::COMPLETED; + } + task.is_finished = true; + } else { + status[task_id].s = TransferStatusEnum::WAITING; + } + } + return Status::OK(); +} + +Status BarexTransport::getTransferStatus(BatchID batch_id, size_t task_id, + TransferStatus &status) { + auto &batch_desc = *((BatchDesc *)(batch_id)); + const size_t task_count = batch_desc.task_list.size(); + if (task_id >= task_count) { + return Status::InvalidArgument( + "BarexTransport::getTransportStatus invalid argument, batch id: " + + std::to_string(batch_id)); + } + auto &task = batch_desc.task_list[task_id]; + status.transferred_bytes = task.transferred_bytes; + uint64_t success_slice_count = task.success_slice_count; + uint64_t failed_slice_count = task.failed_slice_count; + if (success_slice_count + failed_slice_count == task.slice_count) { + if (failed_slice_count) + status.s = TransferStatusEnum::FAILED; + else + status.s = TransferStatusEnum::COMPLETED; + task.is_finished = true; + } else { + status.s = TransferStatusEnum::WAITING; + } + return Status::OK(); +} + +BarexTransport::SegmentID BarexTransport::getSegmentID( + const std::string &segment_name) { + return metadata_->getSegmentID(segment_name); +} + +Status BarexTransport::OpenChannel(const std::string &segment_name, SegmentID sid) { + auto [ip, port] = parseHostNameWithPort(segment_name); + + HandShakeDesc local_desc, peer_desc; + local_desc.barex_port = getLocalPort(); + + int rc = metadata_->sendHandshake(segment_name, local_desc, peer_desc); + if (rc) return Status::Socket("sendHandshake failed");; + if (!peer_desc.reply_msg.empty()) { + LOG(ERROR) << "Reject the handshake request by peer " + << segment_name; + return Status::Socket("empty peer_desc"); + } else { + LOG(INFO) << "Handshake finish, get peer_server " << segment_name << ":" << peer_desc.barex_port; + setPeerPort(peer_desc.barex_port); + } + + int client_ctx_cnt = client_context_list_.size(); + int total_channels = client_ctx_cnt * client_context_list_[0]->getQpNum(); + CountDownLatch connect_latch(total_channels); + std::vector channels; + static std::mutex push_channel_mtx; + for (int i = 0; i < total_channels; i++) { + BarexResult result = connector_->Connect(ip, getPeerPort(), [=, &channels, &connect_latch](XChannel *channel, accl::barex::Status s) { + if (!s.IsOk()) { + LOG(ERROR) << "BarexTransport::OpenChannel failed, " << s.ErrMsg(); + } else { + std::unique_lock lk(push_channel_mtx); + channels.push_back(channel); + LOG(INFO) << "Open channel " << i+1 << "/" << total_channels; + } + connect_latch.CountDown(); + }); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport::OpenChannel failed, result=" << result; + connect_latch.CountDown(); + } + } + connect_latch.Wait(); + if ((int)channels.size() != total_channels) { + LOG(ERROR) << "open channel failed, need " << total_channels << " but got " << channels.size(); + return Status::InvalidArgument("connect failed"); + } + for (auto channel : channels) { + int idx = channel->GetContext()->GetXDevice()->GetId(); + assert(client_context_list_[idx]->getCtx()->GetXDevice()->GetId() == idx); + client_context_list_[idx]->addChannel(sid, idx, channel); + } + return Status::OK(); +} + +Status BarexTransport::CheckStatus(SegmentID sid) { + bool status = 0; + for (auto ctx : client_context_list_) { + int ret = ctx->checkStatus(sid); + if (ret) { + LOG(INFO) << "checkStatus failed in ctx" << ctx << ", bad channel cnt=" << ret; + status = 1; + } + } + if (!status) { + LOG(ERROR) << "CheckStatus for sid " << sid << " failed"; + return Status::InvalidArgument("sid status error"); + } + return Status::OK(); +} + +int BarexTransport::onSetupRdmaConnections(const HandShakeDesc &peer_desc, + HandShakeDesc &local_desc) { + local_desc.barex_port = getLocalPort(); + return 0; +} + +int BarexTransport::initializeRdmaResources() { + auto hca_list = local_topology_->getHcaList(); + BarexResult result; + XDeviceManager *manager = nullptr; + XThreadpool *server_threadpool = nullptr; + XThreadpool *client_threadpool = nullptr; + XSimpleMempool *mempool = nullptr; + result = XDeviceManager::Singleton(manager); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport: Create XDeviceManager failed"; + return ERR_DEVICE_NOT_FOUND; + } + std::vector devices = manager->AllDevices(); + if (devices.size() <= 0) { + LOG(ERROR) << "BarexTransport: No available RNIC"; + return ERR_DEVICE_NOT_FOUND; + } else { + LOG(INFO) << devices.size() << " rdma devices found"; + } + result = XSimpleMempool::NewInstance(mempool, "barex-mempool", devices); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport: Create XSimpleMempool failed"; + return ERR_INVALID_ARGUMENT; + } + mempool_ = std::shared_ptr(mempool); + result = XThreadpool::NewInstance(server_threadpool, 10, "barex-server-threadpool"); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport: Create Server XThreadpool failed"; + return ERR_INVALID_ARGUMENT; + } + server_threadpool_ = std::shared_ptr(server_threadpool); + result = XThreadpool::NewInstance(client_threadpool, 10, "barex-client-threadpool"); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport: Create Client XThreadpool failed"; + return ERR_INVALID_ARGUMENT; + } + client_threadpool_ = std::shared_ptr(client_threadpool); + auto &config = globalConfig(); + for (auto &dev : devices) { + if (std::find(hca_list.begin(), hca_list.end(), dev->GetName()) == hca_list.end()) { + LOG(WARNING) << "BarexTransport: device " << dev->GetName() << " not found in hca_list, ignore "; + continue; + } + ContextConfig server_config = XConfigUtil::DefaultContextConfig(); + XContext *raw_server_context = nullptr; + result = XContext::NewInstance(raw_server_context, server_config, new EmptyCallback(), dev, mempool, server_threadpool); + if (result != BAREX_SUCCESS) { + local_topology_->disableDevice(dev->GetName()); + LOG(WARNING) << "BarexTransport: Create XContext failed, Disable device " << dev->GetName(); + } else { + raw_server_context->Start(); + auto server_context = std::make_shared(raw_server_context, barex_use_cpu_, barex_local_device_); + server_context->setQpNum(config.num_qp_per_ep); + server_context_list_.push_back(server_context); + } + ContextConfig client_config = XConfigUtil::DefaultContextConfig(); + XContext *raw_client_context = nullptr; + result = XContext::NewInstance(raw_client_context, client_config, new EmptyCallback(), dev, mempool, client_threadpool); + if (result != BAREX_SUCCESS) { + local_topology_->disableDevice(dev->GetName()); + LOG(WARNING) << "BarexTransport: Create XContext failed, Disable device " << dev->GetName(); + } else { + raw_client_context->Start(); + auto client_context = std::make_shared(raw_client_context, barex_use_cpu_, barex_local_device_); + client_context->setQpNum(config.num_qp_per_ep); + client_context_list_.push_back(client_context); + } + } + + if (local_topology_->empty()) { + LOG(ERROR) << "BarexTransport: No available RNIC"; + return ERR_DEVICE_NOT_FOUND; + } + return 0; +} + +int BarexTransport::startHandshakeDaemon(std::string &local_server_name) { + std::vector raw_server_contexts; + std::vector raw_client_contexts; + for (auto ctx : server_context_list_) { + raw_server_contexts.emplace_back(ctx->getCtx()); + } + for (auto ctx : client_context_list_) { + raw_client_contexts.emplace_back(ctx->getCtx()); + } + XListener* listerner = nullptr; + + int port = metadata_->localRpcMeta().barex_port; + setLocalPort(port); + BarexResult result = XListener::NewInstance(listerner, 2, getLocalPort(), TIMER_3S, raw_server_contexts); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport: startHandshakeDaemon, create listerner failed, result " << result; + return ERR_INVALID_ARGUMENT; + } + result = listerner->Listen(); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport: startHandshakeDaemon, Listen failed, result " << result; + return ERR_INVALID_ARGUMENT; + } + listerner_ = std::shared_ptr(listerner); + XConnector* connector = nullptr; + result = XConnector::NewInstance(connector, 2, TIMER_3S, raw_client_contexts); + if (result != BAREX_SUCCESS) { + LOG(ERROR) << "BarexTransport: startHandshakeDaemon, create connector failed, result " << result; + return ERR_INVALID_ARGUMENT; + } + connector_ = std::shared_ptr(connector); + return metadata_->startHandshakeDaemon( + std::bind(&BarexTransport::onSetupRdmaConnections, this, + std::placeholders::_1, std::placeholders::_2), + metadata_->localRpcMeta().rpc_port, metadata_->localRpcMeta().sockfd); +} + +// According to the request desc, offset and length information, find proper +// buffer_id and device_id as output. +// Return 0 if successful, ERR_ADDRESS_NOT_REGISTERED otherwise. +int BarexTransport::selectDevice(SegmentDesc *desc, uint64_t offset, + size_t length, int &buffer_id, int &device_id, + int retry_count) { + if (!desc) return ERR_ADDRESS_NOT_REGISTERED; + int ret = 0; + for (buffer_id = 0; buffer_id < (int)desc->buffers.size(); ++buffer_id) { + auto &buffer_desc = desc->buffers[buffer_id]; + if (buffer_desc.addr > offset || offset >= buffer_desc.addr + buffer_desc.length) { + continue; + } else { + if (offset + length > buffer_desc.addr + buffer_desc.length) { + // mr cross two buffers, need separate into two parts + if (buffer_id + 1 < (int)desc->buffers.size()) { + auto &next_buffer_desc = desc->buffers[buffer_id+1]; + if (offset + length > next_buffer_desc.addr && offset + length <= next_buffer_desc.addr + next_buffer_desc.length) { + ret = 1; + } else { + LOG(ERROR) << "selectDevice failed, 2 buffers in need but next buffer not fit," + << " offset " << offset + << " length " << length + << " buffer_id " << buffer_id + << " buffer_desc.addr " << buffer_desc.addr + << " buffer_desc.length " << buffer_desc.length + << " buffer_id " << buffer_id+1 + << " next_buffer_desc.addr " << next_buffer_desc.addr + << " next_buffer_desc.length " << next_buffer_desc.length; + return ERR_ADDRESS_NOT_REGISTERED; + } + } else { + LOG(ERROR) << "selectDevice failed, last buffer overflow," + << " offset " << offset + << " length " << length + << " buffer_id " << buffer_id + << " buffer_desc.addr " << buffer_desc.addr + << " buffer_desc.length " << buffer_desc.length; + return ERR_ADDRESS_NOT_REGISTERED; + } + } + device_id = desc->topology.selectDevice(buffer_desc.name, retry_count); + if (device_id >= 0) return ret; + device_id = desc->topology.selectDevice(kWildcardLocation, retry_count); + if (device_id >= 0) return ret; + } + } + + return ERR_ADDRESS_NOT_REGISTERED; +} +} // namespace mooncake diff --git a/scripts/build_wheel.sh b/scripts/build_wheel.sh index ecf20e7ce..22f55e855 100755 --- a/scripts/build_wheel.sh +++ b/scripts/build_wheel.sh @@ -37,9 +37,13 @@ else fi # Copy nvlink-allocator.so to mooncake directory (only if it exists - CUDA builds only) -if [ -f build/mooncake-transfer-engine/nvlink-allocator/nvlink_allocator.so ]; then - echo "Copying CUDA nvlink_allocator.so..." - cp build/mooncake-transfer-engine/nvlink-allocator/nvlink_allocator.so mooncake-wheel/mooncake/nvlink_allocator.so +if [ -f build/mooncake-transfer-engine/nvlink-allocator/nvlink_allocator.so ] \ + || [ -f /usr/lib/libaccl_barex.so ] \ + || [ -f /usr/lib64/libaccl_barex.so ]; then + if [ -f build/mooncake-transfer-engine/nvlink-allocator/nvlink_allocator.so ]; then + echo "Copying CUDA nvlink_allocator.so..." + cp build/mooncake-transfer-engine/nvlink-allocator/nvlink_allocator.so mooncake-wheel/mooncake/nvlink_allocator.so + fi echo "Copying allocator libraries..." # Copy allocator.py cp mooncake-integration/allocator.py mooncake-wheel/mooncake/allocator.py @@ -282,6 +286,7 @@ else --exclude libascend_trace.so* \ --exclude libmetadef*.so \ --exclude libllm_datadist*.so \ + --exclude libaccl_barex.so* \ -w ${REPAIRED_DIR}/ --plat ${PLATFORM_TAG} fi From 5e9f409dfc4d889d1ec9549c7af7b503f1133db9 Mon Sep 17 00:00:00 2001 From: "zhangzechao.zzc" Date: Tue, 11 Nov 2025 15:15:32 +0800 Subject: [PATCH 2/6] feat[accl-barex]: spell fix --- .../transfer_engine/transfer_engine_py.cpp | 9 ++++++--- .../transport/barex_transport/barex_context.h | 14 +++++++------- .../transport/barex_transport/barex_transport.h | 4 ++-- mooncake-transfer-engine/src/multi_transport.cpp | 2 +- .../src/transfer_metadata_plugin.cpp | 4 ++-- .../transport/barex_transport/barex_transport.cpp | 14 +++++++------- 6 files changed, 25 insertions(+), 22 deletions(-) diff --git a/mooncake-integration/transfer_engine/transfer_engine_py.cpp b/mooncake-integration/transfer_engine/transfer_engine_py.cpp index 4dc00de62..a3def6980 100644 --- a/mooncake-integration/transfer_engine/transfer_engine_py.cpp +++ b/mooncake-integration/transfer_engine/transfer_engine_py.cpp @@ -135,9 +135,12 @@ int TransferEnginePy::initializeExt(const char *local_hostname, bool pass_alloc = false; const char *pass_alloc_env = std::getenv("PASS_ALLOC"); if (pass_alloc_env) { - int val = atoi(pass_alloc_env); - if (val != 0) { - pass_alloc = true; + try { + if (std::stoi(pass_alloc_env) != 0) { + pass_alloc = true; + } + } catch (const std::exception&) { + // Ignore invalid values or log a warning } } if (!pass_alloc) { diff --git a/mooncake-transfer-engine/include/transport/barex_transport/barex_context.h b/mooncake-transfer-engine/include/transport/barex_transport/barex_context.h index 15c8f8e4d..17c2f82e2 100644 --- a/mooncake-transfer-engine/include/transport/barex_transport/barex_context.h +++ b/mooncake-transfer-engine/include/transport/barex_transport/barex_context.h @@ -53,7 +53,7 @@ using BarexResult = accl::barex::BarexResult; class ChannelCache { public: - // 添加一个 channel 到指定 key & nic_id + // put channel void put(SegmentID key, int nic_id, XChannel* channel) { RWSpinlock::WriteGuard guard(lock_); auto& channels = cache_[key]; @@ -62,7 +62,7 @@ class ChannelCache { vec.push_back(channel); } - // 获取 sid 下指定 nic_id 和 idx 的 channel + // get channel XChannel* find(SegmentID key, int nic_id, int idx) { RWSpinlock::ReadGuard guard(lock_); auto it = cache_.find(key); @@ -77,7 +77,7 @@ class ChannelCache { return nullptr; } - // 删除某个 channel(通过id和idx) + // delete channel bool erase(SegmentID key, int nic_id, int idx) { RWSpinlock::WriteGuard guard(lock_); auto it = cache_.find(key); @@ -101,9 +101,9 @@ class ChannelCache { return true; } - // 查询某个 SegmentID 下的 channel 状态 + // get channel state bool CheckAllChannels(SegmentID segment_id) { - RWSpinlock::WriteGuard guard(lock_); + RWSpinlock::ReadGuard guard(lock_); auto it = cache_.find(segment_id); if (it == cache_.end()) { return false; @@ -120,7 +120,7 @@ class ChannelCache { return true; } - // 检查并删除某个 SegmentID 下的异常channel,并返回删除的数量 + // check and delete invalid channels int RemoveInvalidChannels(SegmentID segment_id) { RWSpinlock::WriteGuard guard(lock_); auto it = cache_.find(segment_id); @@ -143,7 +143,7 @@ class ChannelCache { return invalid_count; } - // 将所有的 channel 以 vector 形式返回 + // get all channels std::vector copyAll() { RWSpinlock::WriteGuard guard(lock_); std::vector result; diff --git a/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h b/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h index 2a9d80c25..84e51a32a 100644 --- a/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h +++ b/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h @@ -81,7 +81,7 @@ class BarexTransport : public Transport { std::shared_ptr meta, std::shared_ptr topo) override; - const char *getName() const override { return "rdma"; } + const char *getName() const override { return "barex"; } void setLocalPort(int port) { local_port_ = port; } @@ -156,7 +156,7 @@ class BarexTransport : public Transport { std::shared_ptr server_threadpool_; std::shared_ptr client_threadpool_; std::shared_ptr mempool_; - std::shared_ptr listerner_; + std::shared_ptr listener_; std::shared_ptr connector_; #endif std::shared_ptr local_topology_; diff --git a/mooncake-transfer-engine/src/multi_transport.cpp b/mooncake-transfer-engine/src/multi_transport.cpp index 1e16da43f..348e7fc70 100644 --- a/mooncake-transfer-engine/src/multi_transport.cpp +++ b/mooncake-transfer-engine/src/multi_transport.cpp @@ -273,7 +273,7 @@ Transport *MultiTransport::installTransport(const std::string &proto, nics += ","; } - // 移除最后一个多余的逗号 + // Remove the last extra comma if (!nics.empty()) { nics.pop_back(); } diff --git a/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp b/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp index 34b2d5ffb..968430d43 100644 --- a/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata_plugin.cpp @@ -1138,8 +1138,8 @@ std::vector findLocalIpAddresses() { uint16_t findAvailableTcpPort(int &sockfd, bool set_range) { static std::random_device rand_gen; std::uniform_int_distribution rand_dist; - int min_port = globalConfig().rpc_min_port;; - int max_port = globalConfig().rpc_max_port;; + int min_port = globalConfig().rpc_min_port; + int max_port = globalConfig().rpc_max_port; #ifdef USE_BAREX if (set_range) { min_port = 17000; diff --git a/mooncake-transfer-engine/src/transport/barex_transport/barex_transport.cpp b/mooncake-transfer-engine/src/transport/barex_transport/barex_transport.cpp index 0ffdb52ed..bd892ec51 100644 --- a/mooncake-transfer-engine/src/transport/barex_transport/barex_transport.cpp +++ b/mooncake-transfer-engine/src/transport/barex_transport/barex_transport.cpp @@ -63,8 +63,8 @@ BarexTransport::~BarexTransport() { batch_desc_set_.clear(); connector_->Shutdown(); connector_->WaitStop(); - listerner_->Shutdown(); - listerner_->WaitStop(); + listener_->Shutdown(); + listener_->WaitStop(); server_threadpool_->Shutdown(); server_threadpool_->WaitStop(); client_threadpool_->Shutdown(); @@ -1241,21 +1241,21 @@ int BarexTransport::startHandshakeDaemon(std::string &local_server_name) { for (auto ctx : client_context_list_) { raw_client_contexts.emplace_back(ctx->getCtx()); } - XListener* listerner = nullptr; + XListener* listener = nullptr; int port = metadata_->localRpcMeta().barex_port; setLocalPort(port); - BarexResult result = XListener::NewInstance(listerner, 2, getLocalPort(), TIMER_3S, raw_server_contexts); + BarexResult result = XListener::NewInstance(listener, 2, getLocalPort(), TIMER_3S, raw_server_contexts); if (result != BAREX_SUCCESS) { - LOG(ERROR) << "BarexTransport: startHandshakeDaemon, create listerner failed, result " << result; + LOG(ERROR) << "BarexTransport: startHandshakeDaemon, create listener failed, result " << result; return ERR_INVALID_ARGUMENT; } - result = listerner->Listen(); + result = listener->Listen(); if (result != BAREX_SUCCESS) { LOG(ERROR) << "BarexTransport: startHandshakeDaemon, Listen failed, result " << result; return ERR_INVALID_ARGUMENT; } - listerner_ = std::shared_ptr(listerner); + listener_ = std::shared_ptr(listener); XConnector* connector = nullptr; result = XConnector::NewInstance(connector, 2, TIMER_3S, raw_client_contexts); if (result != BAREX_SUCCESS) { From cd1ed0fff182da7eba8c0346a4d208f2bc7a8360 Mon Sep 17 00:00:00 2001 From: "zhangzechao.zzc" Date: Tue, 11 Nov 2025 15:15:59 +0800 Subject: [PATCH 3/6] feat[accl-barex]: clang-format --- .../transfer_engine/transfer_engine_py.cpp | 14 +- .../example/transfer_engine_bench.cpp | 2 +- .../include/transfer_metadata_plugin.h | 2 +- .../transport/barex_transport/barex_context.h | 22 +- .../barex_transport/barex_transport.h | 12 +- .../include/transport/transport.h | 4 +- .../src/multi_transport.cpp | 12 +- .../src/transfer_engine.cpp | 30 +- .../src/transfer_metadata.cpp | 20 +- .../barex_transport/barex_context.cpp | 78 +- .../barex_transport/barex_transport.cpp | 699 +++++++++++------- 11 files changed, 566 insertions(+), 329 deletions(-) diff --git a/mooncake-integration/transfer_engine/transfer_engine_py.cpp b/mooncake-integration/transfer_engine/transfer_engine_py.cpp index a3def6980..7d5018ea2 100644 --- a/mooncake-integration/transfer_engine/transfer_engine_py.cpp +++ b/mooncake-integration/transfer_engine/transfer_engine_py.cpp @@ -139,7 +139,7 @@ int TransferEnginePy::initializeExt(const char *local_hostname, if (std::stoi(pass_alloc_env) != 0) { pass_alloc = true; } - } catch (const std::exception&) { + } catch (const std::exception &) { // Ignore invalid values or log a warning } } @@ -279,7 +279,9 @@ int TransferEnginePy::transferSync(const char *target_hostname, if (handle_map_.count(target_hostname)) { handle = handle_map_[target_hostname]; } else { - LOG(INFO) << "transferSync, cache not found, openSegment with target " << target_hostname; + LOG(INFO) + << "transferSync, cache not found, openSegment with target " + << target_hostname; handle = engine_->openSegment(target_hostname); if (handle == (Transport::SegmentHandle)-1) return -1; handle_map_[target_hostname] = handle; @@ -317,7 +319,9 @@ int TransferEnginePy::transferSync(const char *target_hostname, if (!s.ok()) { Status segment_status = engine_->CheckSegmentStatus(handle); if (!segment_status.ok()) { - LOG(WARNING) << "submitTransfer failed with target " << target_hostname << ", CheckSegmentStatus not ok, ready to closeSegment"; + LOG(WARNING) + << "submitTransfer failed with target " << target_hostname + << ", CheckSegmentStatus not ok, ready to closeSegment"; std::lock_guard guard(mutex_); engine_->closeSegment(handle); engine_->getMetadata()->removeSegmentDesc(target_hostname); @@ -413,7 +417,9 @@ int TransferEnginePy::batchTransferSync( engine_->freeBatchID(batch_id); Status segment_status = engine_->CheckSegmentStatus(handle); if (!segment_status.ok()) { - LOG(WARNING) << "submitTransfer failed with target " << target_hostname << ", CheckSegmentStatus not ok, ready to closeSegment"; + LOG(WARNING) + << "submitTransfer failed with target " << target_hostname + << ", CheckSegmentStatus not ok, ready to closeSegment"; std::lock_guard guard(mutex_); engine_->closeSegment(handle); engine_->getMetadata()->removeSegmentDesc(target_hostname); diff --git a/mooncake-transfer-engine/example/transfer_engine_bench.cpp b/mooncake-transfer-engine/example/transfer_engine_bench.cpp index 5edff972f..21687a068 100644 --- a/mooncake-transfer-engine/example/transfer_engine_bench.cpp +++ b/mooncake-transfer-engine/example/transfer_engine_bench.cpp @@ -426,7 +426,7 @@ int target() { void **args = (void **)malloc(2 * sizeof(void *)); args[0] = (void *)nic_priority_matrix.c_str(); args[1] = nullptr; - engine->installTransport("rdma", args); + engine->installTransport("rdma", args); } else if (FLAGS_protocol == "barex") { auto nic_priority_matrix = loadNicPriorityMatrix(); void **args = (void **)malloc(2 * sizeof(void *)); diff --git a/mooncake-transfer-engine/include/transfer_metadata_plugin.h b/mooncake-transfer-engine/include/transfer_metadata_plugin.h index 22c361161..b95f7d31f 100644 --- a/mooncake-transfer-engine/include/transfer_metadata_plugin.h +++ b/mooncake-transfer-engine/include/transfer_metadata_plugin.h @@ -69,7 +69,7 @@ struct HandShakePlugin { std::vector findLocalIpAddresses(); -uint16_t findAvailableTcpPort(int &sockfd, bool set_range=false); +uint16_t findAvailableTcpPort(int &sockfd, bool set_range = false); } // namespace mooncake diff --git a/mooncake-transfer-engine/include/transport/barex_transport/barex_context.h b/mooncake-transfer-engine/include/transport/barex_transport/barex_context.h index 17c2f82e2..a8466f5ca 100644 --- a/mooncake-transfer-engine/include/transport/barex_transport/barex_context.h +++ b/mooncake-transfer-engine/include/transport/barex_transport/barex_context.h @@ -52,7 +52,7 @@ using XContext = accl::barex::XContext; using BarexResult = accl::barex::BarexResult; class ChannelCache { -public: + public: // put channel void put(SegmentID key, int nic_id, XChannel* channel) { RWSpinlock::WriteGuard guard(lock_); @@ -133,10 +133,9 @@ class ChannelCache { for (auto& pair : inner_map) { auto& channels = pair.second; - auto new_end = std::remove_if(channels.begin(), channels.end(), - [](XChannel* channel) { - return !channel->IsActive(); - }); + auto new_end = std::remove_if( + channels.begin(), channels.end(), + [](XChannel* channel) { return !channel->IsActive(); }); invalid_count += std::distance(new_end, channels.end()); channels.erase(new_end, channels.end()); } @@ -155,15 +154,17 @@ class ChannelCache { return result; } -private: - std::unordered_map>> cache_; + private: + std::unordered_map>> + cache_; std::unordered_map status_map_; RWSpinlock lock_; }; class BarexContext { - public: - int submitPostSend(const std::vector &slice_list); - int addChannel(SegmentID sid, int device_id, XChannel *ch); + public: + int submitPostSend(const std::vector& slice_list); + int addChannel(SegmentID sid, int device_id, XChannel* ch); XChannel* getChannel(SegmentID sid, int device_id, int idx); int checkStatus(SegmentID sid); XContext* getCtx(); @@ -186,7 +187,6 @@ class BarexContext { ChannelCache channel_cache_; bool active_ = true; int qp_num_per_ctx_ = 2; - }; #endif } // namespace mooncake diff --git a/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h b/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h index 84e51a32a..c09a32949 100644 --- a/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h +++ b/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h @@ -42,15 +42,14 @@ using TransferStatusEnum = Transport::TransferStatusEnum; using SegmentID = Transport::SegmentID; using BatchID = Transport::BatchID; - class TransferMetadata; class CountDownLatch { -private: + private: int count_; std::mutex mtx; std::condition_variable cv; -public: + public: CountDownLatch(int count) : count_(count){}; void CountDown() { @@ -96,8 +95,9 @@ class BarexTransport : public Transport { bool update_metadata) override; int registerLocalMemoryBase(void *addr, size_t length, - const std::string &location, bool remote_accessible, - bool update_metadata, bool is_gpu); + const std::string &location, + bool remote_accessible, bool update_metadata, + bool is_gpu); int unregisterLocalMemory(void *addr, bool update_metadata = true) override; @@ -161,7 +161,7 @@ class BarexTransport : public Transport { #endif std::shared_ptr local_topology_; std::mutex buf_mutex_; - std::map> buf_length_map_; + std::map> buf_length_map_; bool use_random_dev_ = false; bool barex_use_cpu_ = false; int barex_local_device_ = 0; diff --git a/mooncake-transfer-engine/include/transport/transport.h b/mooncake-transfer-engine/include/transport/transport.h index 13702a7c1..0510c19f3 100644 --- a/mooncake-transfer-engine/include/transport/transport.h +++ b/mooncake-transfer-engine/include/transport/transport.h @@ -259,7 +259,9 @@ class Transport { size_t length; }; - virtual Status OpenChannel(const std::string &segment_name, SegmentID sid) { return Status::OK(); } + virtual Status OpenChannel(const std::string &segment_name, SegmentID sid) { + return Status::OK(); + } virtual Status CheckStatus(SegmentID sid) { return Status::OK(); } protected: diff --git a/mooncake-transfer-engine/src/multi_transport.cpp b/mooncake-transfer-engine/src/multi_transport.cpp index 348e7fc70..3f2078940 100644 --- a/mooncake-transfer-engine/src/multi_transport.cpp +++ b/mooncake-transfer-engine/src/multi_transport.cpp @@ -254,17 +254,19 @@ Transport *MultiTransport::installTransport(const std::string &proto, #ifdef USE_BAREX bool use_eic = false; - for (auto& dev : topo->getHcaList()) { - if (dev.find("soe") != std::string::npos || dev.find("solar") != std::string::npos) { + for (auto &dev : topo->getHcaList()) { + if (dev.find("soe") != std::string::npos || + dev.find("solar") != std::string::npos) { use_eic = true; } } if (std::string(proto) == "barex") { std::string nics; - for (auto& dev : topo->getHcaList()) { + for (auto &dev : topo->getHcaList()) { if (use_eic) { - if (dev.find("soe") == std::string::npos && dev.find("solar") == std::string::npos) { + if (dev.find("soe") == std::string::npos && + dev.find("solar") == std::string::npos) { // ignore no eic nics continue; } @@ -281,7 +283,7 @@ Transport *MultiTransport::installTransport(const std::string &proto, if (!nics.empty()) { LOG(INFO) << "ACCL_USE_NICS is set to " << nics; setenv("ACCL_USE_NICS", nics.c_str(), 1); - } + } } #endif if (transport->install(local_server_name_, metadata_, topo)) { diff --git a/mooncake-transfer-engine/src/transfer_engine.cpp b/mooncake-transfer-engine/src/transfer_engine.cpp index c86ccce81..d644166c0 100644 --- a/mooncake-transfer-engine/src/transfer_engine.cpp +++ b/mooncake-transfer-engine/src/transfer_engine.cpp @@ -114,7 +114,8 @@ int TransferEngine::init(const std::string &metadata_conn_string, int tmp_fd = -1; desc.barex_port = findAvailableTcpPort(tmp_fd, true); if (desc.barex_port == 0) { - LOG(ERROR) << "Barex: No valid port found for local barex service."; + LOG(ERROR) + << "Barex: No valid port found for local barex service."; return -1; } close(tmp_fd); @@ -166,8 +167,10 @@ int TransferEngine::init(const std::string &metadata_conn_string, LOG(INFO) << "Transfer Engine RPC using " << rpc_binding_method << ", listening on " << desc.ip_or_host_name << ":" - << desc.rpc_port - << (use_barex_ ? ", barex use port:" + std::to_string(desc.barex_port) : ""); + << desc.rpc_port + << (use_barex_ + ? ", barex use port:" + std::to_string(desc.barex_port) + : ""); metadata_ = std::make_shared(metadata_conn_string); #ifdef USE_ASCEND @@ -250,22 +253,26 @@ int TransferEngine::init(const std::string &metadata_conn_string, if (local_topology_->getHcaList().size() > 0 && !getenv("MC_FORCE_TCP")) { // only install RDMA transport when there is at least one HCA - Transport* rdma_transport = nullptr; + Transport *rdma_transport = nullptr; if (use_barex_) { #ifdef USE_BAREX - rdma_transport = multi_transports_->installTransport("barex", local_topology_); + rdma_transport = multi_transports_->installTransport( + "barex", local_topology_); #else LOG(ERROR) << "Set USE BAREX while barex not compiled"; return -1; #endif } else { - rdma_transport = multi_transports_->installTransport("rdma", local_topology_); + rdma_transport = multi_transports_->installTransport( + "rdma", local_topology_); } if (rdma_transport == nullptr) { - LOG(ERROR) << "Failed to install RDMA transport, type=" << (use_barex_ ? "barex" : "rdma"); + LOG(ERROR) << "Failed to install RDMA transport, type=" + << (use_barex_ ? "barex" : "rdma"); return -1; } else { - LOG(INFO) << "installTransport, type=" << (use_barex_ ? "barex" : "rdma"); + LOG(INFO) << "installTransport, type=" + << (use_barex_ ? "barex" : "rdma"); } } else { Transport *tcp_transport = @@ -364,7 +371,7 @@ Transport::SegmentHandle TransferEngine::openSegment( SegmentID sid = metadata_->getSegmentID(trimmed_segment_name); #ifdef USE_BAREX if (use_barex_) { - Transport* transport = multi_transports_->getTransport("barex"); + Transport *transport = multi_transports_->getTransport("barex"); if (!transport) { LOG(ERROR) << "Barex proto not installed"; return (Transport::SegmentHandle)-1; @@ -382,8 +389,9 @@ Transport::SegmentHandle TransferEngine::openSegment( Status TransferEngine::CheckSegmentStatus(SegmentID sid) { #ifdef USE_BAREX if (use_barex_) { - Transport* transport = multi_transports_->getTransport("barex"); - BarexTransport* barex_transport = dynamic_cast(transport); + Transport *transport = multi_transports_->getTransport("barex"); + BarexTransport *barex_transport = + dynamic_cast(transport); return barex_transport->CheckStatus(sid); } else { return Status::OK(); diff --git a/mooncake-transfer-engine/src/transfer_metadata.cpp b/mooncake-transfer-engine/src/transfer_metadata.cpp index 76a71c504..1fe503fae 100644 --- a/mooncake-transfer-engine/src/transfer_metadata.cpp +++ b/mooncake-transfer-engine/src/transfer_metadata.cpp @@ -159,7 +159,8 @@ int TransferMetadata::encodeSegmentDesc(const SegmentDesc &desc, segmentJSON["tcp_data_port"] = desc.tcp_data_port; segmentJSON["timestamp"] = getCurrentDateTime(); - if (segmentJSON["protocol"] == "rdma" || segmentJSON["protocol"] == "barex") { + if (segmentJSON["protocol"] == "rdma" || + segmentJSON["protocol"] == "barex") { Json::Value devicesJSON(Json::arrayValue); for (const auto &device : desc.devices) { Json::Value deviceJSON; @@ -289,12 +290,13 @@ int TransferMetadata::updateSegmentDesc(const std::string &segment_name, int TransferMetadata::removeSegmentDesc(const std::string &segment_name) { if (p2p_handshake_mode_) { auto iter = segment_name_to_id_map_.find(segment_name); - if (iter != segment_name_to_id_map_.end()){ + if (iter != segment_name_to_id_map_.end()) { LOG(INFO) << "removeSegmentDesc " << segment_name << " finish"; segment_id_to_desc_map_.erase(iter->second); segment_name_to_id_map_.erase(iter); } else { - LOG(INFO) << "removeSegmentDesc " << segment_name << " not found, already removed maybe"; + LOG(INFO) << "removeSegmentDesc " << segment_name + << " not found, already removed maybe"; } return 0; } @@ -342,13 +344,11 @@ TransferMetadata::decodeSegmentDesc(Json::Value &segmentJSON, if (buffer.name.empty() || !buffer.addr || !buffer.length || buffer.rkey.empty() || buffer.rkey.size() != buffer.lkey.size()) { - LOG(WARNING) << "Corrupted segment descriptor, name " - << segment_name << " protocol " << desc->protocol - << ", " << buffer.name - << ", " << buffer.addr - << ", " << buffer.length - << ", " << buffer.rkey.size() - << ", " << buffer.lkey.size(); + LOG(WARNING) + << "Corrupted segment descriptor, name " << segment_name + << " protocol " << desc->protocol << ", " << buffer.name + << ", " << buffer.addr << ", " << buffer.length << ", " + << buffer.rkey.size() << ", " << buffer.lkey.size(); return nullptr; } desc->buffers.push_back(buffer); diff --git a/mooncake-transfer-engine/src/transport/barex_transport/barex_context.cpp b/mooncake-transfer-engine/src/transport/barex_transport/barex_context.cpp index 3f65ea668..848f79fc9 100644 --- a/mooncake-transfer-engine/src/transport/barex_transport/barex_context.cpp +++ b/mooncake-transfer-engine/src/transport/barex_transport/barex_context.cpp @@ -18,8 +18,10 @@ namespace mooncake { using namespace accl::barex; -BarexContext::BarexContext(XContext* xcontext, bool use_cpu, int device_id) : xcontext_(xcontext), barex_use_cpu_(use_cpu), barex_local_device_(device_id) {} - +BarexContext::BarexContext(XContext* xcontext, bool use_cpu, int device_id) + : xcontext_(xcontext), + barex_use_cpu_(use_cpu), + barex_local_device_(device_id) {} BarexContext::~BarexContext() { if (xcontext_) { @@ -43,18 +45,20 @@ int BarexContext::checkStatus(SegmentID sid) { return channel_cache_.RemoveInvalidChannels(sid); } -XContext* BarexContext::getCtx() { - return xcontext_; -} +XContext* BarexContext::getCtx() { return xcontext_; } std::vector BarexContext::getAllChannel() { return channel_cache_.copyAll(); } int BarexContext::submitPostSend( - const std::vector &slice_list) { - std::unordered_map>> sid_dev_data_map; - std::unordered_map>> sid_dev_slice_map; + const std::vector& slice_list) { + std::unordered_map< + SegmentID, std::unordered_map>> + sid_dev_data_map; + std::unordered_map>> + sid_dev_slice_map; for (auto slice : slice_list) { accl::barex::rw_memp_t w_m; w_m.sg.addr = (uint64_t)slice->source_addr; @@ -78,7 +82,8 @@ int BarexContext::submitPostSend( for (auto& dev_pair : dev_map) { int dev = dev_pair.first; std::vector& data_vec = dev_pair.second; - std::vector& slice_vec = sid_dev_slice_map[sid][dev]; + std::vector& slice_vec = + sid_dev_slice_map[sid][dev]; size_t data_size = data_vec.size(); int qp_in_use = qp_num_per_ctx_; if (data_size < (size_t)qp_num_per_ctx_) { @@ -88,18 +93,23 @@ int BarexContext::submitPostSend( size_t end_idx = 0; size_t batch_size = data_size / qp_in_use; size_t reminder = data_size % qp_in_use; - + int retry_cnt = 5; - for (int i=0; i < qp_in_use; i++) { + for (int i = 0; i < qp_in_use; i++) { XChannel* channel = nullptr; - for (int j=0; j < retry_cnt; j++) { + for (int j = 0; j < retry_cnt; j++) { channel = channel_cache_.find(sid, dev, i); if (!channel) { - LOG(ERROR) << "Write fail, sid " << sid << ", dev " << dev << ", id " << i << " not found, retry " << j << "/" << retry_cnt; + LOG(ERROR) + << "Write fail, sid " << sid << ", dev " << dev + << ", id " << i << " not found, retry " << j << "/" + << retry_cnt; break; } if (!channel->IsActive()) { - LOG(WARNING) << "Write fail, channel status error " << channel << " retry " << j << "/" << retry_cnt; + LOG(WARNING) + << "Write fail, channel status error " << channel + << " retry " << j << "/" << retry_cnt; channel_cache_.erase(sid, dev, i); continue; } @@ -108,24 +118,30 @@ int BarexContext::submitPostSend( LOG(ERROR) << "Write fail, no channel found"; return -1; } - + end_idx += batch_size; if (i == qp_in_use - 1) { end_idx += reminder; } int peer_nic_id = channel->GetPeerNicId(); - auto data_chunk_read = std::make_shared>(); - auto data_chunk_write = std::make_shared>(); - auto slice_chunk_read = std::make_shared>(); - auto slice_chunk_write = std::make_shared>(); + auto data_chunk_read = + std::make_shared>(); + auto data_chunk_write = + std::make_shared>(); + auto slice_chunk_read = + std::make_shared>(); + auto slice_chunk_write = + std::make_shared>(); for (size_t idx = begin_idx; idx < end_idx; idx++) { - data_vec[idx].r_key = slice_vec[idx]->dest_rkeys[peer_nic_id]; - if (slice_vec[idx]->opcode == Transport::TransferRequest::READ) { + data_vec[idx].r_key = + slice_vec[idx]->dest_rkeys[peer_nic_id]; + if (slice_vec[idx]->opcode == + Transport::TransferRequest::READ) { data_chunk_read->emplace_back(data_vec[idx]); - slice_chunk_read->emplace_back(slice_vec[idx]); + slice_chunk_read->emplace_back(slice_vec[idx]); } else { data_chunk_write->emplace_back(data_vec[idx]); - slice_chunk_write->emplace_back(slice_vec[idx]); + slice_chunk_write->emplace_back(slice_vec[idx]); } } @@ -133,8 +149,9 @@ int BarexContext::submitPostSend( BarexResult r = channel->WriteBatch( data_chunk_write, [slice_chunk_write](accl::barex::Status s) { - if(!s.IsOk()) { - LOG(ERROR) << "WriteBatch fail, " << s.ErrMsg().c_str(); + if (!s.IsOk()) { + LOG(ERROR) << "WriteBatch fail, " + << s.ErrMsg().c_str(); for (auto slice : *slice_chunk_write) { slice->markFailed(); } @@ -143,7 +160,8 @@ int BarexContext::submitPostSend( slice->markSuccess(); } } - }, true); + }, + true); if (r != accl::barex::BAREX_SUCCESS) { LOG(ERROR) << "WriteBatch fail, ret " << r; return -2; @@ -153,8 +171,9 @@ int BarexContext::submitPostSend( BarexResult r = channel->ReadBatch( data_chunk_read, [slice_chunk_read](accl::barex::Status s) { - if(!s.IsOk()) { - LOG(ERROR) << "ReadBatch fail, " << s.ErrMsg().c_str(); + if (!s.IsOk()) { + LOG(ERROR) + << "ReadBatch fail, " << s.ErrMsg().c_str(); for (auto slice : *slice_chunk_read) { slice->markFailed(); } @@ -163,7 +182,8 @@ int BarexContext::submitPostSend( slice->markSuccess(); } } - }, true); + }, + true); if (r != accl::barex::BAREX_SUCCESS) { LOG(ERROR) << "ReadBatch fail, ret " << r; return -2; diff --git a/mooncake-transfer-engine/src/transport/barex_transport/barex_transport.cpp b/mooncake-transfer-engine/src/transport/barex_transport/barex_transport.cpp index bd892ec51..3f9558940 100644 --- a/mooncake-transfer-engine/src/transport/barex_transport/barex_transport.cpp +++ b/mooncake-transfer-engine/src/transport/barex_transport/barex_transport.cpp @@ -34,8 +34,9 @@ namespace mooncake { using namespace accl::barex; class EmptyCallback : public XChannelCallback { -public: - void OnRecvCall(XChannel *channel, char *buf, size_t len, x_msg_header header) {} + public: + void OnRecvCall(XChannel *channel, char *buf, size_t len, + x_msg_header header) {} }; BarexTransport::BarexTransport() {} @@ -46,12 +47,13 @@ BarexTransport::~BarexTransport() { batch_desc_set_.clear(); #endif for (auto ctx : client_context_list_) { - std::vector chs = ctx->getAllChannel(); + std::vector chs = ctx->getAllChannel(); for (auto ch : chs) { - BarexResult ret = connector_->CloseChannel(ch, [&, ch](accl::barex::Status s) { - LOG(INFO) << "CloseChannel() finished, s.IsOk=" << s.IsOk(); - ch->Destroy(); - }); + BarexResult ret = + connector_->CloseChannel(ch, [&, ch](accl::barex::Status s) { + LOG(INFO) << "CloseChannel() finished, s.IsOk=" << s.IsOk(); + ch->Destroy(); + }); if (ret != accl::barex::BAREX_SUCCESS) { LOG(ERROR) << "CloseChannel() failed, ret " << ret; } @@ -74,8 +76,8 @@ BarexTransport::~BarexTransport() { } int BarexTransport::install(std::string &local_server_name, - std::shared_ptr meta, - std::shared_ptr topo) { + std::shared_ptr meta, + std::shared_ptr topo) { if (topo == nullptr) { LOG(ERROR) << "BarexTransport: missing topology"; return ERR_INVALID_ARGUMENT; @@ -99,7 +101,7 @@ int BarexTransport::install(std::string &local_server_name, int val = atoi(barex_use_cpu_env); if (val != 0) { LOG(INFO) << "BarexTransport: use_cpu"; - barex_use_cpu_ = true; + barex_use_cpu_ = true; } } @@ -149,12 +151,13 @@ int BarexTransport::registerLocalMemory(void *addr, size_t length, device_type dtype; if (name.find("cuda") != std::string::npos || name == kWildcardLocation) { - dtype = GPU; + dtype = GPU; } else if (name.find("cpu") != std::string::npos) { - dtype = CPU; + dtype = CPU; } else { - LOG(ERROR) << "BarexTransport: registerLocalMemory, cannot recognize: name " << name - << ", need include cpu or cuda in name"; + LOG(ERROR) + << "BarexTransport: registerLocalMemory, cannot recognize: name " + << name << ", need include cpu or cuda in name"; return ERR_INVALID_ARGUMENT; } @@ -162,12 +165,14 @@ int BarexTransport::registerLocalMemory(void *addr, size_t length, while (remaining > 0) { size_t buffer_len = std::min(buffer_size, remaining); - int ret = registerLocalMemoryBase(current_ptr, buffer_len, name, remote_accessible, update_metadata, is_gpu); + int ret = + registerLocalMemoryBase(current_ptr, buffer_len, name, + remote_accessible, update_metadata, is_gpu); if (ret) { LOG(ERROR) << "registerLocalMemoryBase failed, ret " << ret; return -1; } - current_ptr = static_cast(current_ptr) + buffer_len; + current_ptr = static_cast(current_ptr) + buffer_len; remaining -= buffer_len; } @@ -182,10 +187,9 @@ int BarexTransport::registerLocalMemory(void *addr, size_t length, } int BarexTransport::registerLocalMemoryBase(void *addr, size_t length, - const std::string &name, - bool remote_accessible, - bool update_metadata, - bool is_gpu) { + const std::string &name, + bool remote_accessible, + bool update_metadata, bool is_gpu) { (void)remote_accessible; BufferDesc buffer_desc; memp_t mem; @@ -194,8 +198,8 @@ int BarexTransport::registerLocalMemoryBase(void *addr, size_t length, result = mempool_->RegUserMr(mem, addr, length, dtype); if (result != BAREX_SUCCESS) { LOG(ERROR) << "BarexTransport: registerLocalMemory failed" - << ", result " << result << ", addr " << addr - << ", length " << length << ", name "<< name; + << ", result " << result << ", addr " << addr << ", length " + << length << ", name " << name; return ERR_ADDRESS_NOT_REGISTERED; } else { for (auto &mr : mem.mrs) { @@ -203,7 +207,7 @@ int BarexTransport::registerLocalMemoryBase(void *addr, size_t length, buffer_desc.rkey.push_back(mr.second->rkey); } } - + // Get the memory location automatically after registered MR(pinned), // when the name is kWildcardLocation("*"). if (name == kWildcardLocation) { @@ -253,17 +257,21 @@ int BarexTransport::unregisterLocalMemory(void *addr, bool update_metadata) { while (remaining > 0) { size_t buffer_len = std::min(buffer_size, remaining); if (current_ptr > addr) { - int rc = metadata_->removeLocalMemoryBuffer(current_ptr, update_metadata); + int rc = metadata_->removeLocalMemoryBuffer(current_ptr, + update_metadata); if (rc) { - LOG(WARNING) << "unregisterLocalMemory, removeLocalMemoryBuffer failed, addr " << addr; - } + LOG(WARNING) << "unregisterLocalMemory, " + "removeLocalMemoryBuffer failed, addr " + << addr; + } } result = mempool_->DeregUserMr(current_ptr, dtype); if (result != BAREX_SUCCESS) { - LOG(ERROR) << "unregisterLocalMemory, DeregUserMr, failed, ret " << result << ", addr " << current_ptr; + LOG(ERROR) << "unregisterLocalMemory, DeregUserMr, failed, ret " + << result << ", addr " << current_ptr; return -1; } - current_ptr = static_cast(current_ptr) + buffer_len; + current_ptr = static_cast(current_ptr) + buffer_len; remaining -= buffer_len; } @@ -279,8 +287,8 @@ int BarexTransport::allocateLocalSegmentID() { TransferMetadata::DeviceDesc device_desc; device_desc.name = entry->getCtx()->GetXDevice()->GetName(); // TODO is barex need this? - device_desc.lid = 0; //entry->lid(); - device_desc.gid = "ignore"; //entry->gid(); + device_desc.lid = 0; // entry->lid(); + device_desc.gid = "ignore"; // entry->gid(); desc->devices.push_back(device_desc); } desc->topology = *(local_topology_.get()); @@ -293,11 +301,11 @@ int BarexTransport::registerLocalMemoryBatch( const std::vector &buffer_list, const std::string &location) { for (auto &buffer : buffer_list) { - int ret = registerLocalMemory(buffer.addr, buffer.length, location, true, false); + int ret = registerLocalMemory(buffer.addr, buffer.length, location, + true, false); if (ret) { LOG(ERROR) << "BarexTransport: Failed to register memory: addr " - << buffer.addr << " length " - << buffer.length; + << buffer.addr << " length " << buffer.length; return ERR_ADDRESS_NOT_REGISTERED; } } @@ -328,8 +336,9 @@ Status BarexTransport::submitTransfer( BatchID batch_id, const std::vector &entries) { auto &batch_desc = *((BatchDesc *)(batch_id)); if (batch_desc.task_list.size() + entries.size() > batch_desc.batch_size) { - LOG(ERROR) << "BarexTransport: Exceed the limitation of current batch's " - "capacity"; + LOG(ERROR) + << "BarexTransport: Exceed the limitation of current batch's " + "capacity"; return Status::InvalidArgument( "BarexTransport: Exceed the limitation of capacity, batch id: " + std::to_string(batch_id)); @@ -343,11 +352,13 @@ Status BarexTransport::submitTransfer( // const size_t kBlockSize = globalConfig().slice_size; const size_t kBlockMaxSize = globalConfig().eic_max_block_size; const int kMaxRetryCount = globalConfig().retry_cnt; - std::unordered_map> segment_desc_map; + std::unordered_map> + segment_desc_map; for (auto &request : entries) { auto target_id = request.target_id; if (!segment_desc_map.count(target_id)) - segment_desc_map[target_id] = metadata_->getSegmentDescByID(target_id); + segment_desc_map[target_id] = + metadata_->getSegmentDescByID(target_id); } for (auto &request : entries) { TransferTask &task = batch_desc.task_list[task_id]; @@ -355,13 +366,15 @@ Status BarexTransport::submitTransfer( SegmentID target_id = request.target_id; auto peer_segment_desc = segment_desc_map[target_id]; if (!peer_segment_desc) { - LOG(ERROR) << "peer_segment_desc not found for target_id " << target_id; + LOG(ERROR) << "peer_segment_desc not found for target_id " + << target_id; return Status::InvalidArgument( "BarexTransport: peer_segment_desc not found, batch id: " + std::to_string(batch_id)); } size_t kBlockSize = std::min(request.length, kBlockMaxSize); - for (uint64_t offset = 0; offset < request.length; offset += kBlockSize) { + for (uint64_t offset = 0; offset < request.length; + offset += kBlockSize) { Slice *slice = getSliceCache().allocate(); slice->source_addr = (char *)request.source + offset; slice->length = std::min(request.length - offset, kBlockSize); @@ -375,25 +388,28 @@ Status BarexTransport::submitTransfer( slice->status = Slice::PENDING; task.slice_list.push_back(slice); - int peer_buffer_id = -1, extra_peer_buffer_id = 0, peer_device_id = -1; - int local_buffer_id = -1, extra_local_buffer_id = 0, device_id = -1, retry_cnt = 0; + int peer_buffer_id = -1, extra_peer_buffer_id = 0, + peer_device_id = -1; + int local_buffer_id = -1, extra_local_buffer_id = 0, device_id = -1, + retry_cnt = 0; while (retry_cnt < kMaxRetryCount) { - int ret = selectDevice(local_segment_desc.get(), - (uint64_t)slice->source_addr, slice->length, - local_buffer_id, device_id, retry_cnt++); + int ret = selectDevice( + local_segment_desc.get(), (uint64_t)slice->source_addr, + slice->length, local_buffer_id, device_id, retry_cnt++); if (ret) { if (ret == ERR_ADDRESS_NOT_REGISTERED) { - LOG(WARNING) << "local_segment_desc selectDevice failed"; + LOG(WARNING) + << "local_segment_desc selectDevice failed"; continue; } else { // need 2 blocks extra_local_buffer_id = local_buffer_id + 1; } } - ret = selectDevice(peer_segment_desc.get(), - slice->rdma.dest_addr, - slice->length, peer_buffer_id, peer_device_id, - slice->rdma.retry_cnt); + ret = + selectDevice(peer_segment_desc.get(), slice->rdma.dest_addr, + slice->length, peer_buffer_id, peer_device_id, + slice->rdma.retry_cnt); if (ret) { if (ret == ERR_ADDRESS_NOT_REGISTERED) { LOG(WARNING) << "peer_segment_desc selectDevice failed"; @@ -404,18 +420,22 @@ Status BarexTransport::submitTransfer( } } assert(device_id >= 0); - if (device_id >= static_cast(client_context_list_.size()) || use_random_dev_) { + if (device_id >= + static_cast(client_context_list_.size()) || + use_random_dev_) { std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(0, client_context_list_.size() - 1); + std::uniform_int_distribution<> dis( + 0, client_context_list_.size() - 1); device_id = dis(gen); } auto &context = client_context_list_[device_id]; if (!context->active()) continue; - assert(context->getCtx()->GetXDevice()->GetId() == device_id); + assert(context->getCtx()->GetXDevice()->GetId() == device_id); // 4 types, local:peer = 1:1, 1:2, 2:1, 2:2 - if (!extra_local_buffer_id && !extra_peer_buffer_id) { // 1:1 + if (!extra_local_buffer_id && !extra_peer_buffer_id) { // 1:1 slice->rdma.source_lkey = - local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; slice->rdma.lkey_index = device_id; slice->dest_rkeys = peer_segment_desc->buffers[peer_buffer_id].rkey; @@ -423,16 +443,23 @@ Status BarexTransport::submitTransfer( task.total_bytes += slice->length; __sync_fetch_and_add(&task.slice_count, 1); break; - } else if (!extra_local_buffer_id && extra_peer_buffer_id) { // 1:2 - auto &first_peer_buffer_desc = peer_segment_desc.get()->buffers[peer_buffer_id]; - auto &last_peer_buffer_desc = peer_segment_desc.get()->buffers[extra_peer_buffer_id]; - size_t first_length = first_peer_buffer_desc.addr + first_peer_buffer_desc.length - slice->rdma.dest_addr; - size_t last_length = slice->rdma.dest_addr + slice->length - last_peer_buffer_desc.addr; + } else if (!extra_local_buffer_id && + extra_peer_buffer_id) { // 1:2 + auto &first_peer_buffer_desc = + peer_segment_desc.get()->buffers[peer_buffer_id]; + auto &last_peer_buffer_desc = + peer_segment_desc.get()->buffers[extra_peer_buffer_id]; + size_t first_length = first_peer_buffer_desc.addr + + first_peer_buffer_desc.length - + slice->rdma.dest_addr; + size_t last_length = slice->rdma.dest_addr + slice->length - + last_peer_buffer_desc.addr; assert(first_length + last_length == slice->length); // add first part slice->length = first_length; slice->rdma.source_lkey = - local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; slice->rdma.lkey_index = device_id; slice->dest_rkeys = peer_segment_desc->buffers[peer_buffer_id].rkey; @@ -441,10 +468,12 @@ Status BarexTransport::submitTransfer( __sync_fetch_and_add(&task.slice_count, 1); // add last part Slice *last_slice = getSliceCache().allocate(); - last_slice->source_addr = (char *)request.source + offset + first_length; + last_slice->source_addr = + (char *)request.source + offset + first_length; last_slice->length = last_length; last_slice->opcode = request.opcode; - last_slice->rdma.dest_addr = request.target_offset + offset + first_length; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_length; last_slice->rdma.retry_cnt = 0; last_slice->rdma.max_retry_cnt = kMaxRetryCount; last_slice->task = &task; @@ -453,7 +482,8 @@ Status BarexTransport::submitTransfer( last_slice->status = Slice::PENDING; task.slice_list.push_back(last_slice); last_slice->rdma.source_lkey = - local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; last_slice->rdma.lkey_index = device_id; last_slice->dest_rkeys = peer_segment_desc->buffers[extra_peer_buffer_id].rkey; @@ -461,16 +491,25 @@ Status BarexTransport::submitTransfer( task.total_bytes += last_length; __sync_fetch_and_add(&task.slice_count, 1); break; - } else if (extra_local_buffer_id && !extra_peer_buffer_id) { // 2:1 - auto &first_local_buffer_desc = local_segment_desc.get()->buffers[local_buffer_id]; - auto &last_local_buffer_desc = local_segment_desc.get()->buffers[extra_local_buffer_id]; - size_t first_length = first_local_buffer_desc.addr + first_local_buffer_desc.length - (size_t)slice->source_addr; - size_t last_length = (size_t)slice->source_addr + slice->length - last_local_buffer_desc.addr; + } else if (extra_local_buffer_id && + !extra_peer_buffer_id) { // 2:1 + auto &first_local_buffer_desc = + local_segment_desc.get()->buffers[local_buffer_id]; + auto &last_local_buffer_desc = + local_segment_desc.get() + ->buffers[extra_local_buffer_id]; + size_t first_length = first_local_buffer_desc.addr + + first_local_buffer_desc.length - + (size_t)slice->source_addr; + size_t last_length = (size_t)slice->source_addr + + slice->length - + last_local_buffer_desc.addr; assert(first_length + last_length == slice->length); // add first part slice->length = first_length; slice->rdma.source_lkey = - local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; slice->rdma.lkey_index = device_id; slice->dest_rkeys = peer_segment_desc->buffers[peer_buffer_id].rkey; @@ -479,10 +518,12 @@ Status BarexTransport::submitTransfer( __sync_fetch_and_add(&task.slice_count, 1); // add last part Slice *last_slice = getSliceCache().allocate(); - last_slice->source_addr = (char *)request.source + offset + first_length; + last_slice->source_addr = + (char *)request.source + offset + first_length; last_slice->length = last_length; last_slice->opcode = request.opcode; - last_slice->rdma.dest_addr = request.target_offset + offset + first_length; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_length; last_slice->rdma.retry_cnt = 0; last_slice->rdma.max_retry_cnt = kMaxRetryCount; last_slice->task = &task; @@ -491,7 +532,8 @@ Status BarexTransport::submitTransfer( last_slice->status = Slice::PENDING; task.slice_list.push_back(last_slice); last_slice->rdma.source_lkey = - local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; last_slice->rdma.lkey_index = device_id; last_slice->dest_rkeys = peer_segment_desc->buffers[peer_buffer_id].rkey; @@ -499,22 +541,38 @@ Status BarexTransport::submitTransfer( task.total_bytes += last_length; __sync_fetch_and_add(&task.slice_count, 1); break; - } else { // 2:2 - auto &first_local_buffer_desc = local_segment_desc.get()->buffers[local_buffer_id]; - auto &last_local_buffer_desc = local_segment_desc.get()->buffers[extra_local_buffer_id]; - size_t first_local_length = first_local_buffer_desc.addr + first_local_buffer_desc.length - (size_t)slice->source_addr; - size_t last_local_length = (size_t)slice->source_addr + slice->length - last_local_buffer_desc.addr; - assert(first_local_length + last_local_length == slice->length); - auto &first_peer_buffer_desc = peer_segment_desc.get()->buffers[peer_buffer_id]; - auto &last_peer_buffer_desc = peer_segment_desc.get()->buffers[extra_peer_buffer_id]; - size_t first_peer_length = first_peer_buffer_desc.addr + first_peer_buffer_desc.length - slice->rdma.dest_addr; - size_t last_peer_length = slice->rdma.dest_addr + slice->length - last_peer_buffer_desc.addr; - assert(first_peer_length + last_peer_length == slice->length); + } else { // 2:2 + auto &first_local_buffer_desc = + local_segment_desc.get()->buffers[local_buffer_id]; + auto &last_local_buffer_desc = + local_segment_desc.get() + ->buffers[extra_local_buffer_id]; + size_t first_local_length = first_local_buffer_desc.addr + + first_local_buffer_desc.length - + (size_t)slice->source_addr; + size_t last_local_length = (size_t)slice->source_addr + + slice->length - + last_local_buffer_desc.addr; + assert(first_local_length + last_local_length == + slice->length); + auto &first_peer_buffer_desc = + peer_segment_desc.get()->buffers[peer_buffer_id]; + auto &last_peer_buffer_desc = + peer_segment_desc.get()->buffers[extra_peer_buffer_id]; + size_t first_peer_length = first_peer_buffer_desc.addr + + first_peer_buffer_desc.length - + slice->rdma.dest_addr; + size_t last_peer_length = slice->rdma.dest_addr + + slice->length - + last_peer_buffer_desc.addr; + assert(first_peer_length + last_peer_length == + slice->length); if (first_local_length == first_peer_length) { // add first part slice->length = first_local_length; slice->rdma.source_lkey = - local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; slice->rdma.lkey_index = device_id; slice->dest_rkeys = peer_segment_desc->buffers[peer_buffer_id].rkey; @@ -523,10 +581,12 @@ Status BarexTransport::submitTransfer( __sync_fetch_and_add(&task.slice_count, 1); // add last part Slice *last_slice = getSliceCache().allocate(); - last_slice->source_addr = (char *)request.source + offset + first_local_length; + last_slice->source_addr = (char *)request.source + + offset + first_local_length; last_slice->length = last_local_length; last_slice->opcode = request.opcode; - last_slice->rdma.dest_addr = request.target_offset + offset + first_local_length; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_local_length; last_slice->rdma.retry_cnt = 0; last_slice->rdma.max_retry_cnt = kMaxRetryCount; last_slice->task = &task; @@ -535,10 +595,12 @@ Status BarexTransport::submitTransfer( last_slice->status = Slice::PENDING; task.slice_list.push_back(last_slice); last_slice->rdma.source_lkey = - local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; last_slice->rdma.lkey_index = device_id; last_slice->dest_rkeys = - peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + peer_segment_desc->buffers[extra_peer_buffer_id] + .rkey; slices_to_post[context].push_back(last_slice); task.total_bytes += last_local_length; __sync_fetch_and_add(&task.slice_count, 1); @@ -546,7 +608,8 @@ Status BarexTransport::submitTransfer( // add first part slice->length = first_peer_length; slice->rdma.source_lkey = - local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; slice->rdma.lkey_index = device_id; slice->dest_rkeys = peer_segment_desc->buffers[peer_buffer_id].rkey; @@ -555,10 +618,13 @@ Status BarexTransport::submitTransfer( __sync_fetch_and_add(&task.slice_count, 1); // add second part Slice *second_slice = getSliceCache().allocate(); - second_slice->source_addr = (char *)request.source + offset + first_peer_length; - second_slice->length = first_local_length - first_peer_length; + second_slice->source_addr = + (char *)request.source + offset + first_peer_length; + second_slice->length = + first_local_length - first_peer_length; second_slice->opcode = request.opcode; - second_slice->rdma.dest_addr = request.target_offset + offset + first_peer_length; + second_slice->rdma.dest_addr = + request.target_offset + offset + first_peer_length; second_slice->rdma.retry_cnt = 0; second_slice->rdma.max_retry_cnt = kMaxRetryCount; second_slice->task = &task; @@ -567,19 +633,24 @@ Status BarexTransport::submitTransfer( second_slice->status = Slice::PENDING; task.slice_list.push_back(second_slice); second_slice->rdma.source_lkey = - local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; second_slice->rdma.lkey_index = device_id; second_slice->dest_rkeys = - peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + peer_segment_desc->buffers[extra_peer_buffer_id] + .rkey; slices_to_post[context].push_back(second_slice); - task.total_bytes += first_local_length - first_peer_length; + task.total_bytes += + first_local_length - first_peer_length; __sync_fetch_and_add(&task.slice_count, 1); // add last part Slice *last_slice = getSliceCache().allocate(); - last_slice->source_addr = (char *)request.source + offset + first_local_length; + last_slice->source_addr = (char *)request.source + + offset + first_local_length; last_slice->length = last_local_length; last_slice->opcode = request.opcode; - last_slice->rdma.dest_addr = request.target_offset + offset + first_local_length; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_local_length; last_slice->rdma.retry_cnt = 0; last_slice->rdma.max_retry_cnt = kMaxRetryCount; last_slice->task = &task; @@ -588,19 +659,22 @@ Status BarexTransport::submitTransfer( last_slice->status = Slice::PENDING; task.slice_list.push_back(last_slice); last_slice->rdma.source_lkey = - local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; last_slice->rdma.lkey_index = device_id; last_slice->dest_rkeys = - peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + peer_segment_desc->buffers[extra_peer_buffer_id] + .rkey; slices_to_post[context].push_back(last_slice); task.total_bytes += last_local_length; __sync_fetch_and_add(&task.slice_count, 1); - } else { + } else { // first_local_length < first_peer_length // add first part slice->length = first_local_length; slice->rdma.source_lkey = - local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; slice->rdma.lkey_index = device_id; slice->dest_rkeys = peer_segment_desc->buffers[peer_buffer_id].rkey; @@ -609,10 +683,13 @@ Status BarexTransport::submitTransfer( __sync_fetch_and_add(&task.slice_count, 1); // add second part Slice *second_slice = getSliceCache().allocate(); - second_slice->source_addr = (char *)request.source + offset + first_local_length; - second_slice->length = first_peer_length - first_local_length; + second_slice->source_addr = (char *)request.source + + offset + first_local_length; + second_slice->length = + first_peer_length - first_local_length; second_slice->opcode = request.opcode; - second_slice->rdma.dest_addr = request.target_offset + offset + first_local_length; + second_slice->rdma.dest_addr = + request.target_offset + offset + first_local_length; second_slice->rdma.retry_cnt = 0; second_slice->rdma.max_retry_cnt = kMaxRetryCount; second_slice->task = &task; @@ -621,19 +698,23 @@ Status BarexTransport::submitTransfer( second_slice->status = Slice::PENDING; task.slice_list.push_back(second_slice); second_slice->rdma.source_lkey = - local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; second_slice->rdma.lkey_index = device_id; second_slice->dest_rkeys = peer_segment_desc->buffers[peer_buffer_id].rkey; slices_to_post[context].push_back(second_slice); - task.total_bytes += first_peer_length - first_local_length; + task.total_bytes += + first_peer_length - first_local_length; __sync_fetch_and_add(&task.slice_count, 1); // add last part Slice *last_slice = getSliceCache().allocate(); - last_slice->source_addr = (char *)request.source + offset + first_peer_length; + last_slice->source_addr = + (char *)request.source + offset + first_peer_length; last_slice->length = last_peer_length; last_slice->opcode = request.opcode; - last_slice->rdma.dest_addr = request.target_offset + offset + first_peer_length; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_peer_length; last_slice->rdma.retry_cnt = 0; last_slice->rdma.max_retry_cnt = kMaxRetryCount; last_slice->task = &task; @@ -642,10 +723,12 @@ Status BarexTransport::submitTransfer( last_slice->status = Slice::PENDING; task.slice_list.push_back(last_slice); last_slice->rdma.source_lkey = - local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; last_slice->rdma.lkey_index = device_id; last_slice->dest_rkeys = - peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + peer_segment_desc->buffers[extra_peer_buffer_id] + .rkey; slices_to_post[context].push_back(last_slice); task.total_bytes += last_peer_length; __sync_fetch_and_add(&task.slice_count, 1); @@ -657,9 +740,9 @@ Status BarexTransport::submitTransfer( auto source_addr = slice->source_addr; for (auto &entry : slices_to_post) for (auto s : entry.second) delete s; - LOG(ERROR) - << "BarexTransport: Address not registered by any device(s) " - << source_addr; + LOG(ERROR) << "BarexTransport: Address not registered by any " + "device(s) " + << source_addr; return Status::AddressNotRegistered( "BarexTransport: not registered by any device(s), " "address: " + @@ -685,7 +768,8 @@ Status BarexTransport::submitTransferTask( // const size_t kBlockSize = globalConfig().slice_size; const size_t kBlockMaxSize = globalConfig().eic_max_block_size; const int kMaxRetryCount = globalConfig().retry_cnt; - std::unordered_map> segment_desc_map; + std::unordered_map> + segment_desc_map; for (size_t index = 0; index < task_list.size(); ++index) { assert(task_list[index]); auto &task = *task_list[index]; @@ -693,7 +777,8 @@ Status BarexTransport::submitTransferTask( auto &request = *task.request; auto target_id = request.target_id; if (!segment_desc_map.count(target_id)) - segment_desc_map[target_id] = metadata_->getSegmentDescByID(target_id); + segment_desc_map[target_id] = + metadata_->getSegmentDescByID(target_id); } for (size_t index = 0; index < task_list.size(); ++index) { auto &task = *task_list[index]; @@ -701,12 +786,14 @@ Status BarexTransport::submitTransferTask( SegmentID target_id = request.target_id; auto peer_segment_desc = segment_desc_map[target_id]; if (!peer_segment_desc) { - LOG(ERROR) << "peer_segment_desc not found for target_id " << target_id; + LOG(ERROR) << "peer_segment_desc not found for target_id " + << target_id; return Status::InvalidArgument( "BarexTransport: peer_segment_desc not found"); } size_t kBlockSize = std::min(request.length, kBlockMaxSize); - for (uint64_t offset = 0; offset < request.length; offset += kBlockSize) { + for (uint64_t offset = 0; offset < request.length; + offset += kBlockSize) { Slice *slice = getSliceCache().allocate(); assert(slice); slice->source_addr = (char *)request.source + offset; @@ -721,26 +808,29 @@ Status BarexTransport::submitTransferTask( slice->ts = 0; task.slice_list.push_back(slice); - int peer_buffer_id = -1, extra_peer_buffer_id = 0, peer_device_id = -1; - int local_buffer_id = -1, extra_local_buffer_id = 0, device_id = -1, retry_cnt = request.advise_retry_cnt; + int peer_buffer_id = -1, extra_peer_buffer_id = 0, + peer_device_id = -1; + int local_buffer_id = -1, extra_local_buffer_id = 0, device_id = -1, + retry_cnt = request.advise_retry_cnt; bool found_device = false; while (retry_cnt < kMaxRetryCount) { - int ret = selectDevice(local_segment_desc.get(), - (uint64_t)slice->source_addr, slice->length, - local_buffer_id, device_id, retry_cnt++); + int ret = selectDevice( + local_segment_desc.get(), (uint64_t)slice->source_addr, + slice->length, local_buffer_id, device_id, retry_cnt++); if (ret) { if (ret == ERR_ADDRESS_NOT_REGISTERED) { - LOG(WARNING) << "local_segment_desc selectDevice failed"; + LOG(WARNING) + << "local_segment_desc selectDevice failed"; continue; } else { // need 2 blocks extra_local_buffer_id = local_buffer_id + 1; } } - ret = selectDevice(peer_segment_desc.get(), - slice->rdma.dest_addr, - slice->length, peer_buffer_id, peer_device_id, - slice->rdma.retry_cnt); + ret = + selectDevice(peer_segment_desc.get(), slice->rdma.dest_addr, + slice->length, peer_buffer_id, peer_device_id, + slice->rdma.retry_cnt); if (ret) { if (ret == ERR_ADDRESS_NOT_REGISTERED) { LOG(WARNING) << "peer_segment_desc selectDevice failed"; @@ -751,21 +841,28 @@ Status BarexTransport::submitTransferTask( } } assert(device_id >= 0); - if (device_id >= static_cast(client_context_list_.size()) || use_random_dev_) { + if (device_id >= + static_cast(client_context_list_.size()) || + use_random_dev_) { std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(0, client_context_list_.size() - 1); + std::uniform_int_distribution<> dis( + 0, client_context_list_.size() - 1); device_id = dis(gen); } auto &context = client_context_list_[device_id]; assert(context.get()); if (!context->active()) continue; - assert(context->getCtx()->GetXDevice()->GetId() == device_id); - assert(local_buffer_id >= 0 && local_buffer_id < local_segment_desc->buffers.size()); - assert(local_segment_desc->buffers[local_buffer_id].lkey.size() == client_context_list_.size()); + assert(context->getCtx()->GetXDevice()->GetId() == device_id); + assert(local_buffer_id >= 0 && + local_buffer_id < local_segment_desc->buffers.size()); + assert( + local_segment_desc->buffers[local_buffer_id].lkey.size() == + client_context_list_.size()); // 4 types, local:peer = 1:1, 1:2, 2:1, 2:2 - if (!extra_local_buffer_id && !extra_peer_buffer_id) { // 1:1 + if (!extra_local_buffer_id && !extra_peer_buffer_id) { // 1:1 slice->rdma.source_lkey = - local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; slice->rdma.lkey_index = device_id; slice->dest_rkeys = peer_segment_desc->buffers[peer_buffer_id].rkey; @@ -774,16 +871,23 @@ Status BarexTransport::submitTransferTask( __sync_fetch_and_add(&task.slice_count, 1); found_device = true; break; - } else if (!extra_local_buffer_id && extra_peer_buffer_id) { // 1:2 - auto &first_peer_buffer_desc = peer_segment_desc.get()->buffers[peer_buffer_id]; - auto &last_peer_buffer_desc = peer_segment_desc.get()->buffers[extra_peer_buffer_id]; - size_t first_length = first_peer_buffer_desc.addr + first_peer_buffer_desc.length - slice->rdma.dest_addr; - size_t last_length = slice->rdma.dest_addr + slice->length - last_peer_buffer_desc.addr; + } else if (!extra_local_buffer_id && + extra_peer_buffer_id) { // 1:2 + auto &first_peer_buffer_desc = + peer_segment_desc.get()->buffers[peer_buffer_id]; + auto &last_peer_buffer_desc = + peer_segment_desc.get()->buffers[extra_peer_buffer_id]; + size_t first_length = first_peer_buffer_desc.addr + + first_peer_buffer_desc.length - + slice->rdma.dest_addr; + size_t last_length = slice->rdma.dest_addr + slice->length - + last_peer_buffer_desc.addr; assert(first_length + last_length == slice->length); // add first part slice->length = first_length; slice->rdma.source_lkey = - local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; slice->rdma.lkey_index = device_id; slice->dest_rkeys = peer_segment_desc->buffers[peer_buffer_id].rkey; @@ -792,10 +896,12 @@ Status BarexTransport::submitTransferTask( __sync_fetch_and_add(&task.slice_count, 1); // add last part Slice *last_slice = getSliceCache().allocate(); - last_slice->source_addr = (char *)request.source + offset + first_length; + last_slice->source_addr = + (char *)request.source + offset + first_length; last_slice->length = last_length; last_slice->opcode = request.opcode; - last_slice->rdma.dest_addr = request.target_offset + offset + first_length; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_length; last_slice->rdma.retry_cnt = 0; last_slice->rdma.max_retry_cnt = kMaxRetryCount; last_slice->task = &task; @@ -804,7 +910,8 @@ Status BarexTransport::submitTransferTask( last_slice->status = Slice::PENDING; task.slice_list.push_back(last_slice); last_slice->rdma.source_lkey = - local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; last_slice->rdma.lkey_index = device_id; last_slice->dest_rkeys = peer_segment_desc->buffers[extra_peer_buffer_id].rkey; @@ -813,16 +920,25 @@ Status BarexTransport::submitTransferTask( __sync_fetch_and_add(&task.slice_count, 1); found_device = true; break; - } else if (extra_local_buffer_id && !extra_peer_buffer_id) { // 2:1 - auto &first_local_buffer_desc = local_segment_desc.get()->buffers[local_buffer_id]; - auto &last_local_buffer_desc = local_segment_desc.get()->buffers[extra_local_buffer_id]; - size_t first_length = first_local_buffer_desc.addr + first_local_buffer_desc.length - (size_t)slice->source_addr; - size_t last_length = (size_t)slice->source_addr + slice->length - last_local_buffer_desc.addr; + } else if (extra_local_buffer_id && + !extra_peer_buffer_id) { // 2:1 + auto &first_local_buffer_desc = + local_segment_desc.get()->buffers[local_buffer_id]; + auto &last_local_buffer_desc = + local_segment_desc.get() + ->buffers[extra_local_buffer_id]; + size_t first_length = first_local_buffer_desc.addr + + first_local_buffer_desc.length - + (size_t)slice->source_addr; + size_t last_length = (size_t)slice->source_addr + + slice->length - + last_local_buffer_desc.addr; assert(first_length + last_length == slice->length); // add first part slice->length = first_length; slice->rdma.source_lkey = - local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; slice->rdma.lkey_index = device_id; slice->dest_rkeys = peer_segment_desc->buffers[peer_buffer_id].rkey; @@ -831,10 +947,12 @@ Status BarexTransport::submitTransferTask( __sync_fetch_and_add(&task.slice_count, 1); // add last part Slice *last_slice = getSliceCache().allocate(); - last_slice->source_addr = (char *)request.source + offset + first_length; + last_slice->source_addr = + (char *)request.source + offset + first_length; last_slice->length = last_length; last_slice->opcode = request.opcode; - last_slice->rdma.dest_addr = request.target_offset + offset + first_length; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_length; last_slice->rdma.retry_cnt = 0; last_slice->rdma.max_retry_cnt = kMaxRetryCount; last_slice->task = &task; @@ -843,7 +961,8 @@ Status BarexTransport::submitTransferTask( last_slice->status = Slice::PENDING; task.slice_list.push_back(last_slice); last_slice->rdma.source_lkey = - local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; last_slice->rdma.lkey_index = device_id; last_slice->dest_rkeys = peer_segment_desc->buffers[peer_buffer_id].rkey; @@ -852,22 +971,38 @@ Status BarexTransport::submitTransferTask( __sync_fetch_and_add(&task.slice_count, 1); found_device = true; break; - } else { // 2:2 - auto &first_local_buffer_desc = local_segment_desc.get()->buffers[local_buffer_id]; - auto &last_local_buffer_desc = local_segment_desc.get()->buffers[extra_local_buffer_id]; - size_t first_local_length = first_local_buffer_desc.addr + first_local_buffer_desc.length - (size_t)slice->source_addr; - size_t last_local_length = (size_t)slice->source_addr + slice->length - last_local_buffer_desc.addr; - assert(first_local_length + last_local_length == slice->length); - auto &first_peer_buffer_desc = peer_segment_desc.get()->buffers[peer_buffer_id]; - auto &last_peer_buffer_desc = peer_segment_desc.get()->buffers[extra_peer_buffer_id]; - size_t first_peer_length = first_peer_buffer_desc.addr + first_peer_buffer_desc.length - slice->rdma.dest_addr; - size_t last_peer_length = slice->rdma.dest_addr + slice->length - last_peer_buffer_desc.addr; - assert(first_peer_length + last_peer_length == slice->length); + } else { // 2:2 + auto &first_local_buffer_desc = + local_segment_desc.get()->buffers[local_buffer_id]; + auto &last_local_buffer_desc = + local_segment_desc.get() + ->buffers[extra_local_buffer_id]; + size_t first_local_length = first_local_buffer_desc.addr + + first_local_buffer_desc.length - + (size_t)slice->source_addr; + size_t last_local_length = (size_t)slice->source_addr + + slice->length - + last_local_buffer_desc.addr; + assert(first_local_length + last_local_length == + slice->length); + auto &first_peer_buffer_desc = + peer_segment_desc.get()->buffers[peer_buffer_id]; + auto &last_peer_buffer_desc = + peer_segment_desc.get()->buffers[extra_peer_buffer_id]; + size_t first_peer_length = first_peer_buffer_desc.addr + + first_peer_buffer_desc.length - + slice->rdma.dest_addr; + size_t last_peer_length = slice->rdma.dest_addr + + slice->length - + last_peer_buffer_desc.addr; + assert(first_peer_length + last_peer_length == + slice->length); if (first_local_length == first_peer_length) { // add first part slice->length = first_local_length; slice->rdma.source_lkey = - local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; slice->rdma.lkey_index = device_id; slice->dest_rkeys = peer_segment_desc->buffers[peer_buffer_id].rkey; @@ -876,10 +1011,12 @@ Status BarexTransport::submitTransferTask( __sync_fetch_and_add(&task.slice_count, 1); // add last part Slice *last_slice = getSliceCache().allocate(); - last_slice->source_addr = (char *)request.source + offset + first_local_length; + last_slice->source_addr = (char *)request.source + + offset + first_local_length; last_slice->length = last_local_length; last_slice->opcode = request.opcode; - last_slice->rdma.dest_addr = request.target_offset + offset + first_local_length; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_local_length; last_slice->rdma.retry_cnt = 0; last_slice->rdma.max_retry_cnt = kMaxRetryCount; last_slice->task = &task; @@ -888,10 +1025,12 @@ Status BarexTransport::submitTransferTask( last_slice->status = Slice::PENDING; task.slice_list.push_back(last_slice); last_slice->rdma.source_lkey = - local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; last_slice->rdma.lkey_index = device_id; last_slice->dest_rkeys = - peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + peer_segment_desc->buffers[extra_peer_buffer_id] + .rkey; slices_to_post[context].push_back(last_slice); task.total_bytes += last_local_length; __sync_fetch_and_add(&task.slice_count, 1); @@ -899,7 +1038,8 @@ Status BarexTransport::submitTransferTask( // add first part slice->length = first_peer_length; slice->rdma.source_lkey = - local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; slice->rdma.lkey_index = device_id; slice->dest_rkeys = peer_segment_desc->buffers[peer_buffer_id].rkey; @@ -908,10 +1048,13 @@ Status BarexTransport::submitTransferTask( __sync_fetch_and_add(&task.slice_count, 1); // add second part Slice *second_slice = getSliceCache().allocate(); - second_slice->source_addr = (char *)request.source + offset + first_peer_length; - second_slice->length = first_local_length - first_peer_length; + second_slice->source_addr = + (char *)request.source + offset + first_peer_length; + second_slice->length = + first_local_length - first_peer_length; second_slice->opcode = request.opcode; - second_slice->rdma.dest_addr = request.target_offset + offset + first_peer_length; + second_slice->rdma.dest_addr = + request.target_offset + offset + first_peer_length; second_slice->rdma.retry_cnt = 0; second_slice->rdma.max_retry_cnt = kMaxRetryCount; second_slice->task = &task; @@ -920,19 +1063,24 @@ Status BarexTransport::submitTransferTask( second_slice->status = Slice::PENDING; task.slice_list.push_back(second_slice); second_slice->rdma.source_lkey = - local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; second_slice->rdma.lkey_index = device_id; second_slice->dest_rkeys = - peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + peer_segment_desc->buffers[extra_peer_buffer_id] + .rkey; slices_to_post[context].push_back(second_slice); - task.total_bytes += first_local_length - first_peer_length; + task.total_bytes += + first_local_length - first_peer_length; __sync_fetch_and_add(&task.slice_count, 1); // add last part Slice *last_slice = getSliceCache().allocate(); - last_slice->source_addr = (char *)request.source + offset + first_local_length; + last_slice->source_addr = (char *)request.source + + offset + first_local_length; last_slice->length = last_local_length; last_slice->opcode = request.opcode; - last_slice->rdma.dest_addr = request.target_offset + offset + first_local_length; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_local_length; last_slice->rdma.retry_cnt = 0; last_slice->rdma.max_retry_cnt = kMaxRetryCount; last_slice->task = &task; @@ -941,19 +1089,22 @@ Status BarexTransport::submitTransferTask( last_slice->status = Slice::PENDING; task.slice_list.push_back(last_slice); last_slice->rdma.source_lkey = - local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; last_slice->rdma.lkey_index = device_id; last_slice->dest_rkeys = - peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + peer_segment_desc->buffers[extra_peer_buffer_id] + .rkey; slices_to_post[context].push_back(last_slice); task.total_bytes += last_local_length; __sync_fetch_and_add(&task.slice_count, 1); - } else { + } else { // first_local_length < first_peer_length // add first part slice->length = first_local_length; slice->rdma.source_lkey = - local_segment_desc->buffers[local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[local_buffer_id] + .lkey[device_id]; slice->rdma.lkey_index = device_id; slice->dest_rkeys = peer_segment_desc->buffers[peer_buffer_id].rkey; @@ -962,10 +1113,13 @@ Status BarexTransport::submitTransferTask( __sync_fetch_and_add(&task.slice_count, 1); // add second part Slice *second_slice = getSliceCache().allocate(); - second_slice->source_addr = (char *)request.source + offset + first_local_length; - second_slice->length = first_peer_length - first_local_length; + second_slice->source_addr = (char *)request.source + + offset + first_local_length; + second_slice->length = + first_peer_length - first_local_length; second_slice->opcode = request.opcode; - second_slice->rdma.dest_addr = request.target_offset + offset + first_local_length; + second_slice->rdma.dest_addr = + request.target_offset + offset + first_local_length; second_slice->rdma.retry_cnt = 0; second_slice->rdma.max_retry_cnt = kMaxRetryCount; second_slice->task = &task; @@ -974,19 +1128,23 @@ Status BarexTransport::submitTransferTask( second_slice->status = Slice::PENDING; task.slice_list.push_back(second_slice); second_slice->rdma.source_lkey = - local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; second_slice->rdma.lkey_index = device_id; second_slice->dest_rkeys = peer_segment_desc->buffers[peer_buffer_id].rkey; slices_to_post[context].push_back(second_slice); - task.total_bytes += first_peer_length - first_local_length; + task.total_bytes += + first_peer_length - first_local_length; __sync_fetch_and_add(&task.slice_count, 1); // add last part Slice *last_slice = getSliceCache().allocate(); - last_slice->source_addr = (char *)request.source + offset + first_peer_length; + last_slice->source_addr = + (char *)request.source + offset + first_peer_length; last_slice->length = last_peer_length; last_slice->opcode = request.opcode; - last_slice->rdma.dest_addr = request.target_offset + offset + first_peer_length; + last_slice->rdma.dest_addr = + request.target_offset + offset + first_peer_length; last_slice->rdma.retry_cnt = 0; last_slice->rdma.max_retry_cnt = kMaxRetryCount; last_slice->task = &task; @@ -995,10 +1153,12 @@ Status BarexTransport::submitTransferTask( last_slice->status = Slice::PENDING; task.slice_list.push_back(last_slice); last_slice->rdma.source_lkey = - local_segment_desc->buffers[extra_local_buffer_id].lkey[device_id]; + local_segment_desc->buffers[extra_local_buffer_id] + .lkey[device_id]; last_slice->rdma.lkey_index = device_id; last_slice->dest_rkeys = - peer_segment_desc->buffers[extra_peer_buffer_id].rkey; + peer_segment_desc->buffers[extra_peer_buffer_id] + .rkey; slices_to_post[context].push_back(last_slice); task.total_bytes += last_peer_length; __sync_fetch_and_add(&task.slice_count, 1); @@ -1030,7 +1190,7 @@ Status BarexTransport::submitTransferTask( } Status BarexTransport::getTransferStatus(BatchID batch_id, - std::vector &status) { + std::vector &status) { auto &batch_desc = *((BatchDesc *)(batch_id)); const size_t task_count = batch_desc.task_list.size(); status.resize(task_count); @@ -1054,7 +1214,7 @@ Status BarexTransport::getTransferStatus(BatchID batch_id, } Status BarexTransport::getTransferStatus(BatchID batch_id, size_t task_id, - TransferStatus &status) { + TransferStatus &status) { auto &batch_desc = *((BatchDesc *)(batch_id)); const size_t task_count = batch_desc.task_list.size(); if (task_id >= task_count) { @@ -1083,20 +1243,22 @@ BarexTransport::SegmentID BarexTransport::getSegmentID( return metadata_->getSegmentID(segment_name); } -Status BarexTransport::OpenChannel(const std::string &segment_name, SegmentID sid) { +Status BarexTransport::OpenChannel(const std::string &segment_name, + SegmentID sid) { auto [ip, port] = parseHostNameWithPort(segment_name); - + HandShakeDesc local_desc, peer_desc; local_desc.barex_port = getLocalPort(); int rc = metadata_->sendHandshake(segment_name, local_desc, peer_desc); - if (rc) return Status::Socket("sendHandshake failed");; + if (rc) return Status::Socket("sendHandshake failed"); + ; if (!peer_desc.reply_msg.empty()) { - LOG(ERROR) << "Reject the handshake request by peer " - << segment_name; + LOG(ERROR) << "Reject the handshake request by peer " << segment_name; return Status::Socket("empty peer_desc"); } else { - LOG(INFO) << "Handshake finish, get peer_server " << segment_name << ":" << peer_desc.barex_port; + LOG(INFO) << "Handshake finish, get peer_server " << segment_name << ":" + << peer_desc.barex_port; setPeerPort(peer_desc.barex_port); } @@ -1106,29 +1268,37 @@ Status BarexTransport::OpenChannel(const std::string &segment_name, SegmentID si std::vector channels; static std::mutex push_channel_mtx; for (int i = 0; i < total_channels; i++) { - BarexResult result = connector_->Connect(ip, getPeerPort(), [=, &channels, &connect_latch](XChannel *channel, accl::barex::Status s) { - if (!s.IsOk()) { - LOG(ERROR) << "BarexTransport::OpenChannel failed, " << s.ErrMsg(); - } else { - std::unique_lock lk(push_channel_mtx); - channels.push_back(channel); - LOG(INFO) << "Open channel " << i+1 << "/" << total_channels; - } - connect_latch.CountDown(); - }); + BarexResult result = connector_->Connect( + ip, getPeerPort(), + [=, &channels, &connect_latch](XChannel *channel, + accl::barex::Status s) { + if (!s.IsOk()) { + LOG(ERROR) + << "BarexTransport::OpenChannel failed, " << s.ErrMsg(); + } else { + std::unique_lock lk(push_channel_mtx); + channels.push_back(channel); + LOG(INFO) + << "Open channel " << i + 1 << "/" << total_channels; + } + connect_latch.CountDown(); + }); if (result != BAREX_SUCCESS) { - LOG(ERROR) << "BarexTransport::OpenChannel failed, result=" << result; + LOG(ERROR) << "BarexTransport::OpenChannel failed, result=" + << result; connect_latch.CountDown(); } } connect_latch.Wait(); if ((int)channels.size() != total_channels) { - LOG(ERROR) << "open channel failed, need " << total_channels << " but got " << channels.size(); + LOG(ERROR) << "open channel failed, need " << total_channels + << " but got " << channels.size(); return Status::InvalidArgument("connect failed"); } for (auto channel : channels) { int idx = channel->GetContext()->GetXDevice()->GetId(); - assert(client_context_list_[idx]->getCtx()->GetXDevice()->GetId() == idx); + assert(client_context_list_[idx]->getCtx()->GetXDevice()->GetId() == + idx); client_context_list_[idx]->addChannel(sid, idx, channel); } return Status::OK(); @@ -1139,19 +1309,20 @@ Status BarexTransport::CheckStatus(SegmentID sid) { for (auto ctx : client_context_list_) { int ret = ctx->checkStatus(sid); if (ret) { - LOG(INFO) << "checkStatus failed in ctx" << ctx << ", bad channel cnt=" << ret; + LOG(INFO) << "checkStatus failed in ctx" << ctx + << ", bad channel cnt=" << ret; status = 1; } } if (!status) { LOG(ERROR) << "CheckStatus for sid " << sid << " failed"; - return Status::InvalidArgument("sid status error"); + return Status::InvalidArgument("sid status error"); } return Status::OK(); } int BarexTransport::onSetupRdmaConnections(const HandShakeDesc &peer_desc, - HandShakeDesc &local_desc) { + HandShakeDesc &local_desc) { local_desc.barex_port = getLocalPort(); return 0; } @@ -1181,13 +1352,15 @@ int BarexTransport::initializeRdmaResources() { return ERR_INVALID_ARGUMENT; } mempool_ = std::shared_ptr(mempool); - result = XThreadpool::NewInstance(server_threadpool, 10, "barex-server-threadpool"); + result = XThreadpool::NewInstance(server_threadpool, 10, + "barex-server-threadpool"); if (result != BAREX_SUCCESS) { LOG(ERROR) << "BarexTransport: Create Server XThreadpool failed"; return ERR_INVALID_ARGUMENT; } server_threadpool_ = std::shared_ptr(server_threadpool); - result = XThreadpool::NewInstance(client_threadpool, 10, "barex-client-threadpool"); + result = XThreadpool::NewInstance(client_threadpool, 10, + "barex-client-threadpool"); if (result != BAREX_SUCCESS) { LOG(ERROR) << "BarexTransport: Create Client XThreadpool failed"; return ERR_INVALID_ARGUMENT; @@ -1195,31 +1368,43 @@ int BarexTransport::initializeRdmaResources() { client_threadpool_ = std::shared_ptr(client_threadpool); auto &config = globalConfig(); for (auto &dev : devices) { - if (std::find(hca_list.begin(), hca_list.end(), dev->GetName()) == hca_list.end()) { - LOG(WARNING) << "BarexTransport: device " << dev->GetName() << " not found in hca_list, ignore "; + if (std::find(hca_list.begin(), hca_list.end(), dev->GetName()) == + hca_list.end()) { + LOG(WARNING) << "BarexTransport: device " << dev->GetName() + << " not found in hca_list, ignore "; continue; } ContextConfig server_config = XConfigUtil::DefaultContextConfig(); XContext *raw_server_context = nullptr; - result = XContext::NewInstance(raw_server_context, server_config, new EmptyCallback(), dev, mempool, server_threadpool); + result = XContext::NewInstance(raw_server_context, server_config, + new EmptyCallback(), dev, mempool, + server_threadpool); if (result != BAREX_SUCCESS) { local_topology_->disableDevice(dev->GetName()); - LOG(WARNING) << "BarexTransport: Create XContext failed, Disable device " << dev->GetName(); + LOG(WARNING) + << "BarexTransport: Create XContext failed, Disable device " + << dev->GetName(); } else { raw_server_context->Start(); - auto server_context = std::make_shared(raw_server_context, barex_use_cpu_, barex_local_device_); + auto server_context = std::make_shared( + raw_server_context, barex_use_cpu_, barex_local_device_); server_context->setQpNum(config.num_qp_per_ep); server_context_list_.push_back(server_context); } ContextConfig client_config = XConfigUtil::DefaultContextConfig(); XContext *raw_client_context = nullptr; - result = XContext::NewInstance(raw_client_context, client_config, new EmptyCallback(), dev, mempool, client_threadpool); + result = XContext::NewInstance(raw_client_context, client_config, + new EmptyCallback(), dev, mempool, + client_threadpool); if (result != BAREX_SUCCESS) { local_topology_->disableDevice(dev->GetName()); - LOG(WARNING) << "BarexTransport: Create XContext failed, Disable device " << dev->GetName(); + LOG(WARNING) + << "BarexTransport: Create XContext failed, Disable device " + << dev->GetName(); } else { raw_client_context->Start(); - auto client_context = std::make_shared(raw_client_context, barex_use_cpu_, barex_local_device_); + auto client_context = std::make_shared( + raw_client_context, barex_use_cpu_, barex_local_device_); client_context->setQpNum(config.num_qp_per_ep); client_context_list_.push_back(client_context); } @@ -1233,33 +1418,41 @@ int BarexTransport::initializeRdmaResources() { } int BarexTransport::startHandshakeDaemon(std::string &local_server_name) { - std::vector raw_server_contexts; - std::vector raw_client_contexts; + std::vector raw_server_contexts; + std::vector raw_client_contexts; for (auto ctx : server_context_list_) { raw_server_contexts.emplace_back(ctx->getCtx()); } for (auto ctx : client_context_list_) { raw_client_contexts.emplace_back(ctx->getCtx()); } - XListener* listener = nullptr; + XListener *listener = nullptr; int port = metadata_->localRpcMeta().barex_port; setLocalPort(port); - BarexResult result = XListener::NewInstance(listener, 2, getLocalPort(), TIMER_3S, raw_server_contexts); + BarexResult result = XListener::NewInstance(listener, 2, getLocalPort(), + TIMER_3S, raw_server_contexts); if (result != BAREX_SUCCESS) { - LOG(ERROR) << "BarexTransport: startHandshakeDaemon, create listener failed, result " << result; + LOG(ERROR) << "BarexTransport: startHandshakeDaemon, create listener " + "failed, result " + << result; return ERR_INVALID_ARGUMENT; } result = listener->Listen(); if (result != BAREX_SUCCESS) { - LOG(ERROR) << "BarexTransport: startHandshakeDaemon, Listen failed, result " << result; + LOG(ERROR) + << "BarexTransport: startHandshakeDaemon, Listen failed, result " + << result; return ERR_INVALID_ARGUMENT; } listener_ = std::shared_ptr(listener); - XConnector* connector = nullptr; - result = XConnector::NewInstance(connector, 2, TIMER_3S, raw_client_contexts); + XConnector *connector = nullptr; + result = + XConnector::NewInstance(connector, 2, TIMER_3S, raw_client_contexts); if (result != BAREX_SUCCESS) { - LOG(ERROR) << "BarexTransport: startHandshakeDaemon, create connector failed, result " << result; + LOG(ERROR) << "BarexTransport: startHandshakeDaemon, create connector " + "failed, result " + << result; return ERR_INVALID_ARGUMENT; } connector_ = std::shared_ptr(connector); @@ -1273,46 +1466,52 @@ int BarexTransport::startHandshakeDaemon(std::string &local_server_name) { // buffer_id and device_id as output. // Return 0 if successful, ERR_ADDRESS_NOT_REGISTERED otherwise. int BarexTransport::selectDevice(SegmentDesc *desc, uint64_t offset, - size_t length, int &buffer_id, int &device_id, - int retry_count) { + size_t length, int &buffer_id, int &device_id, + int retry_count) { if (!desc) return ERR_ADDRESS_NOT_REGISTERED; int ret = 0; for (buffer_id = 0; buffer_id < (int)desc->buffers.size(); ++buffer_id) { auto &buffer_desc = desc->buffers[buffer_id]; - if (buffer_desc.addr > offset || offset >= buffer_desc.addr + buffer_desc.length) { + if (buffer_desc.addr > offset || + offset >= buffer_desc.addr + buffer_desc.length) { continue; } else { if (offset + length > buffer_desc.addr + buffer_desc.length) { // mr cross two buffers, need separate into two parts if (buffer_id + 1 < (int)desc->buffers.size()) { - auto &next_buffer_desc = desc->buffers[buffer_id+1]; - if (offset + length > next_buffer_desc.addr && offset + length <= next_buffer_desc.addr + next_buffer_desc.length) { + auto &next_buffer_desc = desc->buffers[buffer_id + 1]; + if (offset + length > next_buffer_desc.addr && + offset + length <= + next_buffer_desc.addr + next_buffer_desc.length) { ret = 1; } else { - LOG(ERROR) << "selectDevice failed, 2 buffers in need but next buffer not fit," - << " offset " << offset - << " length " << length - << " buffer_id " << buffer_id + LOG(ERROR) << "selectDevice failed, 2 buffers in need " + "but next buffer not fit," + << " offset " << offset << " length " + << length << " buffer_id " << buffer_id << " buffer_desc.addr " << buffer_desc.addr - << " buffer_desc.length " << buffer_desc.length - << " buffer_id " << buffer_id+1 - << " next_buffer_desc.addr " << next_buffer_desc.addr - << " next_buffer_desc.length " << next_buffer_desc.length; + << " buffer_desc.length " + << buffer_desc.length << " buffer_id " + << buffer_id + 1 << " next_buffer_desc.addr " + << next_buffer_desc.addr + << " next_buffer_desc.length " + << next_buffer_desc.length; return ERR_ADDRESS_NOT_REGISTERED; } } else { LOG(ERROR) << "selectDevice failed, last buffer overflow," - << " offset " << offset - << " length " << length + << " offset " << offset << " length " << length << " buffer_id " << buffer_id << " buffer_desc.addr " << buffer_desc.addr << " buffer_desc.length " << buffer_desc.length; return ERR_ADDRESS_NOT_REGISTERED; } } - device_id = desc->topology.selectDevice(buffer_desc.name, retry_count); + device_id = + desc->topology.selectDevice(buffer_desc.name, retry_count); if (device_id >= 0) return ret; - device_id = desc->topology.selectDevice(kWildcardLocation, retry_count); + device_id = + desc->topology.selectDevice(kWildcardLocation, retry_count); if (device_id >= 0) return ret; } } From e2064246eb9c2b70db8a8078d222ec3050dc4d42 Mon Sep 17 00:00:00 2001 From: Teng Ma Date: Wed, 12 Nov 2025 16:33:58 +0800 Subject: [PATCH 4/6] feat[accl-barex]: fix clang format --- .../include/transport/barex_transport/barex_transport.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h b/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h index c09a32949..cf6f2fe1d 100644 --- a/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h +++ b/mooncake-transfer-engine/include/transport/barex_transport/barex_transport.h @@ -50,7 +50,7 @@ class CountDownLatch { std::condition_variable cv; public: - CountDownLatch(int count) : count_(count){}; + CountDownLatch(int count) : count_(count) {}; void CountDown() { std::unique_lock lk(mtx); From 947ee6cac314d739f917d759dfd1f85c838bc64d Mon Sep 17 00:00:00 2001 From: "zhangzechao.zzc" Date: Thu, 13 Nov 2025 15:31:29 +0800 Subject: [PATCH 5/6] feat[barex]: add log --- mooncake-integration/transfer_engine/transfer_engine_py.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mooncake-integration/transfer_engine/transfer_engine_py.cpp b/mooncake-integration/transfer_engine/transfer_engine_py.cpp index 7d5018ea2..26151786e 100644 --- a/mooncake-integration/transfer_engine/transfer_engine_py.cpp +++ b/mooncake-integration/transfer_engine/transfer_engine_py.cpp @@ -140,7 +140,8 @@ int TransferEnginePy::initializeExt(const char *local_hostname, pass_alloc = true; } } catch (const std::exception &) { - // Ignore invalid values or log a warning + LOG(WARNING) << "Ignore value from environment variable " + "PASS_ALLOC"; } } if (!pass_alloc) { From 4c3678baf7fa3e53d4797baa36a9128036ffcf5e Mon Sep 17 00:00:00 2001 From: Teng Ma Date: Wed, 19 Nov 2025 00:46:59 +0800 Subject: [PATCH 6/6] Update mooncake-common/common.cmake --- mooncake-common/common.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mooncake-common/common.cmake b/mooncake-common/common.cmake index 564b1b535..c2ba29edd 100644 --- a/mooncake-common/common.cmake +++ b/mooncake-common/common.cmake @@ -61,7 +61,7 @@ option(USE_MUSA "option for enabling gpu features for MTHREADS GPU" OFF) option(USE_HIP "option for enabling gpu features for AMD GPU" OFF) option(USE_NVMEOF "option for using NVMe over Fabric" OFF) option(USE_TCP "option for using TCP transport" ON) -option(USE_BAREX "option for using accl-barex transport" OFF) +option(USE_BAREX "option for using accl-barex transport" ON) option(USE_ASCEND "option for using npu with HCCL" OFF) option(USE_ASCEND_DIRECT "option for using ascend npu with adxl engine" OFF) option(USE_ASCEND_HETEROGENEOUS "option for transferring between ascend npu and gpu" OFF)