From f4d8a38e0db80271ff7722ff0fd7fab1faafb3ef Mon Sep 17 00:00:00 2001 From: Patrick Assuied Date: Sat, 27 Sep 2025 12:31:32 -0700 Subject: [PATCH 1/6] - Introduced `AsyncTaskHubGrpcClient` as async implementation of `TaskHubGrpcClient` - Added e2e tests Signed-off-by: Patrick Assuied --- durabletask/aio/__init__.py | 5 + durabletask/aio/client.py | 160 ++++++ durabletask/aio/internal/__init__.py | 0 durabletask/aio/internal/grpc_interceptor.py | 55 ++ durabletask/aio/internal/shared.py | 46 ++ durabletask/task.py | 1 + durabletask/worker.py | 2 +- tests/durabletask/test_client.py | 1 + tests/durabletask/test_client_async.py | 103 ++++ tests/durabletask/test_orchestration_e2e.py | 1 - .../test_orchestration_e2e_async.py | 487 ++++++++++++++++++ .../test_orchestration_executor.py | 4 +- tests/durabletask/test_orchestration_wait.py | 7 +- 13 files changed, 864 insertions(+), 8 deletions(-) create mode 100644 durabletask/aio/__init__.py create mode 100644 durabletask/aio/client.py create mode 100644 durabletask/aio/internal/__init__.py create mode 100644 durabletask/aio/internal/grpc_interceptor.py create mode 100644 durabletask/aio/internal/shared.py create mode 100644 tests/durabletask/test_client_async.py create mode 100644 tests/durabletask/test_orchestration_e2e_async.py diff --git a/durabletask/aio/__init__.py b/durabletask/aio/__init__.py new file mode 100644 index 0000000..d446228 --- /dev/null +++ b/durabletask/aio/__init__.py @@ -0,0 +1,5 @@ +from .client import AsyncTaskHubGrpcClient + +__all__ = [ + "AsyncTaskHubGrpcClient", +] diff --git a/durabletask/aio/client.py b/durabletask/aio/client.py new file mode 100644 index 0000000..51797f3 --- /dev/null +++ b/durabletask/aio/client.py @@ -0,0 +1,160 @@ +import logging +import uuid +from datetime import datetime +from typing import Any, Optional, Sequence, Union + +import grpc +from google.protobuf import wrappers_pb2 + +import durabletask.internal.helpers as helpers +import durabletask.internal.orchestrator_service_pb2 as pb +import durabletask.internal.orchestrator_service_pb2_grpc as stubs +import durabletask.internal.shared as shared +from durabletask.aio.internal.shared import get_grpc_aio_channel, AioClientInterceptor +from durabletask import task +from durabletask.client import OrchestrationState, OrchestrationStatus, new_orchestration_state, TInput, TOutput +from durabletask.aio.internal.grpc_interceptor import DefaultAioClientInterceptorImpl + + +class AsyncTaskHubGrpcClient: + + def __init__(self, *, + host_address: Optional[str] = None, + metadata: Optional[list[tuple[str, str]]] = None, + log_handler: Optional[logging.Handler] = None, + log_formatter: Optional[logging.Formatter] = None, + secure_channel: bool = False, + interceptors: Optional[Sequence[AioClientInterceptor]] = None): + + if interceptors is not None: + interceptors = list(interceptors) + if metadata is not None: + interceptors.append(DefaultAioClientInterceptorImpl(metadata)) + elif metadata is not None: + interceptors = [DefaultAioClientInterceptorImpl(metadata)] + else: + interceptors = None + + channel = get_grpc_aio_channel( + host_address=host_address, + secure_channel=secure_channel, + interceptors=interceptors + ) + self._channel = channel + self._stub = stubs.TaskHubSidecarServiceStub(channel) + self._logger = shared.get_logger("client", log_handler, log_formatter) + + async def aclose(self): + await self._channel.close() + + async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, + input: Optional[TInput] = None, + instance_id: Optional[str] = None, + start_at: Optional[datetime] = None, + reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None) -> str: + + name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator) + + req = pb.CreateInstanceRequest( + name=name, + instanceId=instance_id if instance_id else uuid.uuid4().hex, + input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None, + scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, + version=helpers.get_string_value(None), + orchestrationIdReusePolicy=reuse_id_policy, + ) + + self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.") + res: pb.CreateInstanceResponse = await self._stub.StartInstance(req) + return res.instanceId + + async def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]: + req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) + res: pb.GetInstanceResponse = await self._stub.GetInstance(req) + return new_orchestration_state(req.instanceId, res) + + async def wait_for_orchestration_start(self, instance_id: str, *, + fetch_payloads: bool = False, + timeout: int = 0) -> Optional[OrchestrationState]: + req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) + try: + grpc_timeout = None if timeout == 0 else timeout + self._logger.info( + f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start.") + res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=grpc_timeout) + return new_orchestration_state(req.instanceId, res) + except grpc.RpcError as rpc_error: + if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore + # Replace gRPC error with the built-in TimeoutError + raise TimeoutError("Timed-out waiting for the orchestration to start") + else: + raise + + async def wait_for_orchestration_completion(self, instance_id: str, *, + fetch_payloads: bool = True, + timeout: int = 0) -> Optional[OrchestrationState]: + req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) + try: + grpc_timeout = None if timeout == 0 else timeout + self._logger.info( + f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete.") + res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=grpc_timeout) + state = new_orchestration_state(req.instanceId, res) + if not state: + return None + + if state.runtime_status == OrchestrationStatus.FAILED and state.failure_details is not None: + details = state.failure_details + self._logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}") + elif state.runtime_status == OrchestrationStatus.TERMINATED: + self._logger.info(f"Instance '{instance_id}' was terminated.") + elif state.runtime_status == OrchestrationStatus.COMPLETED: + self._logger.info(f"Instance '{instance_id}' completed.") + + return state + except grpc.RpcError as rpc_error: + if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore + # Replace gRPC error with the built-in TimeoutError + raise TimeoutError("Timed-out waiting for the orchestration to complete") + else: + raise + + async def raise_orchestration_event( + self, + instance_id: str, + event_name: str, + *, + data: Optional[Any] = None): + req = pb.RaiseEventRequest( + instanceId=instance_id, + name=event_name, + input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None) + + self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") + await self._stub.RaiseEvent(req) + + async def terminate_orchestration(self, instance_id: str, *, + output: Optional[Any] = None, + recursive: bool = True): + req = pb.TerminateRequest( + instanceId=instance_id, + output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None, + recursive=recursive) + + self._logger.info(f"Terminating instance '{instance_id}'.") + await self._stub.TerminateInstance(req) + + async def suspend_orchestration(self, instance_id: str): + req = pb.SuspendRequest(instanceId=instance_id) + self._logger.info(f"Suspending instance '{instance_id}'.") + await self._stub.SuspendInstance(req) + + async def resume_orchestration(self, instance_id: str): + req = pb.ResumeRequest(instanceId=instance_id) + self._logger.info(f"Resuming instance '{instance_id}'.") + await self._stub.ResumeInstance(req) + + async def purge_orchestration(self, instance_id: str, recursive: bool = True): + req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive) + self._logger.info(f"Purging instance '{instance_id}'.") + await self._stub.PurgeInstances(req) diff --git a/durabletask/aio/internal/__init__.py b/durabletask/aio/internal/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/durabletask/aio/internal/grpc_interceptor.py b/durabletask/aio/internal/grpc_interceptor.py new file mode 100644 index 0000000..d2c1eb0 --- /dev/null +++ b/durabletask/aio/internal/grpc_interceptor.py @@ -0,0 +1,55 @@ +from collections import namedtuple + +from grpc import aio as grpc_aio + + +class _AioClientCallDetails( + namedtuple( + '_AioClientCallDetails', + ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']), + grpc_aio.ClientCallDetails): + pass + + +class DefaultAioClientInterceptorImpl( + grpc_aio.UnaryUnaryClientInterceptor, grpc_aio.UnaryStreamClientInterceptor, + grpc_aio.StreamUnaryClientInterceptor, grpc_aio.StreamStreamClientInterceptor): + """Async gRPC client interceptor to add metadata to all calls.""" + + def __init__(self, metadata: list[tuple[str, str]]): + super().__init__() + self._metadata = metadata + + def _intercept_call(self, client_call_details: _AioClientCallDetails) -> grpc_aio.ClientCallDetails: + if self._metadata is None: + return client_call_details + + if client_call_details.metadata is not None: + metadata = list(client_call_details.metadata) + else: + metadata = [] + + metadata.extend(self._metadata) + return _AioClientCallDetails( + client_call_details.method, + client_call_details.timeout, + metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, + client_call_details.compression) + + async def intercept_unary_unary(self, continuation, client_call_details, request): + new_client_call_details = self._intercept_call(client_call_details) + return await continuation(new_client_call_details, request) + + async def intercept_unary_stream(self, continuation, client_call_details, request): + new_client_call_details = self._intercept_call(client_call_details) + return await continuation(new_client_call_details, request) + + async def intercept_stream_unary(self, continuation, client_call_details, request_iterator): + new_client_call_details = self._intercept_call(client_call_details) + return await continuation(new_client_call_details, request_iterator) + + async def intercept_stream_stream(self, continuation, client_call_details, request_iterator): + new_client_call_details = self._intercept_call(client_call_details) + return await continuation(new_client_call_details, request_iterator) diff --git a/durabletask/aio/internal/shared.py b/durabletask/aio/internal/shared.py new file mode 100644 index 0000000..b15523d --- /dev/null +++ b/durabletask/aio/internal/shared.py @@ -0,0 +1,46 @@ +from typing import Optional, Sequence, Union + +import grpc +from grpc import aio as grpc_aio + +from durabletask.internal.shared import ( + get_default_host_address, + SECURE_PROTOCOLS, + INSECURE_PROTOCOLS, +) + + +AioClientInterceptor = Union[ + grpc_aio.UnaryUnaryClientInterceptor, + grpc_aio.UnaryStreamClientInterceptor, + grpc_aio.StreamUnaryClientInterceptor, + grpc_aio.StreamStreamClientInterceptor +] + + +def get_grpc_aio_channel( + host_address: Optional[str], + secure_channel: bool = False, + interceptors: Optional[Sequence[AioClientInterceptor]] = None) -> grpc_aio.Channel: + + if host_address is None: + host_address = get_default_host_address() + + for protocol in SECURE_PROTOCOLS: + if host_address.lower().startswith(protocol): + secure_channel = True + host_address = host_address[len(protocol):] + break + + for protocol in INSECURE_PROTOCOLS: + if host_address.lower().startswith(protocol): + secure_channel = False + host_address = host_address[len(protocol):] + break + + if secure_channel: + channel = grpc_aio.secure_channel(host_address, grpc.ssl_channel_credentials(), interceptors=interceptors) + else: + channel = grpc_aio.insecure_channel(host_address, interceptors=interceptors) + + return channel diff --git a/durabletask/task.py b/durabletask/task.py index 29af2c5..50970fd 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -283,6 +283,7 @@ def get_tasks(self) -> list[Task]: def on_child_completed(self, task: Task[T]): pass + class WhenAllTask(CompositeTask[list[T]]): """A task that completes when all of its child tasks complete.""" diff --git a/durabletask/worker.py b/durabletask/worker.py index 7a04649..e8e1fa9 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -880,13 +880,13 @@ class ExecutionResults: actions: list[pb.OrchestratorAction] encoded_custom_status: Optional[str] - def __init__( self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str] ): self.actions = actions self.encoded_custom_status = encoded_custom_status + class _OrchestrationExecutor: _generator: Optional[task.Orchestrator] = None diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index e5a8e9b..e750134 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -21,6 +21,7 @@ def test_get_grpc_channel_secure(): get_grpc_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS) mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value) + def test_get_grpc_channel_default_host_address(): with patch('grpc.insecure_channel') as mock_channel: get_grpc_channel(None, False, interceptors=INTERCEPTORS) diff --git a/tests/durabletask/test_client_async.py b/tests/durabletask/test_client_async.py new file mode 100644 index 0000000..691f39e --- /dev/null +++ b/tests/durabletask/test_client_async.py @@ -0,0 +1,103 @@ +from unittest.mock import ANY, patch + +from durabletask.aio.internal.grpc_interceptor import DefaultAioClientInterceptorImpl +from durabletask.internal.shared import get_default_host_address +from durabletask.aio.internal.shared import get_grpc_aio_channel +from durabletask.aio.client import AsyncTaskHubGrpcClient + + +HOST_ADDRESS = 'localhost:50051' +METADATA = [('key1', 'value1'), ('key2', 'value2')] +INTERCEPTORS_AIO = [DefaultAioClientInterceptorImpl(METADATA)] + + +def test_get_grpc_aio_channel_insecure(): + with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_channel: + get_grpc_aio_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS_AIO) + mock_channel.assert_called_once_with(HOST_ADDRESS, interceptors=INTERCEPTORS_AIO) + + +def test_get_grpc_aio_channel_secure(): + with patch('durabletask.aio.internal.shared.grpc_aio.secure_channel') as mock_channel, patch( + 'grpc.ssl_channel_credentials') as mock_credentials: + get_grpc_aio_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS_AIO) + mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value, interceptors=INTERCEPTORS_AIO) + + +def test_get_grpc_aio_channel_default_host_address(): + with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_channel: + get_grpc_aio_channel(None, False, interceptors=INTERCEPTORS_AIO) + mock_channel.assert_called_once_with(get_default_host_address(), interceptors=INTERCEPTORS_AIO) + + +def test_get_grpc_aio_channel_with_interceptors(): + with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_channel: + get_grpc_aio_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS_AIO) + mock_channel.assert_called_once_with(HOST_ADDRESS, interceptors=INTERCEPTORS_AIO) + + # Capture and check the arguments passed to insecure_channel() + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert 'interceptors' in kwargs + interceptors = kwargs['interceptors'] + assert isinstance(interceptors[0], DefaultAioClientInterceptorImpl) + assert interceptors[0]._metadata == METADATA + + +def test_grpc_aio_channel_with_host_name_protocol_stripping(): + with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_insecure_channel, patch( + 'durabletask.aio.internal.shared.grpc_aio.secure_channel') as mock_secure_channel: + + host_name = "myserver.com:1234" + + prefix = "grpc://" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO) + + prefix = "http://" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO) + + prefix = "HTTP://" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO) + + prefix = "GRPC://" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO) + + prefix = "" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_insecure_channel.assert_called_with(host_name, interceptors=INTERCEPTORS_AIO) + + prefix = "grpcs://" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO) + + prefix = "https://" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO) + + prefix = "HTTPS://" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO) + + prefix = "GRPCS://" + get_grpc_aio_channel(prefix + host_name, interceptors=INTERCEPTORS_AIO) + mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO) + + prefix = "" + get_grpc_aio_channel(prefix + host_name, True, interceptors=INTERCEPTORS_AIO) + mock_secure_channel.assert_called_with(host_name, ANY, interceptors=INTERCEPTORS_AIO) + + +def test_async_client_construct_with_metadata(): + with patch('durabletask.aio.internal.shared.grpc_aio.insecure_channel') as mock_channel: + AsyncTaskHubGrpcClient(host_address=HOST_ADDRESS, metadata=METADATA) + # Ensure channel created with an interceptor that has the expected metadata + args, kwargs = mock_channel.call_args + assert args[0] == HOST_ADDRESS + assert 'interceptors' in kwargs + interceptors = kwargs['interceptors'] + assert isinstance(interceptors[0], DefaultAioClientInterceptorImpl) + assert interceptors[0]._metadata == METADATA diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 2343184..76ec355 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -316,7 +316,6 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): output = "Recursive termination = {recurse}" task_hub_client.terminate_orchestration(instance_id, output=output, recursive=recurse) - metadata = task_hub_client.wait_for_orchestration_completion(instance_id, timeout=30) assert metadata is not None diff --git a/tests/durabletask/test_orchestration_e2e_async.py b/tests/durabletask/test_orchestration_e2e_async.py new file mode 100644 index 0000000..b35d33f --- /dev/null +++ b/tests/durabletask/test_orchestration_e2e_async.py @@ -0,0 +1,487 @@ +import asyncio +import json +import threading +from datetime import timedelta + +import pytest + +from durabletask.aio.client import AsyncTaskHubGrpcClient +from durabletask.client import OrchestrationStatus +from durabletask import task, worker + + +# NOTE: These tests assume a sidecar process is running. Example command: +# go install github.com/microsoft/durabletask-go@main +# durabletask-go --port 4001 +pytestmark = [pytest.mark.e2e, pytest.mark.anyio] + + +@pytest.fixture +def anyio_backend(): + return 'asyncio' + + +async def test_empty_orchestration(): + + invoked = False + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + nonlocal invoked # don't do this in a real app! + invoked = True + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(empty_orchestrator) + w.start() + + c = AsyncTaskHubGrpcClient() + id = await c.schedule_new_orchestration(empty_orchestrator) + state = await c.wait_for_orchestration_completion(id, timeout=30) + await c.aclose() + + assert invoked + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status is None + + +async def test_activity_sequence(): + + def plus_one(_: task.ActivityContext, input: int) -> int: + return input + 1 + + def sequence(ctx: task.OrchestrationContext, start_val: int): + numbers = [start_val] + current = start_val + for _ in range(10): + current = yield ctx.call_activity(plus_one, input=current) + numbers.append(current) + return numbers + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(sequence) + w.add_activity(plus_one) + w.start() + + client = AsyncTaskHubGrpcClient() + id = await client.schedule_new_orchestration(sequence, input=1) + state = await client.wait_for_orchestration_completion(id, timeout=30) + await client.aclose() + + assert state is not None + assert state.name == task.get_name(sequence) + assert state.instance_id == id + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.failure_details is None + assert state.serialized_input == json.dumps(1) + assert state.serialized_output == json.dumps([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + assert state.serialized_custom_status is None + + +async def test_activity_error_handling(): + + def throw(_: task.ActivityContext, input: int) -> int: + raise RuntimeError("Kah-BOOOOM!!!") + + compensation_counter = 0 + + def increment_counter(ctx, _): + nonlocal compensation_counter + compensation_counter += 1 + + def orchestrator(ctx: task.OrchestrationContext, input: int): + error_msg = "" + try: + yield ctx.call_activity(throw, input=input) + except task.TaskFailedError as e: + error_msg = e.details.message + + # compensating actions + yield ctx.call_activity(increment_counter) + yield ctx.call_activity(increment_counter) + + return error_msg + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.add_activity(throw) + w.add_activity(increment_counter) + w.start() + + client = AsyncTaskHubGrpcClient() + id = await client.schedule_new_orchestration(orchestrator, input=1) + state = await client.wait_for_orchestration_completion(id, timeout=30) + await client.aclose() + + assert state is not None + assert state.name == task.get_name(orchestrator) + assert state.instance_id == id + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Kah-BOOOOM!!!") + assert state.failure_details is None + assert state.serialized_custom_status is None + assert compensation_counter == 2 + + +async def test_sub_orchestration_fan_out(): + threadLock = threading.Lock() + activity_counter = 0 + + def increment(ctx, _): + with threadLock: + nonlocal activity_counter + activity_counter += 1 + + def orchestrator_child(ctx: task.OrchestrationContext, activity_count: int): + for _ in range(activity_count): + yield ctx.call_activity(increment) + + def parent_orchestrator(ctx: task.OrchestrationContext, count: int): + # Fan out to multiple sub-orchestrations + tasks = [] + for _ in range(count): + tasks.append(ctx.call_sub_orchestrator( + orchestrator_child, input=3)) + # Wait for all sub-orchestrations to complete + yield task.when_all(tasks) + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_activity(increment) + w.add_orchestrator(orchestrator_child) + w.add_orchestrator(parent_orchestrator) + w.start() + + client = AsyncTaskHubGrpcClient() + id = await client.schedule_new_orchestration(parent_orchestrator, input=10) + state = await client.wait_for_orchestration_completion(id, timeout=30) + await client.aclose() + + assert state is not None + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.failure_details is None + assert activity_counter == 30 + + +async def test_wait_for_multiple_external_events(): + def orchestrator(ctx: task.OrchestrationContext, _): + a = yield ctx.wait_for_external_event('A') + b = yield ctx.wait_for_external_event('B') + c = yield ctx.wait_for_external_event('C') + return [a, b, c] + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.start() + + # Start the orchestration and immediately raise events to it. + client = AsyncTaskHubGrpcClient() + id = await client.schedule_new_orchestration(orchestrator) + await client.raise_orchestration_event(id, 'A', data='a') + await client.raise_orchestration_event(id, 'B', data='b') + await client.raise_orchestration_event(id, 'C', data='c') + state = await client.wait_for_orchestration_completion(id, timeout=30) + await client.aclose() + + assert state is not None + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(['a', 'b', 'c']) + + +@pytest.mark.parametrize("raise_event", [True, False]) +async def test_wait_for_external_event_timeout(raise_event: bool): + def orchestrator(ctx: task.OrchestrationContext, _): + approval: task.Task[bool] = ctx.wait_for_external_event('Approval') + timeout = ctx.create_timer(timedelta(seconds=3)) + winner = yield task.when_any([approval, timeout]) + if winner == approval: + return "approved" + else: + return "timed out" + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.start() + + # Start the orchestration and immediately raise events to it. + client = AsyncTaskHubGrpcClient() + id = await client.schedule_new_orchestration(orchestrator) + if raise_event: + await client.raise_orchestration_event(id, 'Approval') + state = await client.wait_for_orchestration_completion(id, timeout=30) + await client.aclose() + + assert state is not None + assert state.runtime_status == OrchestrationStatus.COMPLETED + if raise_event: + assert state.serialized_output == json.dumps("approved") + else: + assert state.serialized_output == json.dumps("timed out") + + +async def test_suspend_and_resume(): + def orchestrator(ctx: task.OrchestrationContext, _): + result = yield ctx.wait_for_external_event("my_event") + return result + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.start() + + client = AsyncTaskHubGrpcClient() + id = await client.schedule_new_orchestration(orchestrator) + state = await client.wait_for_orchestration_start(id, timeout=30) + assert state is not None + + # Suspend the orchestration and wait for it to go into the SUSPENDED state + await client.suspend_orchestration(id) + while state.runtime_status == OrchestrationStatus.RUNNING: + await asyncio.sleep(0.1) + state = await client.get_orchestration_state(id) + assert state is not None + assert state.runtime_status == OrchestrationStatus.SUSPENDED + + # Raise an event to the orchestration and confirm that it does NOT complete + await client.raise_orchestration_event(id, "my_event", data=42) + try: + state = await client.wait_for_orchestration_completion(id, timeout=3) + assert False, "Orchestration should not have completed" + except TimeoutError: + pass + + # Resume the orchestration and wait for it to complete + await client.resume_orchestration(id) + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(42) + await client.aclose() + + +async def test_terminate(): + def orchestrator(ctx: task.OrchestrationContext, _): + result = yield ctx.wait_for_external_event("my_event") + return result + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.start() + + client = AsyncTaskHubGrpcClient() + id = await client.schedule_new_orchestration(orchestrator) + state = await client.wait_for_orchestration_start(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.RUNNING + + await client.terminate_orchestration(id, output="some reason for termination") + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.TERMINATED + assert state.serialized_output == json.dumps("some reason for termination") + await client.aclose() + + +async def test_terminate_recursive(): + def root(ctx: task.OrchestrationContext, _): + result = yield ctx.call_sub_orchestrator(child) + return result + + def child(ctx: task.OrchestrationContext, _): + result = yield ctx.wait_for_external_event("my_event") + return result + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(root) + w.add_orchestrator(child) + w.start() + + client = AsyncTaskHubGrpcClient() + id = await client.schedule_new_orchestration(root) + state = await client.wait_for_orchestration_start(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.RUNNING + + # Terminate root orchestration(recursive set to True by default) + await client.terminate_orchestration(id, output="some reason for termination") + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.TERMINATED + + # Verify that child orchestration is also terminated + await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.TERMINATED + + await client.purge_orchestration(id) + state = await client.get_orchestration_state(id) + assert state is None + await client.aclose() + + +async def test_continue_as_new(): + all_results = [] + + def orchestrator(ctx: task.OrchestrationContext, input: int): + result = yield ctx.wait_for_external_event("my_event") + if not ctx.is_replaying: + # NOTE: Real orchestrations should never interact with nonlocal variables like this. + nonlocal all_results # noqa: F824 + all_results.append(result) + + if len(all_results) <= 4: + ctx.continue_as_new(max(all_results), save_events=True) + else: + return all_results + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.start() + + client = AsyncTaskHubGrpcClient() + id = await client.schedule_new_orchestration(orchestrator, input=0) + await client.raise_orchestration_event(id, "my_event", data=1) + await client.raise_orchestration_event(id, "my_event", data=2) + await client.raise_orchestration_event(id, "my_event", data=3) + await client.raise_orchestration_event(id, "my_event", data=4) + await client.raise_orchestration_event(id, "my_event", data=5) + + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(all_results) + assert state.serialized_input == json.dumps(4) + assert all_results == [1, 2, 3, 4, 5] + await client.aclose() + + +async def test_retry_policies(): + # This test verifies that the retry policies are working as expected. + # It does this by creating an orchestration that calls a sub-orchestrator, + # which in turn calls an activity that always fails. + # In this test, the retry policies are added, and the orchestration + # should still fail. But, number of times the sub-orchestrator and activity + # is called should increase as per the retry policies. + + child_orch_counter = 0 + throw_activity_counter = 0 + + # Second setup: With retry policies + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + backoff_coefficient=1, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=30)) + + def parent_orchestrator_with_retry(ctx: task.OrchestrationContext, _): + yield ctx.call_sub_orchestrator(child_orchestrator_with_retry, retry_policy=retry_policy) + + def child_orchestrator_with_retry(ctx: task.OrchestrationContext, _): + nonlocal child_orch_counter + if not ctx.is_replaying: + # NOTE: Real orchestrations should never interact with nonlocal variables like this. + # This is done only for testing purposes. + child_orch_counter += 1 + yield ctx.call_activity(throw_activity_with_retry, retry_policy=retry_policy) + + def throw_activity_with_retry(ctx: task.ActivityContext, _): + nonlocal throw_activity_counter + throw_activity_counter += 1 + raise RuntimeError("Kah-BOOOOM!!!") + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(parent_orchestrator_with_retry) + w.add_orchestrator(child_orchestrator_with_retry) + w.add_activity(throw_activity_with_retry) + w.start() + + client = AsyncTaskHubGrpcClient() + id = await client.schedule_new_orchestration(parent_orchestrator_with_retry) + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.FAILED + assert state.failure_details is not None + assert state.failure_details.error_type == "TaskFailedError" + assert state.failure_details.message.startswith("Sub-orchestration task #1 failed:") + assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!") + assert state.failure_details.stack_trace is not None + assert throw_activity_counter == 9 + assert child_orch_counter == 3 + await client.aclose() + + +async def test_retry_timeout(): + # This test verifies that the retry timeout is working as expected. + # Max number of attempts is 5 and retry timeout is 14 seconds. + # Total seconds consumed till 4th attempt is 1 + 2 + 4 + 8 = 15 seconds. + # So, the 5th attempt should not be made and the orchestration should fail. + throw_activity_counter = 0 + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=5, + backoff_coefficient=2, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=14)) + + def mock_orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.call_activity(throw_activity, retry_policy=retry_policy) + + def throw_activity(ctx: task.ActivityContext, _): + nonlocal throw_activity_counter + throw_activity_counter += 1 + raise RuntimeError("Kah-BOOOOM!!!") + + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(mock_orchestrator) + w.add_activity(throw_activity) + w.start() + + client = AsyncTaskHubGrpcClient() + id = await client.schedule_new_orchestration(mock_orchestrator) + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.FAILED + assert state.failure_details is not None + assert state.failure_details.error_type == "TaskFailedError" + assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!") + assert state.failure_details.stack_trace is not None + assert throw_activity_counter == 4 + await client.aclose() + + +async def test_custom_status(): + + def empty_orchestrator(ctx: task.OrchestrationContext, _): + ctx.set_custom_status("foobaz") + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(empty_orchestrator) + w.start() + + c = AsyncTaskHubGrpcClient() + id = await c.schedule_new_orchestration(empty_orchestrator) + state = await c.wait_for_orchestration_completion(id, timeout=30) + await c.aclose() + + assert state is not None + assert state.name == task.get_name(empty_orchestrator) + assert state.instance_id == id + assert state.failure_details is None + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_input is None + assert state.serialized_output is None + assert state.serialized_custom_status == "\"foobaz\"" diff --git a/tests/durabletask/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py index 21f6c6c..c784135 100644 --- a/tests/durabletask/test_orchestration_executor.py +++ b/tests/durabletask/test_orchestration_executor.py @@ -634,7 +634,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): yield ctx.call_sub_orchestrator(suborchestrator, input=None) registry = worker._Registry() - suborchestrator_name = registry.add_orchestrator(suborchestrator) + registry.add_orchestrator(suborchestrator) orchestrator_name = registry.add_orchestrator(orchestrator) exec_evt = helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None) @@ -666,7 +666,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): yield ctx.call_sub_orchestrator(suborchestrator, input=None, app_id="target-app") registry = worker._Registry() - suborchestrator_name = registry.add_orchestrator(suborchestrator) + registry.add_orchestrator(suborchestrator) orchestrator_name = registry.add_orchestrator(orchestrator) exec_evt = helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None) diff --git a/tests/durabletask/test_orchestration_wait.py b/tests/durabletask/test_orchestration_wait.py index 03f7e30..c27345f 100644 --- a/tests/durabletask/test_orchestration_wait.py +++ b/tests/durabletask/test_orchestration_wait.py @@ -1,11 +1,9 @@ -from unittest.mock import patch, ANY, Mock +from unittest.mock import Mock from durabletask.client import TaskHubGrpcClient -from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl -from durabletask.internal.shared import (get_default_host_address, - get_grpc_channel) import pytest + @pytest.mark.parametrize("timeout", [None, 0, 5]) def test_wait_for_orchestration_start_timeout(timeout): instance_id = "test-instance" @@ -34,6 +32,7 @@ def test_wait_for_orchestration_start_timeout(timeout): else: assert kwargs.get('timeout') == timeout + @pytest.mark.parametrize("timeout", [None, 0, 5]) def test_wait_for_orchestration_completion_timeout(timeout): instance_id = "test-instance" From f3f1c4babc8c6959cc5b9b28fc21900f1ad4964b Mon Sep 17 00:00:00 2001 From: Patrick Assuied Date: Sun, 28 Sep 2025 08:54:44 -0700 Subject: [PATCH 2/6] Switch to pytest-asyncio fixed dev dependencies Signed-off-by: Patrick Assuied --- .github/workflows/pr-validation.yml | 2 +- README.md | 2 ++ dev-requirements.txt | 4 ++++ requirements.txt | 2 -- tests/durabletask/test_orchestration_e2e_async.py | 7 +------ 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 63540ac..33de31f 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -28,7 +28,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 pytest + pip install flake8 pytest pytest-cov pytest-asyncio pip install -r requirements.txt - name: Lint with flake8 run: | diff --git a/README.md b/README.md index 4a45d9b..3f691c0 100644 --- a/README.md +++ b/README.md @@ -173,6 +173,7 @@ This will download the `orchestrator_service.proto` from the `microsoft/durablet Unit tests can be run using the following command from the project root. Unit tests _don't_ require a sidecar process to be running. ```sh +pip3 install -r dev-requirements.txt make test-unit ``` @@ -188,6 +189,7 @@ durabletask-go --port 4001 To run the E2E tests, run the following command from the project root: ```sh +pip3 install -r dev-requirements.txt make test-e2e ``` diff --git a/dev-requirements.txt b/dev-requirements.txt index 119f072..58f0b35 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1 +1,5 @@ grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python +pytest +pytest-cov +pytest-asyncio +flake8 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 07426eb..41566b3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,4 @@ autopep8 grpcio>=1.60.0 # 1.60.0 is the version introducing protobuf 1.25.X support, newer versions are backwards compatible protobuf -pytest -pytest-cov asyncio diff --git a/tests/durabletask/test_orchestration_e2e_async.py b/tests/durabletask/test_orchestration_e2e_async.py index b35d33f..7717840 100644 --- a/tests/durabletask/test_orchestration_e2e_async.py +++ b/tests/durabletask/test_orchestration_e2e_async.py @@ -13,12 +13,7 @@ # NOTE: These tests assume a sidecar process is running. Example command: # go install github.com/microsoft/durabletask-go@main # durabletask-go --port 4001 -pytestmark = [pytest.mark.e2e, pytest.mark.anyio] - - -@pytest.fixture -def anyio_backend(): - return 'asyncio' +pytestmark = [pytest.mark.e2e, pytest.mark.asyncio] async def test_empty_orchestration(): From a21da3f5fecf404e6a000a50f754366bada6dd7a Mon Sep 17 00:00:00 2001 From: Patrick Assuied Date: Tue, 30 Sep 2025 11:10:26 -0700 Subject: [PATCH 3/6] Rename classes to avoid repeating `Aio` per PR feedback. Also cleaning dependencies to align protobuf dependencies between grpc-tools and grpcio Signed-off-by: Patrick Assuied --- dev-requirements.txt | 2 +- durabletask/aio/client.py | 10 +++++----- durabletask/aio/internal/grpc_interceptor.py | 10 +++++----- durabletask/aio/internal/shared.py | 4 ++-- requirements.txt | 2 +- tests/durabletask/test_client_async.py | 8 ++++---- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 58f0b35..80d1ba7 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,4 +1,4 @@ -grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python +grpcio-tools>=1.75.1 # supports protobuf 6.x and aligns with generated code pytest pytest-cov pytest-asyncio diff --git a/durabletask/aio/client.py b/durabletask/aio/client.py index 51797f3..ee5abd7 100644 --- a/durabletask/aio/client.py +++ b/durabletask/aio/client.py @@ -10,10 +10,10 @@ import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared -from durabletask.aio.internal.shared import get_grpc_aio_channel, AioClientInterceptor +from durabletask.aio.internal.shared import get_grpc_aio_channel, ClientInterceptor from durabletask import task from durabletask.client import OrchestrationState, OrchestrationStatus, new_orchestration_state, TInput, TOutput -from durabletask.aio.internal.grpc_interceptor import DefaultAioClientInterceptorImpl +from durabletask.aio.internal.grpc_interceptor import DefaultClientInterceptorImpl class AsyncTaskHubGrpcClient: @@ -24,14 +24,14 @@ def __init__(self, *, log_handler: Optional[logging.Handler] = None, log_formatter: Optional[logging.Formatter] = None, secure_channel: bool = False, - interceptors: Optional[Sequence[AioClientInterceptor]] = None): + interceptors: Optional[Sequence[ClientInterceptor]] = None): if interceptors is not None: interceptors = list(interceptors) if metadata is not None: - interceptors.append(DefaultAioClientInterceptorImpl(metadata)) + interceptors.append(DefaultClientInterceptorImpl(metadata)) elif metadata is not None: - interceptors = [DefaultAioClientInterceptorImpl(metadata)] + interceptors = [DefaultClientInterceptorImpl(metadata)] else: interceptors = None diff --git a/durabletask/aio/internal/grpc_interceptor.py b/durabletask/aio/internal/grpc_interceptor.py index d2c1eb0..06dae95 100644 --- a/durabletask/aio/internal/grpc_interceptor.py +++ b/durabletask/aio/internal/grpc_interceptor.py @@ -3,15 +3,15 @@ from grpc import aio as grpc_aio -class _AioClientCallDetails( +class _ClientCallDetails( namedtuple( - '_AioClientCallDetails', + '_ClientCallDetails', ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']), grpc_aio.ClientCallDetails): pass -class DefaultAioClientInterceptorImpl( +class DefaultClientInterceptorImpl( grpc_aio.UnaryUnaryClientInterceptor, grpc_aio.UnaryStreamClientInterceptor, grpc_aio.StreamUnaryClientInterceptor, grpc_aio.StreamStreamClientInterceptor): """Async gRPC client interceptor to add metadata to all calls.""" @@ -20,7 +20,7 @@ def __init__(self, metadata: list[tuple[str, str]]): super().__init__() self._metadata = metadata - def _intercept_call(self, client_call_details: _AioClientCallDetails) -> grpc_aio.ClientCallDetails: + def _intercept_call(self, client_call_details: _ClientCallDetails) -> grpc_aio.ClientCallDetails: if self._metadata is None: return client_call_details @@ -30,7 +30,7 @@ def _intercept_call(self, client_call_details: _AioClientCallDetails) -> grpc_ai metadata = [] metadata.extend(self._metadata) - return _AioClientCallDetails( + return _ClientCallDetails( client_call_details.method, client_call_details.timeout, metadata, diff --git a/durabletask/aio/internal/shared.py b/durabletask/aio/internal/shared.py index b15523d..3e09ff3 100644 --- a/durabletask/aio/internal/shared.py +++ b/durabletask/aio/internal/shared.py @@ -10,7 +10,7 @@ ) -AioClientInterceptor = Union[ +ClientInterceptor = Union[ grpc_aio.UnaryUnaryClientInterceptor, grpc_aio.UnaryStreamClientInterceptor, grpc_aio.StreamUnaryClientInterceptor, @@ -21,7 +21,7 @@ def get_grpc_aio_channel( host_address: Optional[str], secure_channel: bool = False, - interceptors: Optional[Sequence[AioClientInterceptor]] = None) -> grpc_aio.Channel: + interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc_aio.Channel: if host_address is None: host_address = get_default_host_address() diff --git a/requirements.txt b/requirements.txt index 41566b3..0f47c7f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ autopep8 grpcio>=1.60.0 # 1.60.0 is the version introducing protobuf 1.25.X support, newer versions are backwards compatible -protobuf +protobuf>=6,<7 asyncio diff --git a/tests/durabletask/test_client_async.py b/tests/durabletask/test_client_async.py index 691f39e..6e2b919 100644 --- a/tests/durabletask/test_client_async.py +++ b/tests/durabletask/test_client_async.py @@ -1,6 +1,6 @@ from unittest.mock import ANY, patch -from durabletask.aio.internal.grpc_interceptor import DefaultAioClientInterceptorImpl +from durabletask.aio.internal.grpc_interceptor import DefaultClientInterceptorImpl from durabletask.internal.shared import get_default_host_address from durabletask.aio.internal.shared import get_grpc_aio_channel from durabletask.aio.client import AsyncTaskHubGrpcClient @@ -8,7 +8,7 @@ HOST_ADDRESS = 'localhost:50051' METADATA = [('key1', 'value1'), ('key2', 'value2')] -INTERCEPTORS_AIO = [DefaultAioClientInterceptorImpl(METADATA)] +INTERCEPTORS_AIO = [DefaultClientInterceptorImpl(METADATA)] def test_get_grpc_aio_channel_insecure(): @@ -40,7 +40,7 @@ def test_get_grpc_aio_channel_with_interceptors(): assert args[0] == HOST_ADDRESS assert 'interceptors' in kwargs interceptors = kwargs['interceptors'] - assert isinstance(interceptors[0], DefaultAioClientInterceptorImpl) + assert isinstance(interceptors[0], DefaultClientInterceptorImpl) assert interceptors[0]._metadata == METADATA @@ -99,5 +99,5 @@ def test_async_client_construct_with_metadata(): assert args[0] == HOST_ADDRESS assert 'interceptors' in kwargs interceptors = kwargs['interceptors'] - assert isinstance(interceptors[0], DefaultAioClientInterceptorImpl) + assert isinstance(interceptors[0], DefaultClientInterceptorImpl) assert interceptors[0]._metadata == METADATA From eac4b8269d746c7015c032d637b58de76523201f Mon Sep 17 00:00:00 2001 From: Patrick Assuied Date: Tue, 30 Sep 2025 11:30:35 -0700 Subject: [PATCH 4/6] Fixed a bug where `when_all()` and `when_any()` are passed empty lists they return successfully. Added corresponding unit tests for happy path and edge case Signed-off-by: Patrick Assuied --- durabletask/task.py | 8 +++ .../test_orchestration_e2e_async.py | 2 + tests/durabletask/test_task.py | 67 +++++++++++++++++++ 3 files changed, 77 insertions(+) create mode 100644 tests/durabletask/test_task.py diff --git a/durabletask/task.py b/durabletask/task.py index 50970fd..5210c99 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -291,6 +291,10 @@ def __init__(self, tasks: list[Task[T]]): super().__init__(tasks) self._completed_tasks = 0 self._failed_tasks = 0 + # If there are no child tasks, this composite should complete immediately + if len(self._tasks) == 0: + self._result = [] # type: ignore[assignment] + self._is_complete = True @property def pending_tasks(self) -> int: @@ -388,6 +392,10 @@ class WhenAnyTask(CompositeTask[Task]): def __init__(self, tasks: list[Task]): super().__init__(tasks) + # If there are no child tasks, complete immediately with an empty result + if len(self._tasks) == 0: + self._result = [] # type: ignore[assignment] + self._is_complete = True def on_child_completed(self, task: Task): # The first task to complete is the result of the WhenAnyTask. diff --git a/tests/durabletask/test_orchestration_e2e_async.py b/tests/durabletask/test_orchestration_e2e_async.py index 7717840..eab2135 100644 --- a/tests/durabletask/test_orchestration_e2e_async.py +++ b/tests/durabletask/test_orchestration_e2e_async.py @@ -232,6 +232,8 @@ def orchestrator(ctx: task.OrchestrationContext, _): with worker.TaskHubGrpcWorker() as w: w.add_orchestrator(orchestrator) w.start() + # there could be a race condition if the workflow is scheduled before orchestrator is started + await asyncio.sleep(0.2) client = AsyncTaskHubGrpcClient() id = await client.schedule_new_orchestration(orchestrator) diff --git a/tests/durabletask/test_task.py b/tests/durabletask/test_task.py new file mode 100644 index 0000000..914df5b --- /dev/null +++ b/tests/durabletask/test_task.py @@ -0,0 +1,67 @@ +"""Unit tests for durabletask.task primitives.""" + +from durabletask import task + + +def test_when_all_empty_returns_successfully(): + """task.when_all([]) should complete immediately and return an empty list.""" + when_all_task = task.when_all([]) + + assert when_all_task.is_complete + assert when_all_task.get_result() == [] + +def test_when_any_empty_returns_successfully(): + """task.when_any([]) should complete immediately and return an empty list.""" + when_any_task = task.when_any([]) + + assert when_any_task.is_complete + assert when_any_task.get_result() == [] + + +def test_when_all_happy_path_returns_ordered_results_and_completes_last(): + c1 = task.CompletableTask() + c2 = task.CompletableTask() + c3 = task.CompletableTask() + + all_task = task.when_all([c1, c2, c3]) + + assert not all_task.is_complete + + c2.complete("two") + + assert not all_task.is_complete + + c1.complete("one") + + assert not all_task.is_complete + + c3.complete("three") + + assert all_task.is_complete + + assert all_task.get_result() == ["one", "two", "three"] + + +def test_when_any_happy_path_returns_winner_task_and_completes_on_first(): + a = task.CompletableTask() + b = task.CompletableTask() + + any_task = task.when_any([a, b]) + + assert not any_task.is_complete + + b.complete("B") + + assert any_task.is_complete + + winner = any_task.get_result() + + assert winner is b + + assert winner.get_result() == "B" + + # Completing the other child should not change the winner + a.complete("A") + + assert any_task.get_result() is b + From ecccbef3a0b287494fcb92afd5ca120bec45f0b7 Mon Sep 17 00:00:00 2001 From: Patrick Assuied Date: Wed, 1 Oct 2025 10:47:29 -0700 Subject: [PATCH 5/6] Enabled context manager in client. Added copyright headers on new files Signed-off-by: Patrick Assuied --- durabletask/aio/client.py | 10 + durabletask/aio/internal/grpc_interceptor.py | 3 + durabletask/aio/internal/shared.py | 3 + tests/durabletask/test_client_async.py | 3 + .../test_orchestration_e2e_async.py | 190 +++++++++--------- tests/durabletask/test_task.py | 5 +- 6 files changed, 116 insertions(+), 98 deletions(-) diff --git a/durabletask/aio/client.py b/durabletask/aio/client.py index ee5abd7..4ec9bbf 100644 --- a/durabletask/aio/client.py +++ b/durabletask/aio/client.py @@ -1,3 +1,6 @@ +# Copyright (c) The Dapr Authors. +# Licensed under the MIT License. + import logging import uuid from datetime import datetime @@ -47,6 +50,13 @@ def __init__(self, *, async def aclose(self): await self._channel.close() + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + await self.aclose() + return False + async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, input: Optional[TInput] = None, instance_id: Optional[str] = None, diff --git a/durabletask/aio/internal/grpc_interceptor.py b/durabletask/aio/internal/grpc_interceptor.py index 06dae95..bf1ac98 100644 --- a/durabletask/aio/internal/grpc_interceptor.py +++ b/durabletask/aio/internal/grpc_interceptor.py @@ -1,3 +1,6 @@ +# Copyright (c) The Dapr Authors. +# Licensed under the MIT License. + from collections import namedtuple from grpc import aio as grpc_aio diff --git a/durabletask/aio/internal/shared.py b/durabletask/aio/internal/shared.py index 3e09ff3..6bdb256 100644 --- a/durabletask/aio/internal/shared.py +++ b/durabletask/aio/internal/shared.py @@ -1,3 +1,6 @@ +# Copyright (c) The Dapr Authors. +# Licensed under the MIT License. + from typing import Optional, Sequence, Union import grpc diff --git a/tests/durabletask/test_client_async.py b/tests/durabletask/test_client_async.py index 6e2b919..8f2b83e 100644 --- a/tests/durabletask/test_client_async.py +++ b/tests/durabletask/test_client_async.py @@ -1,3 +1,6 @@ +# Copyright (c) The Dapr Authors. +# Licensed under the MIT License. + from unittest.mock import ANY, patch from durabletask.aio.internal.grpc_interceptor import DefaultClientInterceptorImpl diff --git a/tests/durabletask/test_orchestration_e2e_async.py b/tests/durabletask/test_orchestration_e2e_async.py index eab2135..de586f1 100644 --- a/tests/durabletask/test_orchestration_e2e_async.py +++ b/tests/durabletask/test_orchestration_e2e_async.py @@ -1,3 +1,6 @@ +# Copyright (c) The Dapr Authors. +# Licensed under the MIT License. + import asyncio import json import threading @@ -235,34 +238,33 @@ def orchestrator(ctx: task.OrchestrationContext, _): # there could be a race condition if the workflow is scheduled before orchestrator is started await asyncio.sleep(0.2) - client = AsyncTaskHubGrpcClient() - id = await client.schedule_new_orchestration(orchestrator) - state = await client.wait_for_orchestration_start(id, timeout=30) - assert state is not None - - # Suspend the orchestration and wait for it to go into the SUSPENDED state - await client.suspend_orchestration(id) - while state.runtime_status == OrchestrationStatus.RUNNING: - await asyncio.sleep(0.1) - state = await client.get_orchestration_state(id) + async with AsyncTaskHubGrpcClient() as client: + id = await client.schedule_new_orchestration(orchestrator) + state = await client.wait_for_orchestration_start(id, timeout=30) assert state is not None - assert state.runtime_status == OrchestrationStatus.SUSPENDED - # Raise an event to the orchestration and confirm that it does NOT complete - await client.raise_orchestration_event(id, "my_event", data=42) - try: - state = await client.wait_for_orchestration_completion(id, timeout=3) - assert False, "Orchestration should not have completed" - except TimeoutError: - pass - - # Resume the orchestration and wait for it to complete - await client.resume_orchestration(id) - state = await client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == OrchestrationStatus.COMPLETED - assert state.serialized_output == json.dumps(42) - await client.aclose() + # Suspend the orchestration and wait for it to go into the SUSPENDED state + await client.suspend_orchestration(id) + while state.runtime_status == OrchestrationStatus.RUNNING: + await asyncio.sleep(0.1) + state = await client.get_orchestration_state(id) + assert state is not None + assert state.runtime_status == OrchestrationStatus.SUSPENDED + + # Raise an event to the orchestration and confirm that it does NOT complete + await client.raise_orchestration_event(id, "my_event", data=42) + try: + state = await client.wait_for_orchestration_completion(id, timeout=3) + assert False, "Orchestration should not have completed" + except TimeoutError: + pass + + # Resume the orchestration and wait for it to complete + await client.resume_orchestration(id) + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(42) async def test_terminate(): @@ -275,18 +277,17 @@ def orchestrator(ctx: task.OrchestrationContext, _): w.add_orchestrator(orchestrator) w.start() - client = AsyncTaskHubGrpcClient() - id = await client.schedule_new_orchestration(orchestrator) - state = await client.wait_for_orchestration_start(id, timeout=30) - assert state is not None - assert state.runtime_status == OrchestrationStatus.RUNNING + async with AsyncTaskHubGrpcClient() as client: + id = await client.schedule_new_orchestration(orchestrator) + state = await client.wait_for_orchestration_start(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.RUNNING - await client.terminate_orchestration(id, output="some reason for termination") - state = await client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == OrchestrationStatus.TERMINATED - assert state.serialized_output == json.dumps("some reason for termination") - await client.aclose() + await client.terminate_orchestration(id, output="some reason for termination") + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.TERMINATED + assert state.serialized_output == json.dumps("some reason for termination") async def test_terminate_recursive(): @@ -304,27 +305,26 @@ def child(ctx: task.OrchestrationContext, _): w.add_orchestrator(child) w.start() - client = AsyncTaskHubGrpcClient() - id = await client.schedule_new_orchestration(root) - state = await client.wait_for_orchestration_start(id, timeout=30) - assert state is not None - assert state.runtime_status == OrchestrationStatus.RUNNING + async with AsyncTaskHubGrpcClient() as client: + id = await client.schedule_new_orchestration(root) + state = await client.wait_for_orchestration_start(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.RUNNING - # Terminate root orchestration(recursive set to True by default) - await client.terminate_orchestration(id, output="some reason for termination") - state = await client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == OrchestrationStatus.TERMINATED + # Terminate root orchestration(recursive set to True by default) + await client.terminate_orchestration(id, output="some reason for termination") + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.TERMINATED - # Verify that child orchestration is also terminated - await client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == OrchestrationStatus.TERMINATED + # Verify that child orchestration is also terminated + await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.TERMINATED - await client.purge_orchestration(id) - state = await client.get_orchestration_state(id) - assert state is None - await client.aclose() + await client.purge_orchestration(id) + state = await client.get_orchestration_state(id) + assert state is None async def test_continue_as_new(): @@ -347,21 +347,20 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): w.add_orchestrator(orchestrator) w.start() - client = AsyncTaskHubGrpcClient() - id = await client.schedule_new_orchestration(orchestrator, input=0) - await client.raise_orchestration_event(id, "my_event", data=1) - await client.raise_orchestration_event(id, "my_event", data=2) - await client.raise_orchestration_event(id, "my_event", data=3) - await client.raise_orchestration_event(id, "my_event", data=4) - await client.raise_orchestration_event(id, "my_event", data=5) + async with AsyncTaskHubGrpcClient() as client: + id = await client.schedule_new_orchestration(orchestrator, input=0) + await client.raise_orchestration_event(id, "my_event", data=1) + await client.raise_orchestration_event(id, "my_event", data=2) + await client.raise_orchestration_event(id, "my_event", data=3) + await client.raise_orchestration_event(id, "my_event", data=4) + await client.raise_orchestration_event(id, "my_event", data=5) - state = await client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == OrchestrationStatus.COMPLETED - assert state.serialized_output == json.dumps(all_results) - assert state.serialized_input == json.dumps(4) - assert all_results == [1, 2, 3, 4, 5] - await client.aclose() + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(all_results) + assert state.serialized_input == json.dumps(4) + assert all_results == [1, 2, 3, 4, 5] async def test_retry_policies(): @@ -405,19 +404,18 @@ def throw_activity_with_retry(ctx: task.ActivityContext, _): w.add_activity(throw_activity_with_retry) w.start() - client = AsyncTaskHubGrpcClient() - id = await client.schedule_new_orchestration(parent_orchestrator_with_retry) - state = await client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == OrchestrationStatus.FAILED - assert state.failure_details is not None - assert state.failure_details.error_type == "TaskFailedError" - assert state.failure_details.message.startswith("Sub-orchestration task #1 failed:") - assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!") - assert state.failure_details.stack_trace is not None - assert throw_activity_counter == 9 - assert child_orch_counter == 3 - await client.aclose() + async with AsyncTaskHubGrpcClient() as client: + id = await client.schedule_new_orchestration(parent_orchestrator_with_retry) + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.FAILED + assert state.failure_details is not None + assert state.failure_details.error_type == "TaskFailedError" + assert state.failure_details.message.startswith("Sub-orchestration task #1 failed:") + assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!") + assert state.failure_details.stack_trace is not None + assert throw_activity_counter == 9 + assert child_orch_counter == 3 async def test_retry_timeout(): @@ -446,17 +444,16 @@ def throw_activity(ctx: task.ActivityContext, _): w.add_activity(throw_activity) w.start() - client = AsyncTaskHubGrpcClient() - id = await client.schedule_new_orchestration(mock_orchestrator) - state = await client.wait_for_orchestration_completion(id, timeout=30) - assert state is not None - assert state.runtime_status == OrchestrationStatus.FAILED - assert state.failure_details is not None - assert state.failure_details.error_type == "TaskFailedError" - assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!") - assert state.failure_details.stack_trace is not None - assert throw_activity_counter == 4 - await client.aclose() + async with AsyncTaskHubGrpcClient() as client: + id = await client.schedule_new_orchestration(mock_orchestrator) + state = await client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == OrchestrationStatus.FAILED + assert state.failure_details is not None + assert state.failure_details.error_type == "TaskFailedError" + assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!") + assert state.failure_details.stack_trace is not None + assert throw_activity_counter == 4 async def test_custom_status(): @@ -469,10 +466,9 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): w.add_orchestrator(empty_orchestrator) w.start() - c = AsyncTaskHubGrpcClient() - id = await c.schedule_new_orchestration(empty_orchestrator) - state = await c.wait_for_orchestration_completion(id, timeout=30) - await c.aclose() + async with AsyncTaskHubGrpcClient() as client: + id = await client.schedule_new_orchestration(empty_orchestrator) + state = await client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.name == task.get_name(empty_orchestrator) diff --git a/tests/durabletask/test_task.py b/tests/durabletask/test_task.py index 914df5b..81cc8a2 100644 --- a/tests/durabletask/test_task.py +++ b/tests/durabletask/test_task.py @@ -1,3 +1,6 @@ +# Copyright (c) The Dapr Authors. +# Licensed under the MIT License. + """Unit tests for durabletask.task primitives.""" from durabletask import task @@ -10,6 +13,7 @@ def test_when_all_empty_returns_successfully(): assert when_all_task.is_complete assert when_all_task.get_result() == [] + def test_when_any_empty_returns_successfully(): """task.when_any([]) should complete immediately and return an empty list.""" when_any_task = task.when_any([]) @@ -64,4 +68,3 @@ def test_when_any_happy_path_returns_winner_task_and_completes_on_first(): a.complete("A") assert any_task.get_result() is b - From 3d8528d423af0bb2794d482a521d7831b24d1dd0 Mon Sep 17 00:00:00 2001 From: Patrick Assuied Date: Wed, 1 Oct 2025 12:36:22 -0700 Subject: [PATCH 6/6] reverting dependency updates and readme changes Signed-off-by: Patrick Assuied --- README.md | 2 -- dev-requirements.txt | 6 +----- requirements.txt | 6 +++++- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 3f691c0..4a45d9b 100644 --- a/README.md +++ b/README.md @@ -173,7 +173,6 @@ This will download the `orchestrator_service.proto` from the `microsoft/durablet Unit tests can be run using the following command from the project root. Unit tests _don't_ require a sidecar process to be running. ```sh -pip3 install -r dev-requirements.txt make test-unit ``` @@ -189,7 +188,6 @@ durabletask-go --port 4001 To run the E2E tests, run the following command from the project root: ```sh -pip3 install -r dev-requirements.txt make test-e2e ``` diff --git a/dev-requirements.txt b/dev-requirements.txt index 80d1ba7..ba589ab 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,5 +1 @@ -grpcio-tools>=1.75.1 # supports protobuf 6.x and aligns with generated code -pytest -pytest-cov -pytest-asyncio -flake8 \ No newline at end of file +grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python # supports protobuf 6.x and aligns with generated code \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 0f47c7f..06750e2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,8 @@ autopep8 grpcio>=1.60.0 # 1.60.0 is the version introducing protobuf 1.25.X support, newer versions are backwards compatible -protobuf>=6,<7 +protobuf asyncio +pytest +pytest-cov +pytest-asyncio +flake8 \ No newline at end of file