Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[flake8]
ignore = E501,C901
exclude =
.git
*_pb2*
__pycache__
ignore = E203,E501,W503,E701,E704,F821,C901
extend-exclude = .tox,venv,.venv,build,**/.venv,**/venv,*pb2_grpc.py,*pb2.py
per-file-ignores=
examples/**:F541 setup.py:E121
tests/**:F541,E712
max-line-length = 100
3 changes: 1 addition & 2 deletions .github/workflows/pr-validation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install flake8 pytest
pip install -r requirements.txt
pip install -r dev-requirements.txt
- name: Lint with flake8
run: |
flake8 . --count --show-source --statistics --exit-zero
Expand Down
9 changes: 8 additions & 1 deletion dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,8 @@
grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python
# TODO: move to pyproject optional-dependencies
pytest-asyncio>=0.23
flake8
tox>=4.0.0
pytest
pytest-cov
grpcio-tools==1.75.1
protobuf>=6.31.1
1 change: 0 additions & 1 deletion durabletask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@

"""Durable Task SDK for Python"""


PACKAGE_NAME = "durabletask"
116 changes: 71 additions & 45 deletions durabletask/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
from durabletask import task
from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl

TInput = TypeVar('TInput')
TOutput = TypeVar('TOutput')
TInput = TypeVar("TInput")
TOutput = TypeVar("TOutput")


class OrchestrationStatus(Enum):
"""The status of an orchestration instance."""

RUNNING = pb.ORCHESTRATION_STATUS_RUNNING
COMPLETED = pb.ORCHESTRATION_STATUS_COMPLETED
FAILED = pb.ORCHESTRATION_STATUS_FAILED
Expand Down Expand Up @@ -52,7 +53,8 @@ def raise_if_failed(self):
if self.failure_details is not None:
raise OrchestrationFailedError(
f"Orchestration '{self.instance_id}' failed: {self.failure_details.message}",
self.failure_details)
self.failure_details,
)


class OrchestrationFailedError(Exception):
Expand All @@ -65,18 +67,23 @@ def failure_details(self):
return self._failure_details


def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Optional[OrchestrationState]:
def new_orchestration_state(
instance_id: str, res: pb.GetInstanceResponse
) -> Optional[OrchestrationState]:
if not res.exists:
return None

state = res.orchestrationState

failure_details = None
if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '':
if state.failureDetails.errorMessage != "" or state.failureDetails.errorType != "":
failure_details = task.FailureDetails(
state.failureDetails.errorMessage,
state.failureDetails.errorType,
state.failureDetails.stackTrace.value if not helpers.is_empty(state.failureDetails.stackTrace) else None)
state.failureDetails.stackTrace.value
if not helpers.is_empty(state.failureDetails.stackTrace)
else None,
)

return OrchestrationState(
instance_id,
Expand All @@ -87,19 +94,21 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op
state.input.value if not helpers.is_empty(state.input) else None,
state.output.value if not helpers.is_empty(state.output) else None,
state.customStatus.value if not helpers.is_empty(state.customStatus) else None,
failure_details)
failure_details,
)


class TaskHubGrpcClient:

def __init__(self, *,
host_address: Optional[str] = None,
metadata: Optional[list[tuple[str, str]]] = None,
log_handler: Optional[logging.Handler] = None,
log_formatter: Optional[logging.Formatter] = None,
secure_channel: bool = False,
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None):

def __init__(
self,
*,
host_address: Optional[str] = None,
metadata: Optional[list[tuple[str, str]]] = None,
log_handler: Optional[logging.Handler] = None,
log_formatter: Optional[logging.Formatter] = None,
secure_channel: bool = False,
interceptors: Optional[Sequence[shared.ClientInterceptor]] = None,
):
# If the caller provided metadata, we need to create a new interceptor for it and
# add it to the list of interceptors.
if interceptors is not None:
Expand All @@ -112,25 +121,28 @@ def __init__(self, *,
interceptors = None

channel = shared.get_grpc_channel(
host_address=host_address,
secure_channel=secure_channel,
interceptors=interceptors
host_address=host_address, secure_channel=secure_channel, interceptors=interceptors
)
self._stub = stubs.TaskHubSidecarServiceStub(channel)
self._logger = shared.get_logger("client", log_handler, log_formatter)

def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
start_at: Optional[datetime] = None,
reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None) -> str:

def schedule_new_orchestration(
self,
orchestrator: Union[task.Orchestrator[TInput, TOutput], str],
*,
input: Optional[TInput] = None,
instance_id: Optional[str] = None,
start_at: Optional[datetime] = None,
reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None,
) -> str:
name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator)

req = pb.CreateInstanceRequest(
name=name,
instanceId=instance_id if instance_id else uuid.uuid4().hex,
input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None,
input=wrappers_pb2.StringValue(value=shared.to_json(input))
if input is not None
else None,
scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None,
version=wrappers_pb2.StringValue(value=""),
orchestrationIdReusePolicy=reuse_id_policy,
Expand All @@ -140,19 +152,22 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu
res: pb.CreateInstanceResponse = self._stub.StartInstance(req)
return res.instanceId

def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]:
def get_orchestration_state(
self, instance_id: str, *, fetch_payloads: bool = True
) -> Optional[OrchestrationState]:
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
res: pb.GetInstanceResponse = self._stub.GetInstance(req)
return new_orchestration_state(req.instanceId, res)

def wait_for_orchestration_start(self, instance_id: str, *,
fetch_payloads: bool = False,
timeout: int = 0) -> Optional[OrchestrationState]:
def wait_for_orchestration_start(
self, instance_id: str, *, fetch_payloads: bool = False, timeout: int = 0
) -> Optional[OrchestrationState]:
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
try:
grpc_timeout = None if timeout == 0 else timeout
self._logger.info(
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start.")
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start."
)
res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=grpc_timeout)
return new_orchestration_state(req.instanceId, res)
except grpc.RpcError as rpc_error:
Expand All @@ -162,22 +177,30 @@ def wait_for_orchestration_start(self, instance_id: str, *,
else:
raise

def wait_for_orchestration_completion(self, instance_id: str, *,
fetch_payloads: bool = True,
timeout: int = 0) -> Optional[OrchestrationState]:
def wait_for_orchestration_completion(
self, instance_id: str, *, fetch_payloads: bool = True, timeout: int = 0
) -> Optional[OrchestrationState]:
req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads)
try:
grpc_timeout = None if timeout == 0 else timeout
self._logger.info(
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete.")
res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=grpc_timeout)
f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete."
)
res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(
req, timeout=grpc_timeout
)
state = new_orchestration_state(req.instanceId, res)
if not state:
return None

if state.runtime_status == OrchestrationStatus.FAILED and state.failure_details is not None:
if (
state.runtime_status == OrchestrationStatus.FAILED
and state.failure_details is not None
):
details = state.failure_details
self._logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}")
self._logger.info(
f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}"
)
elif state.runtime_status == OrchestrationStatus.TERMINATED:
self._logger.info(f"Instance '{instance_id}' was terminated.")
elif state.runtime_status == OrchestrationStatus.COMPLETED:
Expand All @@ -191,23 +214,26 @@ def wait_for_orchestration_completion(self, instance_id: str, *,
else:
raise

def raise_orchestration_event(self, instance_id: str, event_name: str, *,
data: Optional[Any] = None):
def raise_orchestration_event(
self, instance_id: str, event_name: str, *, data: Optional[Any] = None
):
req = pb.RaiseEventRequest(
instanceId=instance_id,
name=event_name,
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None)
input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None,
)

self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.")
self._stub.RaiseEvent(req)

def terminate_orchestration(self, instance_id: str, *,
output: Optional[Any] = None,
recursive: bool = True):
def terminate_orchestration(
self, instance_id: str, *, output: Optional[Any] = None, recursive: bool = True
):
req = pb.TerminateRequest(
instanceId=instance_id,
output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None,
recursive=recursive)
recursive=recursive,
)

self._logger.info(f"Terminating instance '{instance_id}'.")
self._stub.TerminateInstance(req)
Expand Down
34 changes: 22 additions & 12 deletions durabletask/internal/grpc_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,26 @@


class _ClientCallDetails(
namedtuple(
'_ClientCallDetails',
['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']),
grpc.ClientCallDetails):
namedtuple(
"_ClientCallDetails",
["method", "timeout", "metadata", "credentials", "wait_for_ready", "compression"],
),
grpc.ClientCallDetails,
):
"""This is an implementation of the ClientCallDetails interface needed for interceptors.
This class takes six named values and inherits the ClientCallDetails from grpc package.
This class encloses the values that describe a RPC to be invoked.
"""

pass


class DefaultClientInterceptorImpl (
grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor):
class DefaultClientInterceptorImpl(
grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor,
grpc.StreamStreamClientInterceptor,
):
"""The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an
interceptor to add additional headers to all calls as needed."""
Expand All @@ -29,10 +35,9 @@ def __init__(self, metadata: list[tuple[str, str]]):
super().__init__()
self._metadata = metadata

def _intercept_call(
self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:
def _intercept_call(self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:
"""Internal intercept_call implementation which adds metadata to grpc metadata in the RPC
call details."""
call details."""
if self._metadata is None:
return client_call_details

Expand All @@ -43,8 +48,13 @@ def _intercept_call(

metadata.extend(self._metadata)
client_call_details = _ClientCallDetails(
client_call_details.method, client_call_details.timeout, metadata,
client_call_details.credentials, client_call_details.wait_for_ready, client_call_details.compression)
client_call_details.method,
client_call_details.timeout,
metadata,
client_call_details.credentials,
client_call_details.wait_for_ready,
client_call_details.compression,
)

return client_call_details

Expand Down
Loading