Skip to content

Commit f3a9413

Browse files
authored
feat: service client cache (#685)
1 parent 6074a1b commit f3a9413

File tree

4 files changed

+88
-22
lines changed

4 files changed

+88
-22
lines changed

src/rai_core/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
44

55
[tool.poetry]
66
name = "rai_core"
7-
version = "2.2.1"
7+
version = "2.3.0"
88
description = "Core functionality for RAI framework"
99
authors = ["Maciej Majek <maciej.majek@robotec.ai>", "Bartłomiej Boczek <bartlomiej.boczek@robotec.ai>", "Kajetan Rachwał <kajetan.rachwal@robotec.ai>"]
1010
readme = "README.md"

src/rai_core/rai/communication/ros2/api/service.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import os
1616
import uuid
17+
from threading import Lock
1718
from typing import (
1819
Any,
1920
Callable,
@@ -30,6 +31,7 @@
3031
import rclpy.qos
3132
import rclpy.subscription
3233
import rclpy.task
34+
from rclpy.client import Client
3335
from rclpy.service import Service
3436

3537
from rai.communication.ros2.api.base import (
@@ -39,43 +41,76 @@
3941

4042

4143
class ROS2ServiceAPI(BaseROS2API):
42-
"""Handles ROS2 service operations including calling services."""
44+
"""Handles ROS 2 service operations including calling services."""
4345

4446
def __init__(self, node: rclpy.node.Node) -> None:
4547
self.node = node
4648
self._logger = node.get_logger()
4749
self._services: Dict[str, Service] = {}
50+
self._persistent_clients: Dict[str, Client] = {}
51+
self._persistent_clients_lock = Lock()
52+
53+
def release_client(self, service_name: str) -> bool:
54+
with self._persistent_clients_lock:
55+
return self._persistent_clients.pop(service_name, None) is not None
4856

4957
def call_service(
5058
self,
5159
service_name: str,
5260
service_type: str,
5361
request: Any,
5462
timeout_sec: float = 5.0,
63+
*,
64+
reuse_client: bool = True,
5565
) -> Any:
5666
"""
57-
Call a ROS2 service.
67+
Call a ROS 2 service.
5868
5969
Args:
60-
service_name: Name of the service to call
61-
service_type: ROS2 service type as string
62-
request: Request message content
70+
service_name: Fully-qualified service name.
71+
service_type: ROS 2 service type string (e.g., 'std_srvs/srv/SetBool').
72+
request: Request payload dict.
73+
timeout_sec: Seconds to wait for availability/response.
74+
reuse_client: Reuse a cached client. Client creation is synchronized; set
75+
False to create a new client per call.
6376
6477
Returns:
65-
The response message
78+
Response message instance.
79+
80+
Raises:
81+
ValueError: Service not available within the timeout.
82+
AttributeError: Service type or request cannot be constructed.
83+
84+
Note:
85+
With reuse_client=True, access to the cached client (including the
86+
service call) is serialized by a lock, preventing concurrent calls
87+
through the same client. Use reuse_client=False for per-call clients
88+
when concurrent service calls are required.
6689
"""
6790
srv_msg, srv_cls = self.build_ros2_service_request(service_type, request)
68-
service_client = self.node.create_client(srv_cls, service_name) # type: ignore
69-
client_ready = service_client.wait_for_service(timeout_sec=timeout_sec)
70-
if not client_ready:
71-
raise ValueError(
72-
f"Service {service_name} not ready within {timeout_sec} seconds. "
73-
"Try increasing the timeout or check if the service is running."
74-
)
75-
if os.getenv("ROS_DISTRO") == "humble":
76-
return service_client.call(srv_msg)
91+
92+
def _call_service(client: Client, timeout_sec: float) -> Any:
93+
is_service_available = client.wait_for_service(timeout_sec=timeout_sec)
94+
if not is_service_available:
95+
raise ValueError(
96+
f"Service {service_name} not ready within {timeout_sec} seconds. "
97+
"Try increasing the timeout or check if the service is running."
98+
)
99+
if os.getenv("ROS_DISTRO") == "humble":
100+
return client.call(srv_msg)
101+
else:
102+
return client.call(srv_msg, timeout_sec=timeout_sec)
103+
104+
if reuse_client:
105+
with self._persistent_clients_lock:
106+
client = self._persistent_clients.get(service_name, None)
107+
if client is None:
108+
client = self.node.create_client(srv_cls, service_name) # type: ignore
109+
self._persistent_clients[service_name] = client
110+
return _call_service(client, timeout_sec)
77111
else:
78-
return service_client.call(srv_msg, timeout_sec=timeout_sec)
112+
client = self.node.create_client(srv_cls, service_name) # type: ignore
113+
return _call_service(client, timeout_sec)
79114

80115
def get_service_names_and_types(self) -> List[Tuple[str, List[str]]]:
81116
return self.node.get_service_names_and_types()

src/rai_core/rai/communication/ros2/connectors/service_mixin.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,25 @@ def __post_init__(self, *args: Any, **kwargs: Any) -> None:
3030
f"{self.__class__.__name__} instance must have an attribute '_service_api' of type ROS2ServiceAPI"
3131
)
3232

33+
def release_client(self, service_name: str) -> bool:
34+
return self._service_api.release_client(service_name)
35+
3336
def service_call(
3437
self,
3538
message: ROS2Message,
3639
target: str,
3740
timeout_sec: float = 5.0,
3841
*,
3942
msg_type: str,
43+
reuse_client: bool = True,
4044
**kwargs: Any,
4145
) -> ROS2Message:
4246
msg = self._service_api.call_service(
4347
service_name=target,
4448
service_type=msg_type,
4549
request=message.payload,
4650
timeout_sec=timeout_sec,
51+
reuse_client=reuse_client,
4752
)
4853
return ROS2Message(
4954
payload=msg, metadata={"msg_type": str(type(msg)), "service": target}

tests/communication/ros2/test_api.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,14 @@ def test_ros2_single_message_publish_wrong_qos_setup(
138138
shutdown_executors_and_threads(executors, threads)
139139

140140

141-
def service_call_helper(service_name: str, service_api: ROS2ServiceAPI):
141+
def invoke_set_bool_service(
142+
service_name: str, service_api: ROS2ServiceAPI, reuse_client: bool = True
143+
):
142144
response = service_api.call_service(
143145
service_name,
144146
service_type="std_srvs/srv/SetBool",
145147
request={"data": True},
148+
reuse_client=reuse_client,
146149
)
147150
assert response.success
148151
assert response.message == "Test service called"
@@ -164,7 +167,7 @@ def test_ros2_service_single_call(
164167

165168
try:
166169
service_api = ROS2ServiceAPI(node)
167-
service_call_helper(service_name, service_api)
170+
invoke_set_bool_service(service_name, service_api)
168171
finally:
169172
shutdown_executors_and_threads(executors, threads)
170173

@@ -186,7 +189,30 @@ def test_ros2_service_multiple_calls(
186189
try:
187190
service_api = ROS2ServiceAPI(node)
188191
for _ in range(3):
189-
service_call_helper(service_name, service_api)
192+
invoke_set_bool_service(service_name, service_api, reuse_client=False)
193+
finally:
194+
shutdown_executors_and_threads(executors, threads)
195+
196+
197+
@pytest.mark.parametrize(
198+
"callback_group",
199+
[MutuallyExclusiveCallbackGroup(), ReentrantCallbackGroup()],
200+
ids=["MutuallyExclusiveCallbackGroup", "ReentrantCallbackGroup"],
201+
)
202+
def test_ros2_service_multiple_calls_with_reused_client(
203+
ros_setup: None, request: pytest.FixtureRequest, callback_group: CallbackGroup
204+
) -> None:
205+
service_name = f"{request.node.originalname}_service" # type: ignore
206+
node_name = f"{request.node.originalname}_node" # type: ignore
207+
service_server = ServiceServer(service_name, callback_group)
208+
node = Node(node_name)
209+
executors, threads = multi_threaded_spinner([service_server, node])
210+
211+
try:
212+
service_api = ROS2ServiceAPI(node)
213+
for _ in range(3):
214+
invoke_set_bool_service(service_name, service_api, reuse_client=True)
215+
assert service_api.release_client(service_name), "Client not released"
190216
finally:
191217
shutdown_executors_and_threads(executors, threads)
192218

@@ -210,7 +236,7 @@ def test_ros2_service_multiple_calls_at_the_same_time_threading(
210236
service_threads: List[threading.Thread] = []
211237
for _ in range(10):
212238
thread = threading.Thread(
213-
target=service_call_helper, args=(service_name, service_api)
239+
target=invoke_set_bool_service, args=(service_name, service_api)
214240
)
215241
service_threads.append(thread)
216242
thread.start()
@@ -241,7 +267,7 @@ def test_ros2_service_multiple_calls_at_the_same_time_multiprocessing(
241267
service_api = ROS2ServiceAPI(node)
242268
with Pool(10) as pool:
243269
pool.map(
244-
lambda _: service_call_helper(service_name, service_api), range(10)
270+
lambda _: invoke_set_bool_service(service_name, service_api), range(10)
245271
)
246272
finally:
247273
shutdown_executors_and_threads(executors, threads)

0 commit comments

Comments
 (0)