Skip to content

Commit f4d8a38

Browse files
committed
- Introduced AsyncTaskHubGrpcClient as async implementation of TaskHubGrpcClient
- Added e2e tests Signed-off-by: Patrick Assuied <patrick.assuied@elationhealth.com>
1 parent 06357df commit f4d8a38

13 files changed

+864
-8
lines changed

durabletask/aio/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .client import AsyncTaskHubGrpcClient
2+
3+
__all__ = [
4+
"AsyncTaskHubGrpcClient",
5+
]

durabletask/aio/client.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import logging
2+
import uuid
3+
from datetime import datetime
4+
from typing import Any, Optional, Sequence, Union
5+
6+
import grpc
7+
from google.protobuf import wrappers_pb2
8+
9+
import durabletask.internal.helpers as helpers
10+
import durabletask.internal.orchestrator_service_pb2 as pb
11+
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
12+
import durabletask.internal.shared as shared
13+
from durabletask.aio.internal.shared import get_grpc_aio_channel, AioClientInterceptor
14+
from durabletask import task
15+
from durabletask.client import OrchestrationState, OrchestrationStatus, new_orchestration_state, TInput, TOutput
16+
from durabletask.aio.internal.grpc_interceptor import DefaultAioClientInterceptorImpl
17+
18+
19+
class AsyncTaskHubGrpcClient:
20+
21+
def __init__(self, *,
22+
host_address: Optional[str] = None,
23+
metadata: Optional[list[tuple[str, str]]] = None,
24+
log_handler: Optional[logging.Handler] = None,
25+
log_formatter: Optional[logging.Formatter] = None,
26+
secure_channel: bool = False,
27+
interceptors: Optional[Sequence[AioClientInterceptor]] = None):
28+
29+
if interceptors is not None:
30+
interceptors = list(interceptors)
31+
if metadata is not None:
32+
interceptors.append(DefaultAioClientInterceptorImpl(metadata))
33+
elif metadata is not None:
34+
interceptors = [DefaultAioClientInterceptorImpl(metadata)]
35+
else:
36+
interceptors = None
37+
38+
channel = get_grpc_aio_channel(
39+
host_address=host_address,
40+
secure_channel=secure_channel,
41+
interceptors=interceptors
42+
)
43+
self._channel = channel
44+
self._stub = stubs.TaskHubSidecarServiceStub(channel)
45+
self._logger = shared.get_logger("client", log_handler, log_formatter)
46+
47+
async def aclose(self):
48+
await self._channel.close()
49+
50+
async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
51+
input: Optional[TInput] = None,
52+
instance_id: Optional[str] = None,
53+
start_at: Optional[datetime] = None,
54+
reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None) -> str:
55+
56+
name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)
57+
58+
req = pb.CreateInstanceRequest(
59+
name=name,
60+
instanceId=instance_id if instance_id else uuid.uuid4().hex,
61+
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
62+
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
63+
version=helpers.get_string_value(None),
64+
orchestrationIdReusePolicy=reuse_id_policy,
65+
)
66+
67+
self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.")
68+
res: pb.CreateInstanceResponse = await self._stub.StartInstance(req)
69+
return res.instanceId
70+
71+
async def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]:
72+
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
73+
res: pb.GetInstanceResponse = await self._stub.GetInstance(req)
74+
return new_orchestration_state(req.instanceId, res)
75+
76+
async def wait_for_orchestration_start(self, instance_id: str, *,
77+
fetch_payloads: bool = False,
78+
timeout: int = 0) -> Optional[OrchestrationState]:
79+
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
80+
try:
81+
grpc_timeout = None if timeout == 0 else timeout
82+
self._logger.info(
83+
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start.")
84+
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=grpc_timeout)
85+
return new_orchestration_state(req.instanceId, res)
86+
except grpc.RpcError as rpc_error:
87+
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
88+
# Replace gRPC error with the built-in TimeoutError
89+
raise TimeoutError("Timed-out waiting for the orchestration to start")
90+
else:
91+
raise
92+
93+
async def wait_for_orchestration_completion(self, instance_id: str, *,
94+
fetch_payloads: bool = True,
95+
timeout: int = 0) -> Optional[OrchestrationState]:
96+
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
97+
try:
98+
grpc_timeout = None if timeout == 0 else timeout
99+
self._logger.info(
100+
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete.")
101+
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=grpc_timeout)
102+
state = new_orchestration_state(req.instanceId, res)
103+
if not state:
104+
return None
105+
106+
if state.runtime_status == OrchestrationStatus.FAILED and state.failure_details is not None:
107+
details = state.failure_details
108+
self._logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}")
109+
elif state.runtime_status == OrchestrationStatus.TERMINATED:
110+
self._logger.info(f"Instance '{instance_id}' was terminated.")
111+
elif state.runtime_status == OrchestrationStatus.COMPLETED:
112+
self._logger.info(f"Instance '{instance_id}' completed.")
113+
114+
return state
115+
except grpc.RpcError as rpc_error:
116+
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
117+
# Replace gRPC error with the built-in TimeoutError
118+
raise TimeoutError("Timed-out waiting for the orchestration to complete")
119+
else:
120+
raise
121+
122+
async def raise_orchestration_event(
123+
self,
124+
instance_id: str,
125+
event_name: str,
126+
*,
127+
data: Optional[Any] = None):
128+
req = pb.RaiseEventRequest(
129+
instanceId=instance_id,
130+
name=event_name,
131+
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None)
132+
133+
self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
134+
await self._stub.RaiseEvent(req)
135+
136+
async def terminate_orchestration(self, instance_id: str, *,
137+
output: Optional[Any] = None,
138+
recursive: bool = True):
139+
req = pb.TerminateRequest(
140+
instanceId=instance_id,
141+
output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None,
142+
recursive=recursive)
143+
144+
self._logger.info(f"Terminating instance '{instance_id}'.")
145+
await self._stub.TerminateInstance(req)
146+
147+
async def suspend_orchestration(self, instance_id: str):
148+
req = pb.SuspendRequest(instanceId=instance_id)
149+
self._logger.info(f"Suspending instance '{instance_id}'.")
150+
await self._stub.SuspendInstance(req)
151+
152+
async def resume_orchestration(self, instance_id: str):
153+
req = pb.ResumeRequest(instanceId=instance_id)
154+
self._logger.info(f"Resuming instance '{instance_id}'.")
155+
await self._stub.ResumeInstance(req)
156+
157+
async def purge_orchestration(self, instance_id: str, recursive: bool = True):
158+
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
159+
self._logger.info(f"Purging instance '{instance_id}'.")
160+
await self._stub.PurgeInstances(req)

