diff --git a/.changes/unreleased/Fixes-20241029-182034.yaml b/.changes/unreleased/Fixes-20241029-182034.yaml new file mode 100644 index 00000000000..5b5f6f6a07d --- /dev/null +++ b/.changes/unreleased/Fixes-20241029-182034.yaml @@ -0,0 +1,7 @@ +kind: Fixes +body: Handle exceptions in `get_execution_status` more broadly to better ensure `run_results.json` + gets written +time: 2024-10-29T18:20:34.782845-05:00 +custom: + Author: QMalcolm + Issue: "10934" diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index e0c436b5a61..ce613c0e44d 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -94,11 +94,16 @@ def get_execution_status(sql: str, adapter: BaseAdapter) -> Tuple[RunStatus, str response, _ = adapter.execute(sql, auto_begin=False, fetch=False) status = RunStatus.Success message = response._message + except (KeyboardInterrupt, SystemExit): + raise except DbtRuntimeError as exc: status = RunStatus.Error message = exc.msg - finally: - return status, message + except Exception as exc: + status = RunStatus.Error + message = str(exc) + + return (status, message) def _get_adapter_info(adapter, run_model_result) -> Dict[str, Any]: @@ -792,8 +797,23 @@ def after_run(self, adapter, results) -> None: ], # exclude that didn't fail to preserve backwards compatibility "database_schemas": list(database_schema_set), } - with adapter.connection_named("master"): - self.safe_run_hooks(adapter, RunHookType.End, extras) + + try: + with adapter.connection_named("master"): + self.safe_run_hooks(adapter, RunHookType.End, extras) + except (KeyboardInterrupt, SystemExit): + run_result = self.get_result( + results=self.node_results, + elapsed_time=time.time() - self.started_at, + generated_at=datetime.utcnow(), + ) + + if self.args.write_json and hasattr(run_result, "write"): + run_result.write(self.result_path()) + + print_run_end_messages(self.node_results, keyboard_interrupt=True) + + raise def get_node_selector(self) -> ResourceTypeSelector: if self.manifest is None or self.graph is None: diff --git a/tests/unit/task/test_run.py b/tests/unit/task/test_run.py index d2ed50ec1ad..7063ca4200d 100644 --- a/tests/unit/task/test_run.py +++ b/tests/unit/task/test_run.py @@ -3,21 +3,29 @@ from dataclasses import dataclass from datetime import datetime, timedelta from importlib import import_module -from typing import Optional +from typing import Optional, Type, Union +from unittest import mock from unittest.mock import MagicMock, patch import pytest +from psycopg2 import DatabaseError from pytest_mock import MockerFixture +from dbt.adapters.contracts.connection import AdapterResponse from dbt.adapters.postgres import PostgresAdapter +from dbt.artifacts.resources.base import FileHash +from dbt.artifacts.resources.types import NodeType, RunHookType +from dbt.artifacts.resources.v1.components import DependsOn +from dbt.artifacts.resources.v1.config import NodeConfig from dbt.artifacts.resources.v1.model import ModelConfig from dbt.artifacts.schemas.batch_results import BatchResults from dbt.artifacts.schemas.results import RunStatus from dbt.artifacts.schemas.run import RunResult from dbt.config.runtime import RuntimeConfig from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.graph.nodes import ModelNode +from dbt.contracts.graph.nodes import HookNode, ModelNode from dbt.events.types import LogModelResult +from dbt.exceptions import DbtRuntimeError from dbt.flags import get_flags, set_from_args from dbt.task.run import ModelRunner, RunTask, _get_adapter_info from dbt.tests.util import safe_set_invocation_context @@ -256,3 +264,85 @@ class Relation: # Assert result of _is_incremental assert model_runner._is_incremental(model) == expectation + + +class TestRunTask: + @pytest.fixture + def hook_node(self) -> HookNode: + return HookNode( + package_name="test", + path="/root/x/path.sql", + original_file_path="/root/path.sql", + language="sql", + raw_code="select * from wherever", + name="foo", + resource_type=NodeType.Operation, + unique_id="model.test.foo", + fqn=["test", "models", "foo"], + refs=[], + sources=[], + metrics=[], + depends_on=DependsOn(), + description="", + database="test_db", + schema="test_schema", + alias="bar", + tags=[], + config=NodeConfig(), + index=None, + checksum=FileHash.from_contents(""), + unrendered_config={}, + ) + + @pytest.mark.parametrize( + "error_to_raise,expected_result", + [ + (None, RunStatus.Success), + (DbtRuntimeError, RunStatus.Error), + (DatabaseError, RunStatus.Error), + (KeyboardInterrupt, KeyboardInterrupt), + ], + ) + def test_safe_run_hooks( + self, + mocker: MockerFixture, + runtime_config: RuntimeConfig, + manifest: Manifest, + hook_node: HookNode, + error_to_raise: Optional[Type[Exception]], + expected_result: Union[RunStatus, Type[Exception]], + ): + mocker.patch("dbt.task.run.RunTask.get_hooks_by_type").return_value = [hook_node] + mocker.patch("dbt.task.run.RunTask.get_hook_sql").return_value = hook_node.raw_code + + flags = mock.Mock() + flags.state = None + flags.defer_state = None + + run_task = RunTask( + args=flags, + config=runtime_config, + manifest=manifest, + ) + + adapter = mock.Mock() + adapter_execute = mock.Mock() + adapter_execute.return_value = (AdapterResponse(_message="Success"), None) + + if error_to_raise: + adapter_execute.side_effect = error_to_raise("Oh no!") + + adapter.execute = adapter_execute + + try: + result = run_task.safe_run_hooks( + adapter=adapter, + hook_type=RunHookType.End, + extra_context={}, + ) + assert isinstance(expected_result, RunStatus) + assert result == expected_result + except BaseException as e: + assert not isinstance(expected_result, RunStatus) + assert issubclass(expected_result, BaseException) + assert type(e) == expected_result