8
8
from datetime import datetime , timezone
9
9
from typing import Any , Callable , Dict , List , Optional , TypeVar , Union
10
10
11
+ from pydantic import BaseModel
11
12
from dapr .ext .workflow import (
12
13
DaprWorkflowClient ,
13
14
WorkflowActivityContext ,
@@ -83,12 +84,6 @@ def model_post_init(self, __context: Any) -> None:
83
84
84
85
self .start_runtime ()
85
86
86
- # Discover and register tasks and workflows
87
- discovered_tasks = self ._discover_tasks ()
88
- self ._register_tasks (discovered_tasks )
89
- discovered_wfs = self ._discover_workflows ()
90
- self ._register_workflows (discovered_wfs )
91
-
92
87
# Set up automatic signal handlers for graceful shutdown
93
88
try :
94
89
self .setup_signal_handlers ()
@@ -350,6 +345,10 @@ def register_tasks_from_package(self, package_name: str) -> None:
350
345
def _register_tasks (self , tasks : Dict [str , Callable ]) -> None :
351
346
"""Register each discovered task with the Dapr runtime using direct registration."""
352
347
for task_name , method in tasks .items ():
348
+ # Don't reregister tasks that are already registered
349
+ if task_name in self .tasks :
350
+ continue
351
+
353
352
llm = self ._choose_llm_for (method )
354
353
logger .debug (
355
354
f"Registering task '{ task_name } ' with llm={ getattr (llm , '__class__' , None )} "
@@ -436,6 +435,10 @@ def _discover_workflows(self) -> Dict[str, Callable]:
436
435
def _register_workflows (self , wfs : Dict [str , Callable ]) -> None :
437
436
"""Register each discovered workflow with the Dapr runtime."""
438
437
for wf_name , method in wfs .items ():
438
+ # Don't reregister workflows that are already registered
439
+ if wf_name in self .workflows :
440
+ continue
441
+
439
442
# Use a closure helper to avoid late-binding capture issues.
440
443
def make_wrapped (meth : Callable ) -> Callable :
441
444
@functools .wraps (meth )
@@ -523,6 +526,17 @@ def start_runtime(self):
523
526
else :
524
527
logger .debug ("Workflow runtime already running; skipping." )
525
528
529
+ self ._ensure_activities_registered ()
530
+
531
+ def _ensure_activities_registered (self ):
532
+ """Ensure all workflow activities are registered with the Dapr runtime."""
533
+ # Discover and register tasks and workflows
534
+ discovered_tasks = self ._discover_tasks ()
535
+ self ._register_tasks (discovered_tasks )
536
+ discovered_wfs = self ._discover_workflows ()
537
+ self ._register_workflows (discovered_wfs )
538
+ logger .debug ("Workflow activities registration completed." )
539
+
526
540
def _sync_workflow_state_after_startup (self ):
527
541
"""
528
542
Sync database workflow state with actual Dapr workflow status after runtime startup.
@@ -540,25 +554,17 @@ def _sync_workflow_state_after_startup(self):
540
554
)
541
555
return
542
556
543
- logger .debug ("Syncing workflow state with Dapr after runtime startup..." )
544
557
self .load_state ()
545
-
546
- # Check if we have instances to sync
547
- instances = (
548
- getattr (self .state , "instances" , {})
549
- if hasattr (self .state , "instances" )
550
- else self .state .get ("instances" , {})
551
- )
552
- if not instances :
553
- return
558
+ instances = self .state .get ("instances" , {})
554
559
555
560
logger .debug (f"Found { len (instances )} workflow instances to sync" )
556
561
557
562
# Sync each instance with Dapr's actual status
558
563
for instance_id , instance_data in instances .items ():
559
564
try :
560
565
# Skip if already completed
561
- if instance_data .get ("end_time" ) is not None :
566
+ end_time = instance_data .get ("end_time" )
567
+ if end_time is not None :
562
568
continue
563
569
564
570
# Get actual status from Dapr
@@ -580,6 +586,7 @@ def _sync_workflow_state_after_startup(self):
580
586
timezone .utc
581
587
).isoformat ()
582
588
instance_data ["status" ] = runtime_status .lower ()
589
+
583
590
logger .debug (
584
591
f"Marked workflow { instance_id } as { runtime_status .lower ()} in database"
585
592
)
@@ -600,6 +607,7 @@ def _sync_workflow_state_after_startup(self):
600
607
timezone .utc
601
608
).isoformat ()
602
609
instance_data ["status" ] = DaprWorkflowStatus .COMPLETED .value
610
+
603
611
logger .debug (
604
612
f"Workflow { instance_id } no longer in Dapr, marked as completed"
605
613
)
@@ -804,9 +812,7 @@ async def _create_agent_span_for_resumed_workflow(
804
812
# Get tracer and create AGENT span as child of the original trace
805
813
tracer = trace .get_tracer (__name__ )
806
814
agent_name = getattr (self , "name" , "DurableAgent" )
807
- workflow_name = instance_data .get (
808
- "workflow_name" , "ToolCallingWorkflow"
809
- )
815
+ workflow_name = instance_data .get ("workflow_name" , "AgenticWorkflow" )
810
816
span_name = f"{ agent_name } .{ workflow_name } "
811
817
812
818
# Create the AGENT span that will show up in the trace
0 commit comments