diff --git a/tfx/orchestration/experimental/core/component_generated_alert.proto b/tfx/orchestration/experimental/core/component_generated_alert.proto new file mode 100644 index 0000000000..9ab6845ab1 --- /dev/null +++ b/tfx/orchestration/experimental/core/component_generated_alert.proto @@ -0,0 +1,28 @@ +// Copyright 2023 Google LLC. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Messages for configuring component generated alerts. + +syntax = "proto2"; + +package tfx.orchestration.experimental.core; + +message ComponentGeneratedAlertInfo { + optional string alert_name = 1; + optional string alert_body = 2; +} + +message ComponentGeneratedAlertList { + repeated ComponentGeneratedAlertInfo component_generated_alert_list = 1; +} diff --git a/tfx/orchestration/experimental/core/event_observer.py b/tfx/orchestration/experimental/core/event_observer.py index 7f42603d94..2813550d8d 100644 --- a/tfx/orchestration/experimental/core/event_observer.py +++ b/tfx/orchestration/experimental/core/event_observer.py @@ -84,8 +84,18 @@ class NodeStateChange: new_state: Any +@dataclasses.dataclass(frozen=True) +class ComponentGeneratedAlert: + """ComponentGeneratedAlert event.""" + execution: metadata_store_pb2.Execution + pipeline_uid: task_lib.PipelineUid + node_id: str + alert_name: str + alert_body: str + + Event = Union[PipelineStarted, PipelineFinished, NodeStateChange, - ExecutionStateChange] + ExecutionStateChange, ComponentGeneratedAlert] ObserverFn = Callable[[Event], None] diff --git a/tfx/orchestration/experimental/core/post_execution_utils.py b/tfx/orchestration/experimental/core/post_execution_utils.py index 78df60bd8c..5542dd9ceb 100644 --- a/tfx/orchestration/experimental/core/post_execution_utils.py +++ b/tfx/orchestration/experimental/core/post_execution_utils.py @@ -20,6 +20,7 @@ from tfx.dsl.io import fileio from tfx.orchestration import data_types_utils from tfx.orchestration import metadata +from tfx.orchestration.experimental.core import component_generated_alert_pb2 from tfx.orchestration.experimental.core import constants from tfx.orchestration.experimental.core import event_observer from tfx.orchestration.experimental.core import garbage_collection @@ -36,6 +37,9 @@ from ml_metadata import proto +_COMPONENT_GENERATED_ALERTS_KEY = '__component_generated_alerts__' + + def publish_execution_results_for_task(mlmd_handle: metadata.Metadata, task: task_lib.ExecNodeTask, result: ts.TaskSchedulerResult) -> None: @@ -87,7 +91,7 @@ def _update_state( # TODO(b/262040844): Instead of directly using the context manager here, we # should consider creating and using wrapper functions. with mlmd_state.evict_from_cache(task.execution_id): - execution_publish_utils.publish_succeeded_execution( + _, execution = execution_publish_utils.publish_succeeded_execution( mlmd_handle, execution_id=task.execution_id, contexts=task.contexts, @@ -96,6 +100,24 @@ def _update_state( garbage_collection.run_garbage_collection_for_node(mlmd_handle, task.node_uid, task.get_node()) + if _COMPONENT_GENERATED_ALERTS_KEY in execution.custom_properties: + alerts_proto = component_generated_alert_pb2.ComponentGeneratedAlertList() + execution.custom_properties[ + _COMPONENT_GENERATED_ALERTS_KEY + ].proto_value.Unpack(alerts_proto) + + for alert in alerts_proto.component_generated_alert_list: + alert_event = event_observer.ComponentGeneratedAlert( + execution=execution, + pipeline_uid=task_lib.PipelineUid( + pipeline_id=task.pipeline.pipeline_info.id + ), + node_id=task.node_uid.node_id, + alert_body=alert.alert_body, + alert_name=alert.alert_name, + ) + event_observer.notify(alert_event) + elif isinstance(result.output, ts.ImporterNodeOutput): output_artifacts = result.output.output_artifacts _remove_temporary_task_dirs( @@ -156,12 +178,13 @@ def publish_execution_results( # TODO(b/262040844): Instead of directly using the context manager here, we # should consider creating and using wrapper functions. with mlmd_state.evict_from_cache(execution_info.execution_id): - return execution_publish_utils.publish_succeeded_execution( + output_dict, _ = execution_publish_utils.publish_succeeded_execution( mlmd_handle, execution_id=execution_info.execution_id, contexts=contexts, output_artifacts=execution_info.output_dict, executor_output=executor_output) + return output_dict def _update_execution_state_in_mlmd( diff --git a/tfx/orchestration/experimental/core/post_execution_utils_test.py b/tfx/orchestration/experimental/core/post_execution_utils_test.py index bc38963455..8422ddd01a 100644 --- a/tfx/orchestration/experimental/core/post_execution_utils_test.py +++ b/tfx/orchestration/experimental/core/post_execution_utils_test.py @@ -20,11 +20,17 @@ from tfx.dsl.io import fileio from tfx.orchestration import data_types_utils from tfx.orchestration import metadata +from tfx.orchestration.experimental.core import component_generated_alert_pb2 +from tfx.orchestration.experimental.core import event_observer from tfx.orchestration.experimental.core import post_execution_utils +from tfx.orchestration.experimental.core import task as task_lib +from tfx.orchestration.experimental.core import task_scheduler as ts +from tfx.orchestration.experimental.core import test_utils from tfx.orchestration.portable import data_types from tfx.orchestration.portable import execution_publish_utils from tfx.proto.orchestration import execution_invocation_pb2 from tfx.proto.orchestration import execution_result_pb2 +from tfx.proto.orchestration import pipeline_pb2 from tfx.types import standard_artifacts from tfx.utils import status as status_lib from tfx.utils import test_case_utils as tu @@ -102,6 +108,8 @@ def test_publish_execution_results_succeeded_execution(self, mock_publish): executor_output = execution_result_pb2.ExecutorOutput() executor_output.execution_result.code = 0 + mock_publish.return_value = [None, None] + post_execution_utils.publish_execution_results( self.mlmd_handle, executor_output, execution_info, contexts=[]) @@ -113,6 +121,63 @@ def test_publish_execution_results_succeeded_execution(self, mock_publish): output_artifacts=execution_info.output_dict, executor_output=executor_output) + @mock.patch.object(event_observer, 'notify') + def test_publish_execution_results_for_task_with_alerts(self, mock_notify): + _ = self._prepare_execution_info() + + executor_output = execution_result_pb2.ExecutorOutput() + executor_output.execution_result.code = 0 + + component_generated_alerts = ( + component_generated_alert_pb2.ComponentGeneratedAlertList() + ) + component_generated_alerts.component_generated_alert_list.append( + component_generated_alert_pb2.ComponentGeneratedAlertInfo( + alert_name='test_alert', + alert_body='test_alert_body', + ) + ) + executor_output.execution_properties[ + post_execution_utils._COMPONENT_GENERATED_ALERTS_KEY + ].proto_value.Pack(component_generated_alerts) + + [execution] = self.mlmd_handle.store.get_executions() + + # Create test pipeline. + deployment_config = pipeline_pb2.IntermediateDeploymentConfig() + executor_spec = pipeline_pb2.ExecutorSpec.PythonClassExecutorSpec( + class_path='trainer.TrainerExecutor') + deployment_config.executor_specs['AlertGenerator'].Pack( + executor_spec + ) + pipeline = pipeline_pb2.Pipeline() + pipeline.nodes.add().pipeline_node.node_info.id = 'AlertGenerator' + pipeline.pipeline_info.id = 'test-pipeline' + pipeline.deployment_config.Pack(deployment_config) + + node_uid = task_lib.NodeUid( + pipeline_uid=task_lib.PipelineUid( + pipeline_id=pipeline.pipeline_info.id + ), + node_id='AlertGenerator', + ) + task = test_utils.create_exec_node_task( + node_uid=node_uid, + execution=execution, + pipeline=pipeline, + ) + result = ts.TaskSchedulerResult( + status=status_lib.Status( + code=status_lib.Code.OK, + message='test TaskScheduler result' + ), + output=ts.ExecutorNodeOutput(executor_output=executor_output) + ) + post_execution_utils.publish_execution_results_for_task( + self.mlmd_handle, task, result + ) + mock_notify.assert_called_once() + if __name__ == '__main__': tf.test.main() diff --git a/tfx/orchestration/experimental/core/test_utils.py b/tfx/orchestration/experimental/core/test_utils.py index b3a44d44aa..02f4b2b0d6 100644 --- a/tfx/orchestration/experimental/core/test_utils.py +++ b/tfx/orchestration/experimental/core/test_utils.py @@ -266,8 +266,10 @@ def fake_finish_node_with_handle( else: output_artifacts = None contexts = context_lib.prepare_contexts(mlmd_handle, node.contexts) - return execution_publish_utils.publish_succeeded_execution( - mlmd_handle, execution_id, contexts, output_artifacts) + output_dict, _ = execution_publish_utils.publish_succeeded_execution( + mlmd_handle, execution_id, contexts, output_artifacts + ) + return output_dict def create_exec_node_task( diff --git a/tfx/orchestration/portable/cache_utils_test.py b/tfx/orchestration/portable/cache_utils_test.py index 6bac1f114c..429c3d8d5c 100644 --- a/tfx/orchestration/portable/cache_utils_test.py +++ b/tfx/orchestration/portable/cache_utils_test.py @@ -215,7 +215,7 @@ def testGetCachedOutputArtifacts(self, mock_verify_artifacts): }) execution_two = execution_publish_utils.register_execution( m, metadata_store_pb2.ExecutionType(name='my_type'), [cache_context]) - output_artifacts = execution_publish_utils.publish_succeeded_execution( + output_artifacts, _ = execution_publish_utils.publish_succeeded_execution( m, execution_two.id, [cache_context], output_artifacts={ diff --git a/tfx/orchestration/portable/execution_publish_utils.py b/tfx/orchestration/portable/execution_publish_utils.py index 0043e136a0..ef3b98e383 100644 --- a/tfx/orchestration/portable/execution_publish_utils.py +++ b/tfx/orchestration/portable/execution_publish_utils.py @@ -78,7 +78,10 @@ def publish_succeeded_execution( contexts: Sequence[metadata_store_pb2.Context], output_artifacts: Optional[typing_utils.ArtifactMultiMap] = None, executor_output: Optional[execution_result_pb2.ExecutorOutput] = None, -) -> Optional[typing_utils.ArtifactMultiMap]: +) -> tuple[ + Optional[typing_utils.ArtifactMultiMap], + metadata_store_pb2.Execution, +]: """Marks an existing execution as success. Also publishes the output artifacts produced by the execution. This method @@ -102,9 +105,10 @@ def publish_succeeded_execution( artifact should not change the type of the artifact. Returns: - The maybe updated output_artifacts, note that only outputs whose key are in - executor_output will be updated and others will be untouched. That said, - it can be partially updated. + The tuple containing the maybe updated output_artifacts (note that only + outputs whose key are in executor_output will be updated and others will be + untouched, that said, it can be partially updated) and the written + execution. Raises: RuntimeError: if the executor output to a output channel is partial. """ @@ -147,14 +151,14 @@ def publish_succeeded_execution( execution.custom_properties[key].CopyFrom(value) set_execution_result_if_not_empty(executor_output, execution) - execution_lib.put_execution( + execution = execution_lib.put_execution( metadata_handle, execution, contexts, output_artifacts=output_artifacts_to_publish, ) - return output_artifacts_to_publish + return output_artifacts_to_publish, execution def publish_failed_execution( diff --git a/tfx/orchestration/portable/execution_publish_utils_test.py b/tfx/orchestration/portable/execution_publish_utils_test.py index 1341856a20..8def6775ab 100644 --- a/tfx/orchestration/portable/execution_publish_utils_test.py +++ b/tfx/orchestration/portable/execution_publish_utils_test.py @@ -191,10 +191,15 @@ def testPublishSuccessfulExecution(self): value {int_value: 1} } """, executor_output.output_artifacts[output_key].artifacts.add()) - output_dict = execution_publish_utils.publish_succeeded_execution( - m, execution_id, contexts, {output_key: [output_example]}, - executor_output) - [execution] = m.store.get_executions() + output_dict, execution = ( + execution_publish_utils.publish_succeeded_execution( + m, + execution_id, + contexts, + {output_key: [output_example]}, + executor_output, + ) + ) self.assertProtoPartiallyEquals( """ id: 1 @@ -303,7 +308,7 @@ def testPublishSuccessfulExecutionWithRuntimeResolvedUri(self): }} """, executor_output.output_artifacts[output_key].artifacts.add()) - output_dict = execution_publish_utils.publish_succeeded_execution( + output_dict, _ = execution_publish_utils.publish_succeeded_execution( m, execution_id, contexts, {output_key: [output_example]}, executor_output) self.assertLen(output_dict[output_key], 2) @@ -361,7 +366,7 @@ def testPublishSuccessfulExecutionOmitsArtifactIfNotResolvedDuringRuntime( value {{int_value: 1}} }} """, executor_output.output_artifacts['key1'].artifacts.add()) - output_dict = execution_publish_utils.publish_succeeded_execution( + output_dict, _ = execution_publish_utils.publish_succeeded_execution( m, execution_id, contexts, original_artifacts, executor_output) self.assertEmpty(output_dict['key1']) self.assertNotEmpty(output_dict['key2']) @@ -414,10 +419,15 @@ def testPublishSuccessExecutionExecutorEditedOutputDict(self): } """, executor_output.output_artifacts[output_key].artifacts.add()) - output_dict = execution_publish_utils.publish_succeeded_execution( - m, execution_id, contexts, {output_key: [output_example]}, - executor_output) - [execution] = m.store.get_executions() + output_dict, execution = ( + execution_publish_utils.publish_succeeded_execution( + m, + execution_id, + contexts, + {output_key: [output_example]}, + executor_output, + ) + ) self.assertProtoPartiallyEquals( """ id: 1 diff --git a/tfx/orchestration/portable/inputs_utils_test.py b/tfx/orchestration/portable/inputs_utils_test.py index 085eb5929f..8e61c45902 100644 --- a/tfx/orchestration/portable/inputs_utils_test.py +++ b/tfx/orchestration/portable/inputs_utils_test.py @@ -72,9 +72,10 @@ def fake_execute(self, metadata_handle, pipeline_node, input_map, output_map): execution = execution_publish_utils.register_execution( metadata_handle, pipeline_node.node_info.type, contexts, input_map ) - return execution_publish_utils.publish_succeeded_execution( + output_dict, _ = execution_publish_utils.publish_succeeded_execution( metadata_handle, execution.id, contexts, output_map ) + return output_dict def assertArtifactEqual(self, expected, actual): self.assertProtoPartiallyEquals(