Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-validation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install flake8 pytest pytest-cov pytest-asyncio
pip install -r requirements.txt
- name: Lint with flake8
run: |
Expand Down
2 changes: 1 addition & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python
grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python # supports protobuf 6.x and aligns with generated code
5 changes: 5 additions & 0 deletions durabletask/aio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .client import AsyncTaskHubGrpcClient

__all__ = [
"AsyncTaskHubGrpcClient",
]
170 changes: 170 additions & 0 deletions durabletask/aio/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# Copyright (c) The Dapr Authors.
# Licensed under the MIT License.

import logging
import uuid
from datetime import datetime
from typing import Any, Optional, Sequence, Union

import grpc
from google.protobuf import wrappers_pb2

import durabletask.internal.helpers as helpers
import durabletask.internal.orchestrator_service_pb2 as pb
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
import durabletask.internal.shared as shared
from durabletask.aio.internal.shared import get_grpc_aio_channel, ClientInterceptor
from durabletask import task
from durabletask.client import OrchestrationState, OrchestrationStatus, new_orchestration_state, TInput, TOutput
from durabletask.aio.internal.grpc_interceptor import DefaultClientInterceptorImpl


class AsyncTaskHubGrpcClient:

def __init__(self, *,
host_address: Optional[str] = None,
metadata: Optional[list[tuple[str, str]]] = None,
log_handler: Optional[logging.Handler] = None,
log_formatter: Optional[logging.Formatter] = None,
secure_channel: bool = False,
interceptors: Optional[Sequence[ClientInterceptor]] = None):

if interceptors is not None:
interceptors = list(interceptors)
if metadata is not None:
interceptors.append(DefaultClientInterceptorImpl(metadata))
elif metadata is not None:
interceptors = [DefaultClientInterceptorImpl(metadata)]
else:
interceptors = None

channel = get_grpc_aio_channel(
host_address=host_address,
secure_channel=secure_channel,
interceptors=interceptors
)
self._channel = channel
self._stub = stubs.TaskHubSidecarServiceStub(channel)
self._logger = shared.get_logger("client", log_handler, log_formatter)

async def aclose(self):
await self._channel.close()

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
return False

async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
start_at: Optional[datetime] = None,
reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None) -> str:

name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)

req = pb.CreateInstanceRequest(
name=name,
instanceId=instance_id if instance_id else uuid.uuid4().hex,
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
version=helpers.get_string_value(None),
orchestrationIdReusePolicy=reuse_id_policy,
)

self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.")
res: pb.CreateInstanceResponse = await self._stub.StartInstance(req)
return res.instanceId

async def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]:
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
res: pb.GetInstanceResponse = await self._stub.GetInstance(req)
return new_orchestration_state(req.instanceId, res)

async def wait_for_orchestration_start(self, instance_id: str, *,
fetch_payloads: bool = False,
timeout: int = 0) -> Optional[OrchestrationState]:
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
try:
grpc_timeout = None if timeout == 0 else timeout
self._logger.info(
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start.")
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=grpc_timeout)
return new_orchestration_state(req.instanceId, res)
except grpc.RpcError as rpc_error:
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
# Replace gRPC error with the built-in TimeoutError
raise TimeoutError("Timed-out waiting for the orchestration to start")
else:
raise

async def wait_for_orchestration_completion(self, instance_id: str, *,
fetch_payloads: bool = True,
timeout: int = 0) -> Optional[OrchestrationState]:
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
try:
grpc_timeout = None if timeout == 0 else timeout
self._logger.info(
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete.")
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=grpc_timeout)
state = new_orchestration_state(req.instanceId, res)
if not state:
return None

if state.runtime_status == OrchestrationStatus.FAILED and state.failure_details is not None:
details = state.failure_details
self._logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}")
elif state.runtime_status == OrchestrationStatus.TERMINATED:
self._logger.info(f"Instance '{instance_id}' was terminated.")
elif state.runtime_status == OrchestrationStatus.COMPLETED:
self._logger.info(f"Instance '{instance_id}' completed.")

return state
except grpc.RpcError as rpc_error:
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
# Replace gRPC error with the built-in TimeoutError
raise TimeoutError("Timed-out waiting for the orchestration to complete")
else:
raise

async def raise_orchestration_event(
self,
instance_id: str,
event_name: str,
*,
data: Optional[Any] = None):
req = pb.RaiseEventRequest(
instanceId=instance_id,
name=event_name,
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None)

self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
await self._stub.RaiseEvent(req)

async def terminate_orchestration(self, instance_id: str, *,
output: Optional[Any] = None,
recursive: bool = True):
req = pb.TerminateRequest(
instanceId=instance_id,
output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None,
recursive=recursive)

self._logger.info(f"Terminating instance '{instance_id}'.")
await self._stub.TerminateInstance(req)

async def suspend_orchestration(self, instance_id: str):
req = pb.SuspendRequest(instanceId=instance_id)
self._logger.info(f"Suspending instance '{instance_id}'.")
await self._stub.SuspendInstance(req)

async def resume_orchestration(self, instance_id: str):
req = pb.ResumeRequest(instanceId=instance_id)
self._logger.info(f"Resuming instance '{instance_id}'.")
await self._stub.ResumeInstance(req)

