Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mooncake-common/common.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ option(USE_MUSA "option for enabling gpu features for MTHREADS GPU" OFF)
option(USE_HIP "option for enabling gpu features for AMD GPU" OFF)
option(USE_NVMEOF "option for using NVMe over Fabric" OFF)
option(USE_TCP "option for using TCP transport" ON)
option(USE_BAREX "option for using accl-barex transport" OFF)
option(USE_ASCEND "option for using npu with HCCL" OFF)
option(USE_ASCEND_DIRECT "option for using ascend npu with adxl engine" OFF)
option(USE_ASCEND_HETEROGENEOUS "option for transferring between ascend npu and gpu" OFF)
Expand Down Expand Up @@ -143,6 +144,10 @@ if (USE_TCP)
add_compile_definitions(USE_TCP)
endif()

if (USE_BAREX)
add_compile_definitions(USE_BAREX)
endif()

if (USE_ASCEND OR USE_ASCEND_DIRECT)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DOPEN_BUILD_PROJECT ")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DOPEN_BUILD_PROJECT ")
Expand Down
2 changes: 1 addition & 1 deletion mooncake-integration/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,6 @@ def get_allocator(cls, device: torch_device) -> CUDAPluggableAllocator:
if device not in cls._instances:
so_path = cls._get_so_path()
cls._instances[device] = CUDAPluggableAllocator(
so_path, "u2mm_alloc_wrapper", "u2mm_free_wrapper"
so_path, "u2mm_alloc_wrapper_with_stream", "u2mm_free_wrapper_with_stream"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this has a compatible issue

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't worry. It works.

)
return cls._instances[device]
43 changes: 41 additions & 2 deletions mooncake-integration/transfer_engine/transfer_engine_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,21 @@ int TransferEnginePy::initializeExt(const char *local_hostname,
free_list_.resize(kSlabSizeKBTabLen);
#if !defined(USE_ASCEND) && !defined(USE_ASCEND_DIRECT) && \
!defined(USE_ASCEND_HETEROGENEOUS)
doBuddyAllocate(kMaxClassId);
bool pass_alloc = false;
const char *pass_alloc_env = std::getenv("PASS_ALLOC");
if (pass_alloc_env) {
try {
if (std::stoi(pass_alloc_env) != 0) {
pass_alloc = true;
}
} catch (const std::exception &) {
LOG(WARNING) << "Ignore value from environment variable "
"PASS_ALLOC";
}
}
if (!pass_alloc) {
doBuddyAllocate(kMaxClassId);
}
#endif
return 0;
}
Expand Down Expand Up @@ -266,6 +280,9 @@ int TransferEnginePy::transferSync(const char *target_hostname,
if (handle_map_.count(target_hostname)) {
handle = handle_map_[target_hostname];
} else {
LOG(INFO)
<< "transferSync, cache not found, openSegment with target "
<< target_hostname;
handle = engine_->openSegment(target_hostname);
if (handle == (Transport::SegmentHandle)-1) return -1;
handle_map_[target_hostname] = handle;
Expand Down Expand Up @@ -300,7 +317,19 @@ int TransferEnginePy::transferSync(const char *target_hostname,
batch_id, {entry},
TransferMetadata::NotifyDesc{notify->name, notify->msg})
: engine_->submitTransfer(batch_id, {entry});
if (!s.ok()) return -1;
if (!s.ok()) {
Status segment_status = engine_->CheckSegmentStatus(handle);
if (!segment_status.ok()) {
LOG(WARNING)
<< "submitTransfer failed with target " << target_hostname
<< ", CheckSegmentStatus not ok, ready to closeSegment";
std::lock_guard<std::mutex> guard(mutex_);
engine_->closeSegment(handle);
engine_->getMetadata()->removeSegmentDesc(target_hostname);
handle_map_.erase(target_hostname);
}
return -1;
}

TransferStatus status;
bool completed = false;
Expand Down Expand Up @@ -387,6 +416,16 @@ int TransferEnginePy::batchTransferSync(
: engine_->submitTransfer(batch_id, entries);
if (!s.ok()) {
engine_->freeBatchID(batch_id);
Status segment_status = engine_->CheckSegmentStatus(handle);
if (!segment_status.ok()) {
LOG(WARNING)
<< "submitTransfer failed with target " << target_hostname
<< ", CheckSegmentStatus not ok, ready to closeSegment";
std::lock_guard<std::mutex> guard(mutex_);
engine_->closeSegment(handle);
engine_->getMetadata()->removeSegmentDesc(target_hostname);
handle_map_.erase(target_hostname);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code block looks like the same as above

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For each non-OK request, it should check the results. I guess we should wrap this code block with USE_BAREX.

return -1;
}

Expand Down
14 changes: 13 additions & 1 deletion mooncake-transfer-engine/example/transfer_engine_bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ DEFINE_string(mode, "initiator",
"data blocks from target node");
DEFINE_string(operation, "read", "Operation type: read or write");

DEFINE_string(protocol, "rdma", "Transfer protocol: rdma|tcp");
DEFINE_string(protocol, "rdma", "Transfer protocol: rdma|barex|tcp");

DEFINE_string(device_name, "mlx5_2",
"Device name to use, valid if protocol=rdma");
Expand Down Expand Up @@ -301,6 +301,12 @@ int initiator() {
args[0] = (void *)nic_priority_matrix.c_str();
args[1] = nullptr;
xport = engine->installTransport("rdma", args);
} else if (FLAGS_protocol == "barex") {
auto nic_priority_matrix = loadNicPriorityMatrix();
void **args = (void **)malloc(2 * sizeof(void *));
args[0] = (void *)nic_priority_matrix.c_str();
args[1] = nullptr;
xport = engine->installTransport("barex", args);
} else if (FLAGS_protocol == "tcp") {
xport = engine->installTransport("tcp", nullptr);
} else if (FLAGS_protocol == "nvlink") {
Expand Down Expand Up @@ -421,6 +427,12 @@ int target() {
args[0] = (void *)nic_priority_matrix.c_str();
args[1] = nullptr;
engine->installTransport("rdma", args);
} else if (FLAGS_protocol == "barex") {
auto nic_priority_matrix = loadNicPriorityMatrix();
void **args = (void **)malloc(2 * sizeof(void *));
args[0] = (void *)nic_priority_matrix.c_str();
args[1] = nullptr;
engine->installTransport("barex", args);
} else if (FLAGS_protocol == "tcp") {
engine->installTransport("tcp", nullptr);
} else if (FLAGS_protocol == "nvlink") {
Expand Down
1 change: 1 addition & 0 deletions mooncake-transfer-engine/include/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ struct GlobalConfig {
bool use_ipv6 = false;
size_t fragment_limit = 16384;
bool enable_dest_device_affinity = false;
size_t eic_max_block_size = 64UL * 1024 * 1024;
EndpointStoreType endpoint_store_type = EndpointStoreType::SIEVE;
};

Expand Down
3 changes: 3 additions & 0 deletions mooncake-transfer-engine/include/transfer_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class TransferEngine {

SegmentHandle openSegment(const std::string &segment_name);

Status CheckSegmentStatus(SegmentID sid);

int closeSegment(SegmentHandle handle);

int removeLocalSegment(const std::string &segment_name);
Expand Down Expand Up @@ -249,6 +251,7 @@ class TransferEngine {
// Set it to false only for testing.
bool auto_discover_;
std::vector<std::string> filter_;
bool use_barex_ = false;

#ifdef WITH_METRICS
ylt::metric::counter_t transferred_bytes_counter_{
Expand Down
2 changes: 2 additions & 0 deletions mooncake-transfer-engine/include/transfer_metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,14 @@ class TransferMetadata {
struct RpcMetaDesc {
std::string ip_or_host_name;
uint16_t rpc_port;
uint16_t barex_port;
int sockfd; // local cache
};

struct HandShakeDesc {
std::string local_nic_path;
std::string peer_nic_path;
uint16_t barex_port;
std::vector<uint32_t> qp_num;
std::string reply_msg; // on error
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct HandShakePlugin {

std::vector<std::string> findLocalIpAddresses();

uint16_t findAvailableTcpPort(int &sockfd);
uint16_t findAvailableTcpPort(int &sockfd, bool set_range = false);

} // namespace mooncake

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
// Copyright 2024 KVCache.AI
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef BAREX_CONTEXT_H_
#define BAREX_CONTEXT_H_

#include <infiniband/verbs.h>

#include <atomic>
#include <cstddef>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "common.h"
#include "transport/transport.h"

#ifdef USE_BAREX
#include <accl/barex/barex.h>
#include <accl/barex/xcontext.h>
#include <accl/barex/xlistener.h>
#include <accl/barex/xconnector.h>
#include <accl/barex/xsimple_mempool.h>
#include <accl/barex/xthreadpool.h>
#include <accl/barex/xtimer.h>
#include <accl/barex/xconfig_util.h>
#endif

namespace mooncake {

#ifdef USE_BAREX

using namespace accl::barex;
using XChannel = accl::barex::XChannel;
using SegmentID = Transport::SegmentID;
using XContext = accl::barex::XContext;
using BarexResult = accl::barex::BarexResult;

class ChannelCache {
public:
// put channel
void put(SegmentID key, int nic_id, XChannel* channel) {
RWSpinlock::WriteGuard guard(lock_);
auto& channels = cache_[key];
auto& vec = channels[nic_id];
status_map_[key] = true;
vec.push_back(channel);
}

// get channel
XChannel* find(SegmentID key, int nic_id, int idx) {
RWSpinlock::ReadGuard guard(lock_);
auto it = cache_.find(key);
if (it == cache_.end()) return nullptr;
auto& channels = it->second;
auto ch_it = channels.find(nic_id);
if (ch_it == channels.end()) return nullptr;
auto& vec = ch_it->second;
if (idx >= 0 && idx < static_cast<int>(vec.size())) {
return vec[idx];
}
return nullptr;
}

// delete channel
bool erase(SegmentID key, int nic_id, int idx) {
RWSpinlock::WriteGuard guard(lock_);
auto it = cache_.find(key);
if (it == cache_.end()) return false;

auto& channels = it->second;
auto ch_it = channels.find(nic_id);
if (ch_it == channels.end()) return false;

auto& vec = ch_it->second;
if (idx < 0 || idx >= static_cast<int>(vec.size())) return false;

vec.erase(vec.begin() + idx);
status_map_[key] = false;
if (vec.empty()) {
channels.erase(ch_it);
if (channels.empty()) {
cache_.erase(it);
}
}
return true;
}

// get channel state
bool CheckAllChannels(SegmentID segment_id) {
RWSpinlock::ReadGuard guard(lock_);
auto it = cache_.find(segment_id);
if (it == cache_.end()) {
return false;
}
auto& inner_map = it->second;
for (auto& pair : inner_map) {
auto& channels = pair.second;
for (XChannel* channel : channels) {
if (!channel->IsActive()) {
return false;
}
}
}
return true;
}

// check and delete invalid channels
int RemoveInvalidChannels(SegmentID segment_id) {
RWSpinlock::WriteGuard guard(lock_);
auto it = cache_.find(segment_id);
if (it == cache_.end()) {
return 0;
}

int invalid_count = 0;
auto& inner_map = it->second;

for (auto& pair : inner_map) {
auto& channels = pair.second;
auto new_end = std::remove_if(
channels.begin(), channels.end(),
[](XChannel* channel) { return !channel->IsActive(); });
invalid_count += std::distance(new_end, channels.end());
channels.erase(new_end, channels.end());
}
return invalid_count;
}

// get all channels
std::vector<XChannel*> copyAll() {
RWSpinlock::WriteGuard guard(lock_);
std::vector<XChannel*> result;
for (const auto& [key, channels] : cache_) {
for (const auto& [nic_id, vec] : channels) {
result.insert(result.end(), vec.begin(), vec.end());
}
}
return result;
}

private:
std::unordered_map<SegmentID,
std::unordered_map<int, std::vector<XChannel*>>>
cache_;
std::unordered_map<SegmentID, bool> status_map_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The status_map_ member of ChannelCache is written to in put() and erase(), but it is never read. This appears to be dead code and should be removed to simplify the class.

RWSpinlock lock_;
};
class BarexContext {
public:
int submitPostSend(const std::vector<Transport::Slice*>& slice_list);
int addChannel(SegmentID sid, int device_id, XChannel* ch);
XChannel* getChannel(SegmentID sid, int device_id, int idx);
int checkStatus(SegmentID sid);
XContext* getCtx();
// int ClearAllChannel();
std::vector<XChannel*> getAllChannel();
bool active() const { return active_; }
void setQpNum(int qp_num) { qp_num_per_ctx_ = qp_num; }
int getQpNum() const { return qp_num_per_ctx_; }

public:
BarexContext(XContext* xcontext, bool use_cpu, int device_id);

~BarexContext();

XContext* xcontext_;
bool barex_use_cpu_;
int barex_local_device_;

private:
ChannelCache channel_cache_;
bool active_ = true;
int qp_num_per_ctx_ = 2;
};
#endif
} // namespace mooncake

#endif // BAREX_CONTEXT_H_
Loading
Loading