Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pr-validation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install flake8 pytest pytest-cov pytest-asyncio
pip install -r requirements.txt
- name: Lint with flake8
run: |
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ This will download the `orchestrator_service.proto` from the `microsoft/durablet
Unit tests can be run using the following command from the project root. Unit tests _don't_ require a sidecar process to be running.

```sh
pip3 install -r dev-requirements.txt
make test-unit
```

Expand All @@ -188,6 +189,7 @@ durabletask-go --port 4001
To run the E2E tests, run the following command from the project root:

```sh
pip3 install -r dev-requirements.txt
make test-e2e
```

Expand Down
4 changes: 4 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python
pytest
pytest-cov
pytest-asyncio
flake8
5 changes: 5 additions & 0 deletions durabletask/aio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .client import AsyncTaskHubGrpcClient

__all__ = [
"AsyncTaskHubGrpcClient",
]
160 changes: 160 additions & 0 deletions durabletask/aio/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import logging
import uuid
from datetime import datetime
from typing import Any, Optional, Sequence, Union

import grpc
from google.protobuf import wrappers_pb2

import durabletask.internal.helpers as helpers
import durabletask.internal.orchestrator_service_pb2 as pb
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
import durabletask.internal.shared as shared
from durabletask.aio.internal.shared import get_grpc_aio_channel, AioClientInterceptor
from durabletask import task
from durabletask.client import OrchestrationState, OrchestrationStatus, new_orchestration_state, TInput, TOutput
from durabletask.aio.internal.grpc_interceptor import DefaultAioClientInterceptorImpl


class AsyncTaskHubGrpcClient:

def __init__(self, *,
host_address: Optional[str] = None,
metadata: Optional[list[tuple[str, str]]] = None,
log_handler: Optional[logging.Handler] = None,
log_formatter: Optional[logging.Formatter] = None,
secure_channel: bool = False,
interceptors: Optional[Sequence[AioClientInterceptor]] = None):

if interceptors is not None:
interceptors = list(interceptors)
if metadata is not None:
interceptors.append(DefaultAioClientInterceptorImpl(metadata))
elif metadata is not None:
interceptors = [DefaultAioClientInterceptorImpl(metadata)]
else:
interceptors = None

channel = get_grpc_aio_channel(
host_address=host_address,
secure_channel=secure_channel,
interceptors=interceptors
)
self._channel = channel
self._stub = stubs.TaskHubSidecarServiceStub(channel)
self._logger = shared.get_logger("client", log_handler, log_formatter)

async def aclose(self):
await self._channel.close()

async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
start_at: Optional[datetime] = None,
reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None) -> str:

name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)

req = pb.CreateInstanceRequest(
name=name,
instanceId=instance_id if instance_id else uuid.uuid4().hex,
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
version=helpers.get_string_value(None),
orchestrationIdReusePolicy=reuse_id_policy,
)

self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.")
res: pb.CreateInstanceResponse = await self._stub.StartInstance(req)
return res.instanceId

async def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]:
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
res: pb.GetInstanceResponse = await self._stub.GetInstance(req)
return new_orchestration_state(req.instanceId, res)

async def wait_for_orchestration_start(self, instance_id: str, *,
fetch_payloads: bool = False,
timeout: int = 0) -> Optional[OrchestrationState]:
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
try:
grpc_timeout = None if timeout == 0 else timeout
self._logger.info(
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start.")
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=grpc_timeout)
return new_orchestration_state(req.instanceId, res)
except grpc.RpcError as rpc_error:
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
# Replace gRPC error with the built-in TimeoutError
raise TimeoutError("Timed-out waiting for the orchestration to start")
else:
raise

async def wait_for_orchestration_completion(self, instance_id: str, *,
fetch_payloads: bool = True,
timeout: int = 0) -> Optional[OrchestrationState]:
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
try:
grpc_timeout = None if timeout == 0 else timeout
self._logger.info(
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete.")
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=grpc_timeout)
state = new_orchestration_state(req.instanceId, res)
if not state:
return None

if state.runtime_status == OrchestrationStatus.FAILED and state.failure_details is not None:
details = state.failure_details
self._logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}")
elif state.runtime_status == OrchestrationStatus.TERMINATED:
self._logger.info(f"Instance '{instance_id}' was terminated.")
elif state.runtime_status == OrchestrationStatus.COMPLETED:
self._logger.info(f"Instance '{instance_id}' completed.")

return state
except grpc.RpcError as rpc_error:
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore
# Replace gRPC error with the built-in TimeoutError
raise TimeoutError("Timed-out waiting for the orchestration to complete")
else:
raise

async def raise_orchestration_event(
self,
instance_id: str,
event_name: str,
*,
data: Optional[Any] = None):
req = pb.RaiseEventRequest(
instanceId=instance_id,
name=event_name,
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None)

self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
await self._stub.RaiseEvent(req)

async def terminate_orchestration(self, instance_id: str, *,
output: Optional[Any] = None,
recursive: bool = True):
req = pb.TerminateRequest(
instanceId=instance_id,
output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None,
recursive=recursive)

self._logger.info(f"Terminating instance '{instance_id}'.")
await self._stub.TerminateInstance(req)

async def suspend_orchestration(self, instance_id: str):
req = pb.SuspendRequest(instanceId=instance_id)
self._logger.info(f"Suspending instance '{instance_id}'.")
await self._stub.SuspendInstance(req)

async def resume_orchestration(self, instance_id: str):
req = pb.ResumeRequest(instanceId=instance_id)
self._logger.info(f"Resuming instance '{instance_id}'.")
await self._stub.ResumeInstance(req)

async def purge_orchestration(self, instance_id: str, recursive: bool = True):
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive)
self._logger.info(f"Purging instance '{instance_id}'.")
await self._stub.PurgeInstances(req)
Empty file.
55 changes: 55 additions & 0 deletions durabletask/aio/internal/grpc_interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from collections import namedtuple

from grpc import aio as grpc_aio


class _AioClientCallDetails(
namedtuple(
'_AioClientCallDetails',
['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']),
grpc_aio.ClientCallDetails):
pass


class DefaultAioClientInterceptorImpl(
grpc_aio.UnaryUnaryClientInterceptor, grpc_aio.UnaryStreamClientInterceptor,
grpc_aio.StreamUnaryClientInterceptor, grpc_aio.StreamStreamClientInterceptor):
"""Async gRPC client interceptor to add metadata to all calls."""

def __init__(self, metadata: list[tuple[str, str]]):
super().__init__()
self._metadata = metadata

def _intercept_call(self, client_call_details: _AioClientCallDetails) -> grpc_aio.ClientCallDetails:
if self._metadata is None:
return client_call_details

if client_call_details.metadata is not None:
metadata = list(client_call_details.metadata)
else:
metadata = []

metadata.extend(self._metadata)
return _AioClientCallDetails(
client_call_details.method,
client_call_details.timeout,
metadata,
client_call_details.credentials,
client_call_details.wait_for_ready,
client_call_details.compression)

async def intercept_unary_unary(self, continuation, client_call_details, request):
new_client_call_details = self._intercept_call(client_call_details)
return await continuation(new_client_call_details, request)

async def intercept_unary_stream(self, continuation, client_call_details, request):
new_client_call_details = self._intercept_call(client_call_details)
return await continuation(new_client_call_details, request)

async def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
new_client_call_details = self._intercept_call(client_call_details)
return await continuation(new_client_call_details, request_iterator)

async def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
new_client_call_details = self._intercept_call(client_call_details)
return await continuation(new_client_call_details, request_iterator)
46 changes: 46 additions & 0 deletions durabletask/aio/internal/shared.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Optional, Sequence, Union

import grpc
from grpc import aio as grpc_aio

from durabletask.internal.shared import (
get_default_host_address,
SECURE_PROTOCOLS,
INSECURE_PROTOCOLS,
)


AioClientInterceptor = Union[
grpc_aio.UnaryUnaryClientInterceptor,
grpc_aio.UnaryStreamClientInterceptor,
grpc_aio.StreamUnaryClientInterceptor,
grpc_aio.StreamStreamClientInterceptor
]


def get_grpc_aio_channel(
host_address: Optional[str],
secure_channel: bool = False,
interceptors: Optional[Sequence[AioClientInterceptor]] = None) -> grpc_aio.Channel:

if host_address is None:
host_address = get_default_host_address()

for protocol in SECURE_PROTOCOLS:
if host_address.lower().startswith(protocol):
secure_channel = True
host_address = host_address[len(protocol):]
break

for protocol in INSECURE_PROTOCOLS:
if host_address.lower().startswith(protocol):
secure_channel = False
host_address = host_address[len(protocol):]
break

if secure_channel:
channel = grpc_aio.secure_channel(host_address, grpc.ssl_channel_credentials(), interceptors=interceptors)
else:
channel = grpc_aio.insecure_channel(host_address, interceptors=interceptors)

return channel
1 change: 1 addition & 0 deletions durabletask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def get_tasks(self) -> list[Task]:
def on_child_completed(self, task: Task[T]):
pass


class WhenAllTask(CompositeTask[list[T]]):
"""A task that completes when all of its child tasks complete."""

Expand Down
2 changes: 1 addition & 1 deletion durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,13 +880,13 @@ class ExecutionResults:
actions: list[pb.OrchestratorAction]
encoded_custom_status: Optional[str]


def __init__(
self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]
):
self.actions = actions
self.encoded_custom_status = encoded_custom_status


class _OrchestrationExecutor:
_generator: Optional[task.Orchestrator] = None

Expand Down
2 changes: 0 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
autopep8
grpcio>=1.60.0 # 1.60.0 is the version introducing protobuf 1.25.X support, newer versions are backwards compatible
protobuf
pytest
pytest-cov
asyncio
1 change: 1 addition & 0 deletions tests/durabletask/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_get_grpc_channel_secure():
get_grpc_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS)
mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value)


def test_get_grpc_channel_default_host_address():
with patch('grpc.insecure_channel') as mock_channel:
get_grpc_channel(None, False, interceptors=INTERCEPTORS)
Expand Down
Loading
Loading