diff --git a/include/ylt/coro_io/ibverbs/ib_io.hpp b/include/ylt/coro_io/ibverbs/ib_io.hpp index 7b36a40e1..493681630 100644 --- a/include/ylt/coro_io/ibverbs/ib_io.hpp +++ b/include/ylt/coro_io/ibverbs/ib_io.hpp @@ -170,7 +170,6 @@ async_simple::coro:: is_canceled = true; } if (is_canceled) [[unlikely]] { - co_await coro_io::dispatch(ib_socket.get_executor()); ib_socket.close(); co_return std::pair{std::make_error_code(std::errc::operation_canceled), std::size_t{0}}; @@ -256,11 +255,9 @@ inline void reset_buffer(std::vector& buffer, std::size_t read_size) { template async_simple::coro::Lazy> -async_io_split(coro_io::ib_socket_t& ib_socket, Buffer&& raw_buffer, - bool read_some = false) { +async_io_split_impl(coro_io::ib_socket_t& ib_socket, Buffer&& raw_buffer, + bool read_some) { if (!ib_socket.is_open()) { - co_await coro_io::dispatch( - ib_socket.get_coro_executor()->get_asio_executor()); co_return std::pair{std::make_error_code(std::errc::not_connected), std::size_t{0}}; } @@ -268,8 +265,6 @@ async_io_split(coro_io::ib_socket_t& ib_socket, Buffer&& raw_buffer, make_sge(sge_list, raw_buffer); std::span sge_span = sge_list; if (sge_span.size() == 0) [[unlikely]] { - co_await coro_io::dispatch( - ib_socket.get_coro_executor()->get_asio_executor()); co_return std::pair{std::error_code{}, std::size_t{0}}; } @@ -278,8 +273,6 @@ async_io_split(coro_io::ib_socket_t& ib_socket, Buffer&& raw_buffer, uint32_t max_size = ib_socket.get_buffer_size(); std::size_t io_completed_size = consume_buffer(ib_socket, sge_span); if (sge_span.empty()) { - co_await coro_io::dispatch( - ib_socket.get_coro_executor()->get_asio_executor()); co_return std::pair{std::error_code{}, io_completed_size}; } @@ -379,7 +372,6 @@ async_io_split(coro_io::ib_socket_t& ib_socket, Buffer&& raw_buffer, is_canceled = true; } if (is_canceled) [[unlikely]] { - co_await coro_io::dispatch(ib_socket.get_executor()); ib_socket.close(); co_return std::pair{std::make_error_code(std::errc::operation_canceled), std::size_t{0}}; @@ -393,6 +385,19 @@ async_io_split(coro_io::ib_socket_t& ib_socket, Buffer&& raw_buffer, co_return std::pair{ec, io_completed_size}; } +template +async_simple::coro::Lazy> +async_io_split(coro_io::ib_socket_t& ib_socket, Buffer&& raw_buffer, + bool read_some = false) { + auto ret = co_await async_io_split_impl( + ib_socket, std::forward(raw_buffer), read_some); + if (!ib_socket.get_executor().running_in_this_thread()) [[unlikely]] { + // switch to io_thread + co_await dispatch(ib_socket.get_executor()); + } + co_return ret; +} + } // namespace detail template diff --git a/include/ylt/coro_io/ibverbs/ib_socket.hpp b/include/ylt/coro_io/ibverbs/ib_socket.hpp index ca4ef0cc2..e4f722ecc 100644 --- a/include/ylt/coro_io/ibverbs/ib_socket.hpp +++ b/include/ylt/coro_io/ibverbs/ib_socket.hpp @@ -91,7 +91,7 @@ struct ib_socket_shared_state_t std::queue> recv_result; // TODO optimize with circle buffer callback_t recv_cb_; - callback_t send_cb_; + std::deque send_cb_; ib_buffer_t recv_buf_; std::shared_ptr ib_buffer_pool_; std::unique_ptr fd_; @@ -182,7 +182,7 @@ struct ib_socket_shared_state_t struct resume_struct { std::error_code ec; std::size_t len; - callback_t* cb; + uint64_t wr_id; }; void post_send_impl(std::span sge, callback_t&& handler) { @@ -195,12 +195,11 @@ struct ib_socket_shared_state_t sge = std::move(sge_copy)]() mutable { ibv_send_wr sr{}; ibv_send_wr* bad_wr = nullptr; - self->send_cb_ = std::move(handler); if (sge.size() && sge[0].lkey == 0) { sr.send_flags = IBV_SEND_INLINE; } sr.next = NULL; - sr.wr_id = (uintptr_t)&self->send_cb_; + sr.wr_id = 1; sr.sg_list = sge.data(); sr.num_sge = sge.size(); sr.opcode = IBV_WR_SEND; @@ -211,13 +210,16 @@ struct ib_socket_shared_state_t } // post the receive request to the RQ else if (auto ec = ibv_post_send(self->qp_.get(), &sr, &bad_wr); ec) { - err = std::make_error_code(std::errc{std::abs(ec)}); + err = std::make_error_code(std::errc{ec}); ELOG_ERROR << "ibv post send failed: " << err.message(); } if (err) { ib_socket_shared_state_t::resume(std::pair{err, std::size_t{0}}, handler); } + else { + self->send_cb_.push_back(std::move(handler)); + } }, executor_->get_asio_executor()) .start([](auto&& res) { @@ -243,7 +245,7 @@ struct ib_socket_shared_state_t struct ibv_wc wc{}; int ne = 0; std::vector vec; - callback_t tmp_callback; + callback_t tmp_recv_callback; while ((ne = ibv_poll_cq(cq_.get(), 1, &wc)) != 0) { if (ne < 0) { ELOG_ERROR << "poll CQ failed:" << ne; @@ -275,8 +277,8 @@ struct ib_socket_shared_state_t close(); } } - tmp_callback = std::move(recv_cb_); - vec.push_back({ec, wc.byte_len, &tmp_callback}); + tmp_recv_callback = std::move(recv_cb_); + vec.push_back({ec, wc.byte_len, 0}); } else { recv_result.push(std::pair{ec, (std::size_t)wc.byte_len}); @@ -289,7 +291,7 @@ struct ib_socket_shared_state_t } } else { - vec.push_back({ec, wc.byte_len, (callback_t*)wc.wr_id}); + vec.push_back({ec, wc.byte_len, wc.wr_id}); } if (cq_ == nullptr) { break; @@ -297,7 +299,13 @@ struct ib_socket_shared_state_t } } for (auto& result : vec) { - resume(std::pair{result.ec, result.len}, *result.cb); + if (result.wr_id == 0) { + resume(std::pair{result.ec, result.len}, tmp_recv_callback); + } + else { + resume(std::pair{result.ec, result.len}, send_cb_.front()); + send_cb_.pop_front(); + } } return ec; } @@ -333,7 +341,7 @@ struct ibverbs_config { uint32_t cq_size = 128; uint32_t recv_buffer_cnt = 4; ibv_qp_type qp_type = IBV_QPT_RC; - ibv_qp_cap cap = {.max_send_wr = 2, + ibv_qp_cap cap = {.max_send_wr = 32, .max_recv_wr = 32, .max_send_sge = 3, .max_recv_sge = 1, @@ -389,7 +397,8 @@ class ib_socket_t { ~ib_socket_t() { close(); } bool is_open() const noexcept { - return state_->fd_ != nullptr && state_->fd_->is_open(); + return state_->fd_ != nullptr && state_->fd_->is_open() && + !state_->has_close_; } std::shared_ptr buffer_pool() const noexcept { return state_->ib_buffer_pool_; @@ -799,7 +808,7 @@ class ib_socket_t { void init_fd() { int r = ibv_req_notify_cq(state_->cq_.get(), 0); if (r) { - auto err_code = std::make_error_code(std::errc{errno}); + auto err_code = std::make_error_code(std::errc{r}); ELOG_ERROR << "ibv_req_notify_cq failed: " << err_code.message(); throw std::system_error(err_code); } @@ -823,12 +832,15 @@ class ib_socket_t { } if (ec) { - auto send_cb = std::move(self->send_cb_); self->close(); ib_socket_shared_state_t::resume(std::pair{ec, std::size_t{0}}, self->recv_cb_); - ib_socket_shared_state_t::resume(std::pair{ec, std::size_t{0}}, - send_cb); + while (!self->send_cb_.empty()) { + ib_socket_shared_state_t::resume(std::pair{ec, std::size_t{0}}, + self->send_cb_.front()); + self->send_cb_.pop_front(); + } + break; } } diff --git a/src/coro_io/tests/ibverbs/CMakeLists.txt b/src/coro_io/tests/ibverbs/CMakeLists.txt index 3186a1d0d..242e3b529 100644 --- a/src/coro_io/tests/ibverbs/CMakeLists.txt +++ b/src/coro_io/tests/ibverbs/CMakeLists.txt @@ -4,11 +4,16 @@ if(YLT_HAVE_IBVERBS) add_executable(ibverbs_test test_device.cpp - ib_socket_pressure_test.cpp test_ib_socket.cpp main.cpp ) + + add_executable(ibverbs_pressure_test + ib_socket_pressure_test.cpp + main.cpp) target_link_libraries(ibverbs_test -libverbs) + target_link_libraries(ibverbs_pressure_test -libverbs) add_test(NAME ibverbs_test COMMAND ibverbs_test) + add_test(NAME ibverbs_pressure_test COMMAND ibverbs_pressure_test) endif()