diff --git a/olp-cpp-sdk-dataservice-read/src/RequestBroker.cpp b/olp-cpp-sdk-dataservice-read/src/RequestBroker.cpp new file mode 100644 index 000000000..9e2e68d9e --- /dev/null +++ b/olp-cpp-sdk-dataservice-read/src/RequestBroker.cpp @@ -0,0 +1,134 @@ +/* + * Copyright (C) 2019-2020 HERE Europe B.V. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + * License-Filename: LICENSE + */ + +#include "RequestBroker.h" + +namespace olp { +namespace dataservice { +namespace read { + +namespace { +std::string UniqueId() { + // TODO: generate a better unique id + static unsigned int id = 0; + return std::to_string(id++); +} +} // namespace + +void RequestBroker::RequestContext::AddCallback(CallerId id, + Callback callback) { + callbacks[id] = std::move(callback); +} + +void RequestBroker::RequestContext::PropagateResponse(DataResponse response) { + for (auto& callback : callbacks) { + callback.second(response); + } + callbacks.clear(); +} + +// Return true if the operation was canceled +bool RequestBroker::RequestContext::CancelRequest(CallerId id) { + // Cancel individual request + { + auto callback_it = callbacks.find(id); + if (callback_it != callbacks.end()) { + Callback callback = std::move(callback_it->second); + callback(client::ApiError(client::ErrorCode::Cancelled, "Canceled")); + callbacks.erase(callback_it); + } else { + assert(false); + } + } + + const bool cancel_operation = callbacks.empty(); + + if (cancel_operation) { + cancelation_context.CancelOperation(); + } + + return cancel_operation; +} + +client::CancellationContext +RequestBroker::RequestContext::CancelationContext() { + return cancelation_context; +} + +RequestBroker::CreateOrAssociateResult RequestBroker::CreateOrAssociateRequest( + RequestId req_id, Callback callback) { + const CallerId caller_id = UniqueId(); + GetOrCreateResult result = GetOrCreateContext(req_id); + result.ctx.AddCallback(caller_id, std::move(callback)); + return {result.ctx.CancelationContext(), CancelToken(req_id, caller_id), + result.just_created}; +} + +DataResponseCallback RequestBroker::ResponseHandler(RequestId req_id) { + return [=](DataResponse response) { + PropagateResponse(req_id, std::move(response)); + }; +} + +RequestBroker::GetOrCreateResult RequestBroker::GetOrCreateContext( + RequestId req_id) { + std::unique_lock lock(mutex_); + + auto request_ctx_it = request_map_.find(req_id); + if (request_ctx_it != request_map_.end()) { + return {request_ctx_it->second, false}; + } else { + request_ctx_it = + request_map_.insert(std::make_pair(req_id, RequestContext{})).first; + return {request_ctx_it->second, true}; + } +} + +void RequestBroker::PropagateResponse(RequestId req_id, DataResponse response) { + std::unique_lock lock(mutex_); + + auto request_ctx_it = request_map_.find(req_id); + if (request_ctx_it == request_map_.end()) { + assert(!response.IsSuccessful()); // Expect cancel here + return; + } + + auto ctx = std::move(request_ctx_it->second); + request_map_.erase(request_ctx_it); + ctx.PropagateResponse(std::move(response)); +} + +void RequestBroker::CancelRequest(RequestId req_id, CallerId id) { + std::unique_lock lock(mutex_); + + auto request_ctx_it = request_map_.find(req_id); + if (request_ctx_it == request_map_.end()) { + assert(false); + return; + } + + RequestContext& ctx = request_ctx_it->second; + if (ctx.CancelRequest(id)) { + request_map_.erase(request_ctx_it); + } +} + +} // namespace read +} // namespace dataservice +} // namespace olp diff --git a/olp-cpp-sdk-dataservice-read/src/RequestBroker.h b/olp-cpp-sdk-dataservice-read/src/RequestBroker.h new file mode 100644 index 000000000..33ca5cbc9 --- /dev/null +++ b/olp-cpp-sdk-dataservice-read/src/RequestBroker.h @@ -0,0 +1,87 @@ +/* + * Copyright (C) 2019-2020 HERE Europe B.V. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + * License-Filename: LICENSE + */ + +#pragma once + +#include + +#include +#include + +namespace olp { +namespace dataservice { +namespace read { + +class RequestBroker { + public: + using Callback = DataResponseCallback; + using RequestId = std::string; + + struct CreateOrAssociateResult { + client::CancellationContext context; + client::CancellationToken caller_cancelation_token; + bool just_created; + }; + + CreateOrAssociateResult CreateOrAssociateRequest(RequestId req_id, + Callback callback); + + DataResponseCallback ResponseHandler(RequestId req_id); + + private: + using CallerId = std::string; + + class RequestContext { + public: + void AddCallback(CallerId id, Callback callback); + void PropagateResponse(DataResponse response); + // Return true if the operation was canceled + bool CancelRequest(CallerId id); + + client::CancellationContext CancelationContext(); + + private: + client::CancellationContext cancelation_context; + std::unordered_map callbacks; + }; + + using RequestMap = std::unordered_map; + + inline client::CancellationToken CancelToken(RequestId req_id, CallerId id) { + return client::CancellationToken([=]() { CancelRequest(req_id, id); }); + } + + struct GetOrCreateResult { + RequestContext& ctx; + bool just_created; + }; + + GetOrCreateResult GetOrCreateContext(RequestId req_id); + + void PropagateResponse(RequestId req_id, DataResponse response); + + void CancelRequest(RequestId req_id, CallerId id); + + std::mutex mutex_; + RequestMap request_map_; +}; + +} // namespace read +} // namespace dataservice +} // namespace olp diff --git a/olp-cpp-sdk-dataservice-read/src/VersionedLayerClientImpl.cpp b/olp-cpp-sdk-dataservice-read/src/VersionedLayerClientImpl.cpp index 807849ce8..c11b6e7d6 100644 --- a/olp-cpp-sdk-dataservice-read/src/VersionedLayerClientImpl.cpp +++ b/olp-cpp-sdk-dataservice-read/src/VersionedLayerClientImpl.cpp @@ -152,24 +152,37 @@ client::CancellationToken VersionedLayerClientImpl::GetData( auto layer_id = layer_id_; auto settings = settings_; - auto data_task = - [=](client::CancellationContext context) mutable -> DataResponse { - if (!request.GetDataHandle()) { - auto version_response = GetVersion(request.GetBillingTag(), - request.GetFetchOption(), context); - if (!version_response.IsSuccessful()) { - return version_response.GetError(); + auto request_id = request.CreateKey(layer_id); + + auto result = + broker_.CreateOrAssociateRequest(request_id, std::move(callback)); + + if (result.just_created) { + // We just created a task, so trigger a real one and consume the + // CancellationContext provided. + auto data_task = + [=](client::CancellationContext context) mutable -> DataResponse { + if (!request.GetDataHandle()) { + auto version_response = GetVersion(request.GetBillingTag(), + request.GetFetchOption(), context); + if (!version_response.IsSuccessful()) { + return version_response.GetError(); + } + request.WithVersion(version_response.GetResult().GetVersion()); } - request.WithVersion(version_response.GetResult().GetVersion()); - } - return repository::DataRepository::GetVersionedData( - std::move(catalog), std::move(layer_id), std::move(request), context, - std::move(settings)); - }; + return repository::DataRepository::GetVersionedData( + std::move(catalog), std::move(layer_id), std::move(request), + context, std::move(settings)); + }; - return AddTask(settings.task_scheduler, pending_requests_, - std::move(data_task), std::move(callback)); + // We don't care about result here since we are passing the context + // outside. + AddTask(settings.task_scheduler, pending_requests_, std::move(data_task), + broker_.ResponseHandler(request_id), result.context); + } + + return result.caller_cancelation_token; }; return ScheduleFetch(std::move(schedule_get_data), std::move(request), diff --git a/olp-cpp-sdk-dataservice-read/src/VersionedLayerClientImpl.h b/olp-cpp-sdk-dataservice-read/src/VersionedLayerClientImpl.h index 401e89dd4..1740e46ff 100644 --- a/olp-cpp-sdk-dataservice-read/src/VersionedLayerClientImpl.h +++ b/olp-cpp-sdk-dataservice-read/src/VersionedLayerClientImpl.h @@ -21,7 +21,6 @@ #include -#include #include #include #include @@ -33,6 +32,8 @@ #include #include #include +#include +#include "RequestBroker.h" #include "repositories/ExecuteOrSchedule.inl" namespace olp { @@ -93,6 +94,7 @@ class VersionedLayerClientImpl { client::HRN catalog_; std::string layer_id_; client::OlpClientSettings settings_; + RequestBroker broker_; std::shared_ptr pending_requests_; std::atomic catalog_version_; }; diff --git a/tests/integration/olp-cpp-sdk-dataservice-read/VersionedLayerClientTest.cpp b/tests/integration/olp-cpp-sdk-dataservice-read/VersionedLayerClientTest.cpp index 967295e73..a19fdfae6 100644 --- a/tests/integration/olp-cpp-sdk-dataservice-read/VersionedLayerClientTest.cpp +++ b/tests/integration/olp-cpp-sdk-dataservice-read/VersionedLayerClientTest.cpp @@ -348,6 +348,170 @@ TEST_F(DataserviceReadVersionedLayerClientTest, GetDataFromPartitionAsync) { ASSERT_NE(response.GetResult()->size(), 0u); } +TEST_F(DataserviceReadVersionedLayerClientTest, GetDataFromPartitionMergeRequestsPositive) { + EXPECT_CALL(*network_mock_, Send(_, _, _, _, _)) + .WillOnce(ReturnHttpResponse(olp::http::NetworkResponse().WithStatus( + olp::http::HttpStatusCode::OK), + HTTP_RESPONSE_LOOKUP)) + .WillOnce(ReturnHttpResponse(olp::http::NetworkResponse().WithStatus( + olp::http::HttpStatusCode::OK), + kHttpResponsePartition_269)) + .WillOnce(ReturnHttpResponse(olp::http::NetworkResponse().WithStatus( + olp::http::HttpStatusCode::OK), + kHttpResponseBlobData_269)); + + auto catalog = olp::client::HRN::FromString( + GetArgument("dataservice_read_test_catalog")); + auto layer = GetArgument("dataservice_read_test_layer"); + auto version = std::stoi(GetArgument("dataservice_read_test_layer_version")); + + std::mutex mutex; + std::condition_variable cv; + + settings_->task_scheduler->ScheduleTask([&]() { + std::unique_lock lock(mutex); + cv.wait(lock); + }); + + auto client = std::make_shared( + catalog, layer, version, *settings_); + ASSERT_TRUE(client); + + auto promise1 = std::make_shared>(); + auto promise2 = std::make_shared>(); + std::future future1 = promise1->get_future(); + std::future future2 = promise2->get_future(); + auto partition = GetArgument("dataservice_read_test_partition"); + auto token1 = client->GetData( + olp::dataservice::read::DataRequest().WithPartitionId(partition), + [promise1](DataResponse response) { promise1->set_value(response); }); + + auto token2 = client->GetData( + olp::dataservice::read::DataRequest().WithPartitionId(partition), + [promise2](DataResponse response) { promise2->set_value(response); }); + + cv.notify_one(); + + ASSERT_NE(future1.wait_for(kWaitTimeout), std::future_status::timeout); + ASSERT_NE(future2.wait_for(kWaitTimeout), std::future_status::timeout); + DataResponse response1 = future1.get(); + DataResponse response2 = future2.get(); + + ASSERT_TRUE(response1.IsSuccessful()) << response1.GetError().GetMessage(); + ASSERT_NE(response1.GetResult(), nullptr); + ASSERT_EQ(response1.GetResult(), response2.GetResult()); + ASSERT_NE(response1.GetResult()->size(), 0u); +} + +TEST_F(DataserviceReadVersionedLayerClientTest, GetDataFromPartitionMergeRequestsCancelOne) { + EXPECT_CALL(*network_mock_, Send(_, _, _, _, _)) + .WillOnce(ReturnHttpResponse(olp::http::NetworkResponse().WithStatus( + olp::http::HttpStatusCode::OK), + HTTP_RESPONSE_LOOKUP)) + .WillOnce(ReturnHttpResponse(olp::http::NetworkResponse().WithStatus( + olp::http::HttpStatusCode::OK), + kHttpResponsePartition_269)) + .WillOnce(ReturnHttpResponse(olp::http::NetworkResponse().WithStatus( + olp::http::HttpStatusCode::OK), + kHttpResponseBlobData_269)); + + auto catalog = olp::client::HRN::FromString( + GetArgument("dataservice_read_test_catalog")); + auto layer = GetArgument("dataservice_read_test_layer"); + auto version = std::stoi(GetArgument("dataservice_read_test_layer_version")); + + std::mutex mutex; + std::condition_variable cv; + + settings_->task_scheduler->ScheduleTask([&]() { + std::unique_lock lock(mutex); + cv.wait(lock); + }); + + auto client = std::make_shared( + catalog, layer, version, *settings_); + ASSERT_TRUE(client); + + auto promise1 = std::make_shared>(); + auto promise2 = std::make_shared>(); + std::future future1 = promise1->get_future(); + std::future future2 = promise2->get_future(); + auto partition = GetArgument("dataservice_read_test_partition"); + auto token1 = client->GetData( + olp::dataservice::read::DataRequest().WithPartitionId(partition), + [promise1](DataResponse response) { promise1->set_value(response); }); + + auto token2 = client->GetData( + olp::dataservice::read::DataRequest().WithPartitionId(partition), + [promise2](DataResponse response) { promise2->set_value(response); }); + + token1.Cancel(); + + cv.notify_one(); + + ASSERT_NE(future1.wait_for(kWaitTimeout), std::future_status::timeout); + ASSERT_NE(future2.wait_for(kWaitTimeout), std::future_status::timeout); + DataResponse response1 = future1.get(); + DataResponse response2 = future2.get(); + + ASSERT_FALSE(response1.IsSuccessful()); + EXPECT_EQ(response1.GetError().GetErrorCode(), + olp::client::ErrorCode::Cancelled); + ASSERT_TRUE(response2.IsSuccessful()) << response1.GetError().GetMessage(); + ASSERT_NE(response2.GetResult(), nullptr); + ASSERT_NE(response2.GetResult()->size(), 0u); +} + + +TEST_F(DataserviceReadVersionedLayerClientTest, GetDataFromPartitionMergeRequestsCancelBoth) { + auto catalog = olp::client::HRN::FromString( + GetArgument("dataservice_read_test_catalog")); + auto layer = GetArgument("dataservice_read_test_layer"); + auto version = std::stoi(GetArgument("dataservice_read_test_layer_version")); + + std::mutex mutex; + std::condition_variable cv; + + settings_->task_scheduler->ScheduleTask([&]() { + std::unique_lock lock(mutex); + cv.wait(lock); + }); + + auto client = std::make_shared( + catalog, layer, version, *settings_); + ASSERT_TRUE(client); + + auto promise1 = std::make_shared>(); + auto promise2 = std::make_shared>(); + std::future future1 = promise1->get_future(); + std::future future2 = promise2->get_future(); + auto partition = GetArgument("dataservice_read_test_partition"); + auto token1 = client->GetData( + olp::dataservice::read::DataRequest().WithPartitionId(partition), + [promise1](DataResponse response) { promise1->set_value(response); }); + + auto token2 = client->GetData( + olp::dataservice::read::DataRequest().WithPartitionId(partition), + [promise2](DataResponse response) { promise2->set_value(response); }); + + token1.Cancel(); + token2.Cancel(); + + cv.notify_one(); + + ASSERT_NE(future1.wait_for(kWaitTimeout), std::future_status::timeout); + ASSERT_NE(future2.wait_for(kWaitTimeout), std::future_status::timeout); + DataResponse response1 = future1.get(); + DataResponse response2 = future2.get(); + + ASSERT_FALSE(response1.IsSuccessful()); + ASSERT_FALSE(response2.IsSuccessful()); + EXPECT_EQ(response1.GetError().GetErrorCode(), + olp::client::ErrorCode::Cancelled); + EXPECT_EQ(response2.GetError().GetErrorCode(), + olp::client::ErrorCode::Cancelled); +} + TEST_F(DataserviceReadVersionedLayerClientTest, GetDataFromPartitionAsyncWithCancellableFuture) { EXPECT_CALL(*network_mock_, Send(_, _, _, _, _))