forked from microsoft/durabletask-python
-
Notifications
You must be signed in to change notification settings - Fork 5
Add async version of durabletask client #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
passuied
wants to merge
6
commits into
dapr:main
Choose a base branch
from
passuied:feature/asyncio-dapr
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
f4d8a38
- Introduced `AsyncTaskHubGrpcClient` as async implementation of `Tas…
passuied f3f1c4b
Switch to pytest-asyncio
passuied a21da3f
Rename classes to avoid repeating `Aio` per PR feedback.
passuied eac4b82
Fixed a bug where `when_all()` and `when_any()` are passed empty list…
passuied ecccbef
Enabled context manager in client.
passuied 3d8528d
reverting dependency updates and readme changes
passuied File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python | ||
grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python # supports protobuf 6.x and aligns with generated code |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .client import AsyncTaskHubGrpcClient | ||
|
||
__all__ = [ | ||
"AsyncTaskHubGrpcClient", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
# Copyright (c) The Dapr Authors. | ||
# Licensed under the MIT License. | ||
|
||
import logging | ||
import uuid | ||
from datetime import datetime | ||
from typing import Any, Optional, Sequence, Union | ||
|
||
import grpc | ||
from google.protobuf import wrappers_pb2 | ||
|
||
import durabletask.internal.helpers as helpers | ||
import durabletask.internal.orchestrator_service_pb2 as pb | ||
import durabletask.internal.orchestrator_service_pb2_grpc as stubs | ||
import durabletask.internal.shared as shared | ||
from durabletask.aio.internal.shared import get_grpc_aio_channel, ClientInterceptor | ||
from durabletask import task | ||
from durabletask.client import OrchestrationState, OrchestrationStatus, new_orchestration_state, TInput, TOutput | ||
from durabletask.aio.internal.grpc_interceptor import DefaultClientInterceptorImpl | ||
|
||
|
||
class AsyncTaskHubGrpcClient: | ||
|
||
def __init__(self, *, | ||
host_address: Optional[str] = None, | ||
metadata: Optional[list[tuple[str, str]]] = None, | ||
log_handler: Optional[logging.Handler] = None, | ||
log_formatter: Optional[logging.Formatter] = None, | ||
secure_channel: bool = False, | ||
interceptors: Optional[Sequence[ClientInterceptor]] = None): | ||
|
||
if interceptors is not None: | ||
interceptors = list(interceptors) | ||
if metadata is not None: | ||
interceptors.append(DefaultClientInterceptorImpl(metadata)) | ||
elif metadata is not None: | ||
interceptors = [DefaultClientInterceptorImpl(metadata)] | ||
else: | ||
interceptors = None | ||
|
||
channel = get_grpc_aio_channel( | ||
host_address=host_address, | ||
secure_channel=secure_channel, | ||
interceptors=interceptors | ||
) | ||
self._channel = channel | ||
self._stub = stubs.TaskHubSidecarServiceStub(channel) | ||
self._logger = shared.get_logger("client", log_handler, log_formatter) | ||
|
||
async def aclose(self): | ||
await self._channel.close() | ||
|
||
async def __aenter__(self): | ||
return self | ||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb): | ||
await self.aclose() | ||
return False | ||
|
||
async def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, | ||
input: Optional[TInput] = None, | ||
instance_id: Optional[str] = None, | ||
start_at: Optional[datetime] = None, | ||
reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None) -> str: | ||
|
||
name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator) | ||
|
||
req = pb.CreateInstanceRequest( | ||
name=name, | ||
instanceId=instance_id if instance_id else uuid.uuid4().hex, | ||
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None, | ||
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, | ||
version=helpers.get_string_value(None), | ||
orchestrationIdReusePolicy=reuse_id_policy, | ||
) | ||
|
||
self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.") | ||
res: pb.CreateInstanceResponse = await self._stub.StartInstance(req) | ||
return res.instanceId | ||
|
||
async def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]: | ||
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) | ||
res: pb.GetInstanceResponse = await self._stub.GetInstance(req) | ||
return new_orchestration_state(req.instanceId, res) | ||
|
||
async def wait_for_orchestration_start(self, instance_id: str, *, | ||
fetch_payloads: bool = False, | ||
timeout: int = 0) -> Optional[OrchestrationState]: | ||
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) | ||
try: | ||
grpc_timeout = None if timeout == 0 else timeout | ||
self._logger.info( | ||
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start.") | ||
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceStart(req, timeout=grpc_timeout) | ||
return new_orchestration_state(req.instanceId, res) | ||
except grpc.RpcError as rpc_error: | ||
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore | ||
# Replace gRPC error with the built-in TimeoutError | ||
raise TimeoutError("Timed-out waiting for the orchestration to start") | ||
else: | ||
raise | ||
|
||
async def wait_for_orchestration_completion(self, instance_id: str, *, | ||
fetch_payloads: bool = True, | ||
timeout: int = 0) -> Optional[OrchestrationState]: | ||
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) | ||
try: | ||
grpc_timeout = None if timeout == 0 else timeout | ||
self._logger.info( | ||
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete.") | ||
res: pb.GetInstanceResponse = await self._stub.WaitForInstanceCompletion(req, timeout=grpc_timeout) | ||
state = new_orchestration_state(req.instanceId, res) | ||
if not state: | ||
return None | ||
|
||
if state.runtime_status == OrchestrationStatus.FAILED and state.failure_details is not None: | ||
details = state.failure_details | ||
self._logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}") | ||
elif state.runtime_status == OrchestrationStatus.TERMINATED: | ||
self._logger.info(f"Instance '{instance_id}' was terminated.") | ||
elif state.runtime_status == OrchestrationStatus.COMPLETED: | ||
self._logger.info(f"Instance '{instance_id}' completed.") | ||
|
||
return state | ||
except grpc.RpcError as rpc_error: | ||
if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore | ||
# Replace gRPC error with the built-in TimeoutError | ||
raise TimeoutError("Timed-out waiting for the orchestration to complete") | ||
else: | ||
raise | ||
|
||
async def raise_orchestration_event( | ||
self, | ||
instance_id: str, | ||
event_name: str, | ||
*, | ||
data: Optional[Any] = None): | ||
req = pb.RaiseEventRequest( | ||
instanceId=instance_id, | ||
name=event_name, | ||
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None) | ||
|
||
self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") | ||
await self._stub.RaiseEvent(req) | ||
|
||
async def terminate_orchestration(self, instance_id: str, *, | ||
output: Optional[Any] = None, | ||
recursive: bool = True): | ||
req = pb.TerminateRequest( | ||
instanceId=instance_id, | ||
output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None, | ||
recursive=recursive) | ||
|
||
self._logger.info(f"Terminating instance '{instance_id}'.") | ||
await self._stub.TerminateInstance(req) | ||
|
||
async def suspend_orchestration(self, instance_id: str): | ||
req = pb.SuspendRequest(instanceId=instance_id) | ||
self._logger.info(f"Suspending instance '{instance_id}'.") | ||
await self._stub.SuspendInstance(req) | ||
|
||
async def resume_orchestration(self, instance_id: str): | ||
req = pb.ResumeRequest(instanceId=instance_id) | ||
self._logger.info(f"Resuming instance '{instance_id}'.") | ||
await self._stub.ResumeInstance(req) | ||
|
||
async def purge_orchestration(self, instance_id: str, recursive: bool = True): | ||
req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive) | ||
self._logger.info(f"Purging instance '{instance_id}'.") | ||
await self._stub.PurgeInstances(req) |
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright (c) The Dapr Authors. | ||
# Licensed under the MIT License. | ||
|
||
from collections import namedtuple | ||
|
||
from grpc import aio as grpc_aio | ||
|
||
|
||
class _ClientCallDetails( | ||
namedtuple( | ||
'_ClientCallDetails', | ||
['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']), | ||
grpc_aio.ClientCallDetails): | ||
pass | ||
|
||
|
||
class DefaultClientInterceptorImpl( | ||
grpc_aio.UnaryUnaryClientInterceptor, grpc_aio.UnaryStreamClientInterceptor, | ||
grpc_aio.StreamUnaryClientInterceptor, grpc_aio.StreamStreamClientInterceptor): | ||
"""Async gRPC client interceptor to add metadata to all calls.""" | ||
|
||
def __init__(self, metadata: list[tuple[str, str]]): | ||
super().__init__() | ||
self._metadata = metadata | ||
|
||
def _intercept_call(self, client_call_details: _ClientCallDetails) -> grpc_aio.ClientCallDetails: | ||
if self._metadata is None: | ||
return client_call_details | ||
|
||
if client_call_details.metadata is not None: | ||
metadata = list(client_call_details.metadata) | ||
else: | ||
metadata = [] | ||
|
||
metadata.extend(self._metadata) | ||
return _ClientCallDetails( | ||
client_call_details.method, | ||
client_call_details.timeout, | ||
metadata, | ||
client_call_details.credentials, | ||
client_call_details.wait_for_ready, | ||
client_call_details.compression) | ||
|
||
async def intercept_unary_unary(self, continuation, client_call_details, request): | ||
new_client_call_details = self._intercept_call(client_call_details) | ||
return await continuation(new_client_call_details, request) | ||
|
||
async def intercept_unary_stream(self, continuation, client_call_details, request): | ||
new_client_call_details = self._intercept_call(client_call_details) | ||
return await continuation(new_client_call_details, request) | ||
|
||
async def intercept_stream_unary(self, continuation, client_call_details, request_iterator): | ||
new_client_call_details = self._intercept_call(client_call_details) | ||
return await continuation(new_client_call_details, request_iterator) | ||
|
||
async def intercept_stream_stream(self, continuation, client_call_details, request_iterator): | ||
new_client_call_details = self._intercept_call(client_call_details) | ||
return await continuation(new_client_call_details, request_iterator) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright (c) The Dapr Authors. | ||
# Licensed under the MIT License. | ||
|
||
from typing import Optional, Sequence, Union | ||
|
||
import grpc | ||
from grpc import aio as grpc_aio | ||
|
||
from durabletask.internal.shared import ( | ||
get_default_host_address, | ||
SECURE_PROTOCOLS, | ||
INSECURE_PROTOCOLS, | ||
) | ||
|
||
|
||
ClientInterceptor = Union[ | ||
grpc_aio.UnaryUnaryClientInterceptor, | ||
grpc_aio.UnaryStreamClientInterceptor, | ||
grpc_aio.StreamUnaryClientInterceptor, | ||
grpc_aio.StreamStreamClientInterceptor | ||
] | ||
|
||
|
||
def get_grpc_aio_channel( | ||
host_address: Optional[str], | ||
secure_channel: bool = False, | ||
interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc_aio.Channel: | ||
|
||
if host_address is None: | ||
host_address = get_default_host_address() | ||
|
||
for protocol in SECURE_PROTOCOLS: | ||
if host_address.lower().startswith(protocol): | ||
secure_channel = True | ||
host_address = host_address[len(protocol):] | ||
break | ||
|
||
for protocol in INSECURE_PROTOCOLS: | ||
if host_address.lower().startswith(protocol): | ||
secure_channel = False | ||
host_address = host_address[len(protocol):] | ||
break | ||
|
||
if secure_channel: | ||
channel = grpc_aio.secure_channel(host_address, grpc.ssl_channel_credentials(), interceptors=interceptors) | ||
else: | ||
channel = grpc_aio.insecure_channel(host_address, interceptors=interceptors) | ||
|
||
return channel |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
autopep8 | ||
grpcio>=1.60.0 # 1.60.0 is the version introducing protobuf 1.25.X support, newer versions are backwards compatible | ||
protobuf | ||
asyncio | ||
pytest | ||
pytest-cov | ||
asyncio | ||
pytest-asyncio | ||
flake8 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.