durabletask/aio/internal/__init__.py

Whitespace-only changes.
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from collections import namedtuple
2+
3+
from grpc import aio as grpc_aio
4+
5+
6+
class _AioClientCallDetails(
7+
namedtuple(
8+
'_AioClientCallDetails',
9+
['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']),
10+
grpc_aio.ClientCallDetails):
11+
pass
12+
13+
14+
class DefaultAioClientInterceptorImpl(
15+
grpc_aio.UnaryUnaryClientInterceptor, grpc_aio.UnaryStreamClientInterceptor,
16+
grpc_aio.StreamUnaryClientInterceptor, grpc_aio.StreamStreamClientInterceptor):
17+
"""Async gRPC client interceptor to add metadata to all calls."""
18+
19+
def __init__(self, metadata: list[tuple[str, str]]):
20+
super().__init__()
21+
self._metadata = metadata
22+
23+
def _intercept_call(self, client_call_details: _AioClientCallDetails) -> grpc_aio.ClientCallDetails:
24+
if self._metadata is None:
25+
return client_call_details
26+
27+
if client_call_details.metadata is not None:
28+
metadata = list(client_call_details.metadata)
29+
else:
30+
metadata = []
31+
32+
metadata.extend(self._metadata)
33+
return _AioClientCallDetails(
34+
client_call_details.method,
35+
client_call_details.timeout,
36+
metadata,
37+
client_call_details.credentials,
38+
client_call_details.wait_for_ready,
39+
client_call_details.compression)
40+
41+
async def intercept_unary_unary(self, continuation, client_call_details, request):
42+
new_client_call_details = self._intercept_call(client_call_details)
43+
return await continuation(new_client_call_details, request)
44+
45+
async def intercept_unary_stream(self, continuation, client_call_details, request):
46+
new_client_call_details = self._intercept_call(client_call_details)
47+
return await continuation(new_client_call_details, request)
48+
49+
async def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
50+
new_client_call_details = self._intercept_call(client_call_details)
51+
return await continuation(new_client_call_details, request_iterator)
52+
53+
async def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
54+
new_client_call_details = self._intercept_call(client_call_details)
55+
return await continuation(new_client_call_details, request_iterator)