async def purge_orchestration(self, instance_id: str, recursive: bool = True):
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
self._logger.info(f"Purging instance '{instance_id}'.")
await self._stub.PurgeInstances(req)
Empty file.
58 changes: 58 additions & 0 deletions durabletask/aio/internal/grpc_interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) The Dapr Authors.
# Licensed under the MIT License.

from collections import namedtuple

from grpc import aio as grpc_aio


class _ClientCallDetails(
namedtuple(
'_ClientCallDetails',
['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']),
grpc_aio.ClientCallDetails):
pass


class DefaultClientInterceptorImpl(
grpc_aio.UnaryUnaryClientInterceptor, grpc_aio.UnaryStreamClientInterceptor,
grpc_aio.StreamUnaryClientInterceptor, grpc_aio.StreamStreamClientInterceptor):
"""Async gRPC client interceptor to add metadata to all calls."""

def __init__(self, metadata: list[tuple[str, str]]):
super().__init__()
self._metadata = metadata

def _intercept_call(self, client_call_details: _ClientCallDetails) -> grpc_aio.ClientCallDetails:
if self._metadata is None:
return client_call_details

if client_call_details.metadata is not None:
metadata = list(client_call_details.metadata)
else:
metadata = []

metadata.extend(self._metadata)
return _ClientCallDetails(
client_call_details.method,
client_call_details.timeout,
metadata,
client_call_details.credentials,
client_call_details.wait_for_ready,
client_call_details.compression)

async def intercept_unary_unary(self, continuation, client_call_details, request):
new_client_call_details = self._intercept_call(client_call_details)
return await continuation(new_client_call_details, request)

async def intercept_unary_stream(self, continuation, client_call_details, request):
new_client_call_details = self._intercept_call(client_call_details)
return await continuation(new_client_call_details, request)

async def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
new_client_call_details = self._intercept_call(client_call_details)
return await continuation(new_client_call_details, request_iterator)

async def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
new_client_call_details = self._intercept_call(client_call_details)
return await continuation(new_client_call_details, request_iterator)
49 changes: 49 additions & 0 deletions durabletask/aio/internal/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) The Dapr Authors.
# Licensed under the MIT License.

from typing import Optional, Sequence, Union

import grpc
from grpc import aio as grpc_aio

from durabletask.internal.shared import (
get_default_host_address,
SECURE_PROTOCOLS,
INSECURE_PROTOCOLS,
)


ClientInterceptor = Union[
grpc_aio.UnaryUnaryClientInterceptor,
grpc_aio.UnaryStreamClientInterceptor,
grpc_aio.StreamUnaryClientInterceptor,
grpc_aio.StreamStreamClientInterceptor
]


def get_grpc_aio_channel(
host_address: Optional[str],
secure_channel: bool = False,
interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc_aio.Channel:

if host_address is None:
host_address = get_default_host_address()

for protocol in SECURE_PROTOCOLS:
if host_address.lower().startswith(protocol):
secure_channel = True
host_address = host_address[len(protocol):]
break

for protocol in INSECURE_PROTOCOLS:
if host_address.lower().startswith(protocol):
secure_channel = False
host_address = host_address[len(protocol):]
break

if secure_channel:
channel = grpc_aio.secure_channel(host_address, grpc.ssl_channel_credentials(), interceptors=interceptors)
else:
channel = grpc_aio.insecure_channel(host_address, interceptors=interceptors)

return channel
9 changes: 9 additions & 0 deletions durabletask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,18 @@ def get_tasks(self) -> list[Task]:
def on_child_completed(self, task: Task[T]):
pass


class WhenAllTask(CompositeTask[list[T]]):
"""A task that completes when all of its child tasks complete."""

def __init__(self, tasks: list[Task[T]]):
super().__init__(tasks)
self._completed_tasks = 0
self._failed_tasks = 0
# If there are no child tasks, this composite should complete immediately
if len(self._tasks) == 0:
self._result = [] # type: ignore[assignment]
self._is_complete = True

@property
def pending_tasks(self) -> int:
Expand Down Expand Up @@ -387,6 +392,10 @@ class WhenAnyTask(CompositeTask[Task]):

def __init__(self, tasks: list[Task]):
super().__init__(tasks)
# If there are no child tasks, complete immediately with an empty result
if len(self._tasks) == 0:
self._result = [] # type: ignore[assignment]
self._is_complete = True

def on_child_completed(self, task: Task):
# The first task to complete is the result of the WhenAnyTask.
Expand Down
2 changes: 1 addition & 1 deletion durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,13 +880,13 @@ class ExecutionResults:
actions: list[pb.OrchestratorAction]
encoded_custom_status: Optional[str]


def __init__(
self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]
):
self.actions = actions
self.encoded_custom_status = encoded_custom_status


class _OrchestrationExecutor:
_generator: Optional[task.Orchestrator] = None

Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
autopep8
grpcio>=1.60.0 # 1.60.0 is the version introducing protobuf 1.25.X support, newer versions are backwards compatible
protobuf
asyncio
pytest
pytest-cov
asyncio
pytest-asyncio
flake8
1 change: 1 addition & 0 deletions tests/durabletask/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_get_grpc_channel_secure():
get_grpc_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS)
mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value)


def test_get_grpc_channel_default_host_address():
with patch('grpc.insecure_channel') as mock_channel:
get_grpc_channel(None, False, interceptors=INTERCEPTORS)
Expand Down
Loading
Loading