Skip to content

Commit 47008b6

Browse files
committed
no-op
PiperOrigin-RevId: 660551866
1 parent 3db64bc commit 47008b6

File tree

8 files changed

+163
-22
lines changed

8 files changed

+163
-22
lines changed

tfx/orchestration/experimental/core/async_pipeline_task_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,8 @@ def _generate_tasks_for_node(
490490
execution_type=node.node_info.type,
491491
contexts=resolved_info.contexts,
492492
input_and_params=unprocessed_inputs,
493+
pipeline=self._pipeline,
494+
node_id=node.node_info.id,
493495
)
494496

495497
for execution in executions:

tfx/orchestration/experimental/core/pipeline_state.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1673,3 +1673,11 @@ def get_pipeline_and_node(
16731673
'pipeline nodes are supported for external executions.'
16741674
)
16751675
return (pipeline_state.pipeline, node)
1676+
1677+
1678+
def get_pipeline(
1679+
mlmd_handle: metadata.Metadata, pipeline_id: str
1680+
) -> pipeline_pb2.Pipeline:
1681+
"""Loads the pipeline proto for a pipeline from latest execution."""
1682+
pipeline_view = PipelineView.load(mlmd_handle, pipeline_id)
1683+
return pipeline_view.pipeline

tfx/orchestration/experimental/core/sync_pipeline_task_gen.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,8 @@ def _generate_tasks_from_resolved_inputs(
564564
execution_type=node.node_info.type,
565565
contexts=resolved_info.contexts,
566566
input_and_params=resolved_info.input_and_params,
567+
pipeline=self._pipeline,
568+
node_id=node.node_info.id,
567569
)
568570

