-
Notifications
You must be signed in to change notification settings - Fork 432
feat[accl-barex]: add barex_transport by build with USE_BAREX #1045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
91d661f
3274770
5e9f409
cd1ed0f
e206424
947ee6c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -132,7 +132,21 @@ int TransferEnginePy::initializeExt(const char *local_hostname, | |
| free_list_.resize(kSlabSizeKBTabLen); | ||
| #if !defined(USE_ASCEND) && !defined(USE_ASCEND_DIRECT) && \ | ||
| !defined(USE_ASCEND_HETEROGENEOUS) | ||
| doBuddyAllocate(kMaxClassId); | ||
| bool pass_alloc = false; | ||
| const char *pass_alloc_env = std::getenv("PASS_ALLOC"); | ||
| if (pass_alloc_env) { | ||
| try { | ||
| if (std::stoi(pass_alloc_env) != 0) { | ||
| pass_alloc = true; | ||
| } | ||
| } catch (const std::exception &) { | ||
| LOG(WARNING) << "Ignore value from environment variable " | ||
| "PASS_ALLOC"; | ||
| } | ||
| } | ||
| if (!pass_alloc) { | ||
| doBuddyAllocate(kMaxClassId); | ||
| } | ||
| #endif | ||
| return 0; | ||
| } | ||
|
|
@@ -266,6 +280,9 @@ int TransferEnginePy::transferSync(const char *target_hostname, | |
| if (handle_map_.count(target_hostname)) { | ||
| handle = handle_map_[target_hostname]; | ||
| } else { | ||
| LOG(INFO) | ||
| << "transferSync, cache not found, openSegment with target " | ||
| << target_hostname; | ||
| handle = engine_->openSegment(target_hostname); | ||
| if (handle == (Transport::SegmentHandle)-1) return -1; | ||
| handle_map_[target_hostname] = handle; | ||
|
|
@@ -300,7 +317,19 @@ int TransferEnginePy::transferSync(const char *target_hostname, | |
| batch_id, {entry}, | ||
| TransferMetadata::NotifyDesc{notify->name, notify->msg}) | ||
| : engine_->submitTransfer(batch_id, {entry}); | ||
| if (!s.ok()) return -1; | ||
| if (!s.ok()) { | ||
| Status segment_status = engine_->CheckSegmentStatus(handle); | ||
| if (!segment_status.ok()) { | ||
| LOG(WARNING) | ||
| << "submitTransfer failed with target " << target_hostname | ||
| << ", CheckSegmentStatus not ok, ready to closeSegment"; | ||
| std::lock_guard<std::mutex> guard(mutex_); | ||
| engine_->closeSegment(handle); | ||
| engine_->getMetadata()->removeSegmentDesc(target_hostname); | ||
| handle_map_.erase(target_hostname); | ||
| } | ||
| return -1; | ||
| } | ||
|
|
||
| TransferStatus status; | ||
| bool completed = false; | ||
|
|
@@ -387,6 +416,16 @@ int TransferEnginePy::batchTransferSync( | |
| : engine_->submitTransfer(batch_id, entries); | ||
| if (!s.ok()) { | ||
| engine_->freeBatchID(batch_id); | ||
| Status segment_status = engine_->CheckSegmentStatus(handle); | ||
| if (!segment_status.ok()) { | ||
| LOG(WARNING) | ||
| << "submitTransfer failed with target " << target_hostname | ||
| << ", CheckSegmentStatus not ok, ready to closeSegment"; | ||
| std::lock_guard<std::mutex> guard(mutex_); | ||
| engine_->closeSegment(handle); | ||
| engine_->getMetadata()->removeSegmentDesc(target_hostname); | ||
| handle_map_.erase(target_hostname); | ||
| } | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This code block looks like the same as above
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For each non-OK request, it should check the results. I guess we should wrap this code block with USE_BAREX. |
||
| return -1; | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,194 @@ | ||
| // Copyright 2024 KVCache.AI | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #ifndef BAREX_CONTEXT_H_ | ||
| #define BAREX_CONTEXT_H_ | ||
|
|
||
| #include <infiniband/verbs.h> | ||
|
|
||
| #include <atomic> | ||
| #include <cstddef> | ||
| #include <map> | ||
| #include <memory> | ||
| #include <mutex> | ||
| #include <string> | ||
| #include <unordered_map> | ||
| #include <unordered_set> | ||
| #include <vector> | ||
|
|
||
| #include "common.h" | ||
| #include "transport/transport.h" | ||
|
|
||
| #ifdef USE_BAREX | ||
| #include <accl/barex/barex.h> | ||
| #include <accl/barex/xcontext.h> | ||
| #include <accl/barex/xlistener.h> | ||
| #include <accl/barex/xconnector.h> | ||
| #include <accl/barex/xsimple_mempool.h> | ||
| #include <accl/barex/xthreadpool.h> | ||
| #include <accl/barex/xtimer.h> | ||
| #include <accl/barex/xconfig_util.h> | ||
| #endif | ||
|
|
||
| namespace mooncake { | ||
|
|
||
| #ifdef USE_BAREX | ||
|
|
||
| using namespace accl::barex; | ||
| using XChannel = accl::barex::XChannel; | ||
| using SegmentID = Transport::SegmentID; | ||
| using XContext = accl::barex::XContext; | ||
| using BarexResult = accl::barex::BarexResult; | ||
|
|
||
| class ChannelCache { | ||
| public: | ||
| // put channel | ||
| void put(SegmentID key, int nic_id, XChannel* channel) { | ||
| RWSpinlock::WriteGuard guard(lock_); | ||
| auto& channels = cache_[key]; | ||
| auto& vec = channels[nic_id]; | ||
| status_map_[key] = true; | ||
| vec.push_back(channel); | ||
| } | ||
|
|
||
| // get channel | ||
| XChannel* find(SegmentID key, int nic_id, int idx) { | ||
| RWSpinlock::ReadGuard guard(lock_); | ||
| auto it = cache_.find(key); | ||
| if (it == cache_.end()) return nullptr; | ||
| auto& channels = it->second; | ||
| auto ch_it = channels.find(nic_id); | ||
| if (ch_it == channels.end()) return nullptr; | ||
| auto& vec = ch_it->second; | ||
| if (idx >= 0 && idx < static_cast<int>(vec.size())) { | ||
| return vec[idx]; | ||
| } | ||
| return nullptr; | ||
| } | ||
|
|
||
| // delete channel | ||
| bool erase(SegmentID key, int nic_id, int idx) { | ||
| RWSpinlock::WriteGuard guard(lock_); | ||
| auto it = cache_.find(key); | ||
| if (it == cache_.end()) return false; | ||
|
|
||
| auto& channels = it->second; | ||
| auto ch_it = channels.find(nic_id); | ||
| if (ch_it == channels.end()) return false; | ||
|
|
||
| auto& vec = ch_it->second; | ||
| if (idx < 0 || idx >= static_cast<int>(vec.size())) return false; | ||
|
|
||
| vec.erase(vec.begin() + idx); | ||
| status_map_[key] = false; | ||
| if (vec.empty()) { | ||
| channels.erase(ch_it); | ||
| if (channels.empty()) { | ||
| cache_.erase(it); | ||
| } | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| // get channel state | ||
| bool CheckAllChannels(SegmentID segment_id) { | ||
| RWSpinlock::ReadGuard guard(lock_); | ||
| auto it = cache_.find(segment_id); | ||
| if (it == cache_.end()) { | ||
| return false; | ||
| } | ||
| auto& inner_map = it->second; | ||
| for (auto& pair : inner_map) { | ||
| auto& channels = pair.second; | ||
| for (XChannel* channel : channels) { | ||
| if (!channel->IsActive()) { | ||
| return false; | ||
| } | ||
| } | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| // check and delete invalid channels | ||
| int RemoveInvalidChannels(SegmentID segment_id) { | ||
| RWSpinlock::WriteGuard guard(lock_); | ||
| auto it = cache_.find(segment_id); | ||
| if (it == cache_.end()) { | ||
| return 0; | ||
| } | ||
|
|
||
| int invalid_count = 0; | ||
| auto& inner_map = it->second; | ||
|
|
||
| for (auto& pair : inner_map) { | ||
| auto& channels = pair.second; | ||
| auto new_end = std::remove_if( | ||
| channels.begin(), channels.end(), | ||
| [](XChannel* channel) { return !channel->IsActive(); }); | ||
| invalid_count += std::distance(new_end, channels.end()); | ||
| channels.erase(new_end, channels.end()); | ||
| } | ||
| return invalid_count; | ||
| } | ||
|
|
||
| // get all channels | ||
| std::vector<XChannel*> copyAll() { | ||
| RWSpinlock::WriteGuard guard(lock_); | ||
| std::vector<XChannel*> result; | ||
| for (const auto& [key, channels] : cache_) { | ||
| for (const auto& [nic_id, vec] : channels) { | ||
| result.insert(result.end(), vec.begin(), vec.end()); | ||
| } | ||
| } | ||
| return result; | ||
| } | ||
|
|
||
| private: | ||
| std::unordered_map<SegmentID, | ||
| std::unordered_map<int, std::vector<XChannel*>>> | ||
| cache_; | ||
| std::unordered_map<SegmentID, bool> status_map_; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| RWSpinlock lock_; | ||
| }; | ||
| class BarexContext { | ||
| public: | ||
| int submitPostSend(const std::vector<Transport::Slice*>& slice_list); | ||
| int addChannel(SegmentID sid, int device_id, XChannel* ch); | ||
| XChannel* getChannel(SegmentID sid, int device_id, int idx); | ||
| int checkStatus(SegmentID sid); | ||
| XContext* getCtx(); | ||
| // int ClearAllChannel(); | ||
| std::vector<XChannel*> getAllChannel(); | ||
| bool active() const { return active_; } | ||
| void setQpNum(int qp_num) { qp_num_per_ctx_ = qp_num; } | ||
| int getQpNum() const { return qp_num_per_ctx_; } | ||
|
|
||
| public: | ||
| BarexContext(XContext* xcontext, bool use_cpu, int device_id); | ||
|
|
||
| ~BarexContext(); | ||
|
|
||
| XContext* xcontext_; | ||
| bool barex_use_cpu_; | ||
| int barex_local_device_; | ||
|
|
||
| private: | ||
| ChannelCache channel_cache_; | ||
| bool active_ = true; | ||
| int qp_num_per_ctx_ = 2; | ||
| }; | ||
| #endif | ||
| } // namespace mooncake | ||
|
|
||
| #endif // BAREX_CONTEXT_H_ | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if this has a compatible issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't worry. It works.