durabletask/aio/internal/shared.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import Optional, Sequence, Union
2+
3+
import grpc
4+
from grpc import aio as grpc_aio
5+
6+
from durabletask.internal.shared import (
7+
get_default_host_address,
8+
SECURE_PROTOCOLS,
9+
INSECURE_PROTOCOLS,
10+
)
11+
12+
13+
AioClientInterceptor = Union[
14+
grpc_aio.UnaryUnaryClientInterceptor,
15+
grpc_aio.UnaryStreamClientInterceptor,
16+
grpc_aio.StreamUnaryClientInterceptor,
17+
grpc_aio.StreamStreamClientInterceptor
18+
]
19+
20+
21+
def get_grpc_aio_channel(
22+
host_address: Optional[str],
23+
secure_channel: bool = False,
24+
interceptors: Optional[Sequence[AioClientInterceptor]] = None) -> grpc_aio.Channel:
25+
26+
if host_address is None:
27+
host_address = get_default_host_address()
28+
29+
for protocol in SECURE_PROTOCOLS:
30+
if host_address.lower().startswith(protocol):
31+
secure_channel = True
32+
host_address = host_address[len(protocol):]
33+
break
34+
35+
for protocol in INSECURE_PROTOCOLS:
36+
if host_address.lower().startswith(protocol):
37+
secure_channel = False
38+
host_address = host_address[len(protocol):]
39+
break
40+
41+
if secure_channel:
42+
channel = grpc_aio.secure_channel(host_address, grpc.ssl_channel_credentials(), interceptors=interceptors)
43+
else:
44+
channel = grpc_aio.insecure_channel(host_address, interceptors=interceptors)
45+
46+
return channel

durabletask/task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ def get_tasks(self) -> list[Task]:
283283
def on_child_completed(self, task: Task[T]):
284284
pass
285285

286+
286287
class WhenAllTask(CompositeTask[list[T]]):
287288
"""A task that completes when all of its child tasks complete."""
288289

durabletask/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,13 +880,13 @@ class ExecutionResults:
880880
actions: list[pb.OrchestratorAction]
881881
encoded_custom_status: Optional[str]
882882

883-
884883
def __init__(
885884
self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]
886885
):
887886
self.actions = actions
888887
self.encoded_custom_status = encoded_custom_status
889888

889+
890890
class _OrchestrationExecutor:
891891
_generator: Optional[task.Orchestrator] = None
892892

tests/durabletask/test_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def test_get_grpc_channel_secure():
2121
get_grpc_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS)
2222
mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value)
2323

24+
2425
def test_get_grpc_channel_default_host_address():
2526
with patch('grpc.insecure_channel') as mock_channel:
2627
get_grpc_channel(None, False, interceptors=INTERCEPTORS)

0 commit comments

Comments
 (0)