569571
result.extend(

tfx/orchestration/experimental/core/task_gen_utils.py

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from tfx.orchestration import metadata
3131
from tfx.orchestration import node_proto_view
3232
from tfx.orchestration.experimental.core import constants
33+
from tfx.orchestration.experimental.core import env
3334
from tfx.orchestration.experimental.core import mlmd_state
3435
from tfx.orchestration.experimental.core import task as task_lib
3536
from tfx.orchestration import mlmd_connection_manager as mlmd_cm
@@ -548,21 +549,41 @@ def register_executions_from_existing_executions(
548549
contexts = metadata_handle.store.get_contexts_by_execution(
549550
existing_executions[0].id
550551
)
551-
return execution_lib.put_executions(
552+
executions = execution_lib.put_executions(
552553
metadata_handle,
553554
new_executions,
554555
contexts,
555556
input_artifacts_maps=input_artifacts,
556557
)
557558

559+
pipeline_asset = metadata_handle.store.pipeline_asset
560+
if pipeline_asset:
561+
env.get_env().create_pipeline_run_node_executions(
562+
pipeline_asset.owner,
563+
pipeline_asset.name,
564+
pipeline,
565+
node.node_info.id,
566+
executions,
567+
)
568+
else:
569+
logging.warning(
570+
'Pipeline asset %s not found in MLMD. Unable to create pipeline run'
571+
' node executions.',
572+
pipeline_asset,
573+
)
574+
return executions
575+
558576

577+
# TODO(b/349654866): make pipeline and node_id non-optional.
559578
def register_executions(
560579
metadata_handle: metadata.Metadata,
561580
execution_type: metadata_store_pb2.ExecutionType,
562581
contexts: Sequence[metadata_store_pb2.Context],
563582
input_and_params: Sequence[InputAndParam],
583+
pipeline: Optional[pipeline_pb2.Pipeline] = None,
584+
node_id: Optional[str] = None,
564585
) -> Sequence[metadata_store_pb2.Execution]:
565-
"""Registers multiple executions in MLMD.
586+
"""Registers multiple executions in storage backends.
566587
567588
Along with the execution:
568589
- the input artifacts will be linked to the executions.
@@ -575,6 +596,8 @@ def register_executions(
575596
input_and_params: A list of InputAndParams, which includes input_dicts
576597
(dictionaries of artifacts. One execution will be registered for each of
577598
the input_dict) and corresponding exec_properties.
599+
pipeline: Optional. The pipeline proto.
600+
node_id: Optional. The node id of the executions to be registered.
578601
579602
Returns:
580603
A list of MLMD executions that are registered in MLMD, with id populated.
@@ -603,21 +626,41 @@ def register_executions(
603626
executions.append(execution)
604627

605628
if len(executions) == 1:
606-
return [
629+
new_executions = [
607630
execution_lib.put_execution(
608631
metadata_handle,
609632
executions[0],
610633
contexts,
611634
input_artifacts=input_and_params[0].input_artifacts,
612635
)
613636
]
637+
else:
638+
new_executions = execution_lib.put_executions(
639+
metadata_handle,
640+
executions,
641+
contexts,
642+
[
643+
input_and_param.input_artifacts
644+
for input_and_param in input_and_params
645+
],
646+
)
614647

615-
return execution_lib.put_executions(
616-
metadata_handle,
617-
executions,
618-
contexts,
619-
[input_and_param.input_artifacts for input_and_param in input_and_params],
620-
)
648+
pipeline_asset = metadata_handle.store.pipeline_asset
649+
if pipeline_asset and pipeline and node_id:
650+
env.get_env().create_pipeline_run_node_executions(
651+
pipeline_asset.owner,
652+
pipeline_asset.name,
653+
pipeline,
654+
node_id,
655+
new_executions,
656+
)
657+
else:
658+
logging.warning(
659+
'Skipping creating pipeline run node executions for pipeline asset %s.',
660+
pipeline_asset,
661+
)
662+
663+
return new_executions
621664

622665

623666
def update_external_artifact_type(

tfx/orchestration/portable/execution_publish_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def publish_cached_executions(
3737
output_artifacts_maps: Optional[
3838
Sequence[typing_utils.ArtifactMultiMap]
3939
] = None,
40-
) -> None:
40+
) -> Sequence[metadata_store_pb2.Execution]:
4141
"""Marks an existing execution as using cached outputs from a previous execution.
4242
4343
Args:
@@ -46,11 +46,14 @@ def publish_cached_executions(
4646
executions: Executions that will be published as CACHED executions.
4747
output_artifacts_maps: A list of output artifacts of the executions. Each
4848
artifact will be linked with the execution through an event of type OUTPUT
49+
50+
Returns:
51+
A list of MLMD executions that are published to MLMD, with id pupulated.
4952
"""
5053
for execution in executions:
5154
execution.last_known_state = metadata_store_pb2.Execution.CACHED
5255

53-
execution_lib.put_executions(
56+
return execution_lib.put_executions(
5457
metadata_handle,
5558
executions,
5659
contexts,

tfx/orchestration/portable/importer_node_handler.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from tfx.dsl.components.common import importer
2121
from tfx.orchestration import data_types_utils
2222
from tfx.orchestration import metadata
23+
from tfx.orchestration.experimental.core import env
24+
from tfx.orchestration.experimental.core import pipeline_state as pstate
2325
from tfx.orchestration.portable import data_types
2426
from tfx.orchestration.portable import execution_publish_utils
2527
from tfx.orchestration.portable import inputs_utils
@@ -57,7 +59,7 @@ def run(
5759
5860
Args:
5961
mlmd_connection: ML metadata connection.
60-
pipeline_node: The specification of the node that this launcher lauches.
62+
pipeline_node: The specification of the node that this launcher launches.
6163
pipeline_info: The information of the pipeline that this node runs in.
6264
pipeline_runtime_spec: The runtime information of the pipeline that this
6365
node runs in.
@@ -78,13 +80,29 @@ def run(
7880
inputs_utils.resolve_parameters_with_schema(
7981
node_parameters=pipeline_node.parameters))
8082

81-
# 3. Registers execution in metadata.
83+
# 3. Registers execution in storage backend.
8284
execution = execution_publish_utils.register_execution(
8385
metadata_handle=m,
8486
execution_type=pipeline_node.node_info.type,
8587
contexts=contexts,
8688
exec_properties=exec_properties,
8789
)
90+
pipeline_asset = m.store.pipeline_asset
91+
if pipeline_asset:
92+
env.get_env().create_pipeline_run_node_executions(
93+
pipeline_asset.owner,
94+
pipeline_asset.name,
95+
pstate.get_pipeline(m, pipeline_info.id),
96+
pipeline_node.node_info.id,
97+
[execution],
98+
)
99+
else:
100+
logging.warning(
101+
'Pipeline asset %s not found in MLMD. Unable to create pipeline run'
102+
' node execution %s.',
103+
pipeline_asset,
104+
execution,
105+
)
88106

89107
# 4. Generate output artifacts to represent the imported artifacts.
90108
output_key = cast(str, exec_properties[importer.OUTPUT_KEY_KEY])

tfx/orchestration/portable/partial_run_utils.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tfx.dsl.compiler import constants
2525
from tfx.orchestration import metadata
2626
from tfx.orchestration import node_proto_view
27+
from tfx.orchestration.experimental.core import env
2728
from tfx.orchestration.portable import execution_publish_utils
2829
from tfx.orchestration.portable.mlmd import context_lib
2930
from tfx.orchestration.portable.mlmd import execution_lib
@@ -599,6 +600,8 @@ def __init__(
599600
for node in node_proto_view.get_view_for_all_in(new_pipeline_run_ir)
600601
}
601602

603+
self._pipeline = new_pipeline_run_ir
604+
602605
def _get_base_pipeline_run_context(
603606
self, base_run_id: Optional[str] = None
604607
) -> metadata_store_pb2.Context:
@@ -788,7 +791,12 @@ def _cache_and_publish(
788791
contexts=[self._new_pipeline_run_context] + node_contexts,
789792
)
790793
)
791-
if not prev_cache_executions:
794+
795+
# If there are no previous attempts to cache and publish, we will create new
796+
# cache executions.
797+
create_new_cache_executions: bool = not prev_cache_executions
798+
799+
if create_new_cache_executions:
792800
new_cached_executions = []
793801
for e in existing_executions:
794802
new_cached_executions.append(
@@ -820,12 +828,36 @@ def _cache_and_publish(
820828
execution_lib.get_output_artifacts(self._mlmd, e.id)
821829
for e in existing_executions
822830
]
823-
execution_publish_utils.publish_cached_executions(
824-
self._mlmd,
825-
contexts=cached_execution_contexts,
826-
executions=new_cached_executions,
827-
output_artifacts_maps=output_artifacts_maps,
828-
)
831+
832+
if create_new_cache_executions:
833+
new_executions = execution_publish_utils.publish_cached_executions(
834+
self._mlmd,
835+
contexts=cached_execution_contexts,
836+
executions=new_cached_executions,
837+
output_artifacts_maps=output_artifacts_maps,
838+
)
839+
pipeline_asset = self._mlmd.store.pipeline_asset
840+
if pipeline_asset:
841+
env.get_env().create_pipeline_run_node_executions(
842+
pipeline_asset.owner,
843+
pipeline_asset.name,
844+
self._pipeline,
845+
node.node_info.id,
846+
new_executions,
847+
)
848+
else:
849+
logging.warning(
850+
'Pipeline asset %s not found in MLMD. Unable to create pipeline run'
851+
' node executions.',
852+
pipeline_asset,
853+
)
854+
else:
855+
execution_publish_utils.publish_cached_executions(
856+
self._mlmd,
857+
contexts=cached_execution_contexts,
858+
executions=new_cached_executions,
859+
output_artifacts_maps=output_artifacts_maps,
860+
)
829861

830862
def put_parent_context(self):
831863
"""Puts a ParentContext edge in MLMD."""

tfx/orchestration/portable/resolver_node_handler.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import grpc
2121
from tfx.orchestration import data_types_utils
2222
from tfx.orchestration import metadata
23+
from tfx.orchestration.experimental.core import env
24+
from tfx.orchestration.experimental.core import pipeline_state as pstate
2325
from tfx.orchestration.portable import data_types
2426
from tfx.orchestration.portable import execution_publish_utils
2527
from tfx.orchestration.portable import inputs_utils
@@ -86,6 +88,22 @@ def run(
8688
contexts=contexts,
8789
exec_properties=exec_properties,
8890
)
91+
pipeline_asset = m.store.pipeline_asset
92+
if pipeline_asset:
93+
env.get_env().create_pipeline_run_node_executions(
94+
pipeline_asset.owner,
95+
pipeline_asset.name,
96+
pstate.get_pipeline(m, pipeline_info.id),
97+
pipeline_node.node_info.id,
98+
[execution],
99+
)
100+
else:
101+
logging.warning(
102+
'Pipeline asset %s not found in MLMD. Unable to create pipeline'
103+
' run node execution %s.',
104+
pipeline_asset,
105+
execution,
106+
)
89107
execution_publish_utils.publish_failed_execution(
90108
metadata_handle=m,
91109
contexts=contexts,
@@ -103,14 +121,29 @@ def run(
103121
if isinstance(resolved_inputs, inputs_utils.Skip):
104122
return data_types.ExecutionInfo()
105123

106-
# 3. Registers execution in metadata.
124+
# 3. Registers execution in storage backends.
107125
execution = execution_publish_utils.register_execution(
108126
metadata_handle=m,
109127
execution_type=pipeline_node.node_info.type,
110128
contexts=contexts,
111129
exec_properties=exec_properties,
112130
)
113-
131+
pipeline_asset = m.store.pipeline_asset
132+
if pipeline_asset:
133+
env.get_env().create_pipeline_run_node_executions(
134+
pipeline_asset.owner,
135+
pipeline_asset.name,
136+
pstate.get_pipeline(m, pipeline_info.id),
137+
pipeline_node.node_info.id,
138+
[execution],
139+
)
140+
else:
141+
logging.warning(
142+
'Pipeline asset %s not found in MLMD. Unable to create pipeline'
143+
' run node execution %s.',
144+
pipeline_asset,
145+
execution,
146+
)
114147
# TODO(b/197741942): Support len > 1.
115148
if len(resolved_inputs) > 1:
116149
execution_publish_utils.publish_failed_execution(

0 commit comments

Comments
 (0)