From e85826d9e07c25f8772c9653a7b7732eb498745b Mon Sep 17 00:00:00 2001 From: Kshitij Aranke Date: Fri, 18 Oct 2024 12:21:40 +0100 Subject: [PATCH] [Tidy-First]: Fix `timings` object for hooks and macros, and make types of timings explicit --- core/dbt/artifacts/schemas/results.py | 14 ++++-- core/dbt/task/run.py | 19 +++++--- core/dbt/task/run_operation.py | 46 +++++++++++-------- .../adapter/hooks/test_on_run_hooks.py | 8 ++++ .../run_operations/test_run_operations.py | 17 +++++++ 5 files changed, 75 insertions(+), 29 deletions(-) diff --git a/core/dbt/artifacts/schemas/results.py b/core/dbt/artifacts/schemas/results.py index 00746c87885..e827cb5154e 100644 --- a/core/dbt/artifacts/schemas/results.py +++ b/core/dbt/artifacts/schemas/results.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Union from dbt.contracts.graph.nodes import ResultNode from dbt_common.dataclass_schema import StrEnum, dbtClassMixin @@ -10,7 +10,13 @@ @dataclass class TimingInfo(dbtClassMixin): - name: str + """ + Represents a step in the execution of a node. + `name` should be one of: compile, execute, or other + Do not call directly, use `collect_timing_info` instead. + """ + + name: Literal["compile", "execute", "other"] started_at: Optional[datetime] = None completed_at: Optional[datetime] = None @@ -31,7 +37,9 @@ def to_msg_dict(self): # This is a context manager class collect_timing_info: - def __init__(self, name: str, callback: Callable[[TimingInfo], None]) -> None: + def __init__( + self, name: Literal["compile", "execute", "other"], callback: Callable[[TimingInfo], None] + ) -> None: self.timing_info = TimingInfo(name=name) self.callback = callback diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 99913a551c5..f159861e1ae 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -17,6 +17,7 @@ RunningStatus, RunStatus, TimingInfo, + collect_timing_info, ) from dbt.artifacts.schemas.run import RunResult from dbt.cli.flags import Flags @@ -633,7 +634,6 @@ def get_hooks_by_type(self, hook_type: RunHookType) -> List[HookNode]: def safe_run_hooks( self, adapter: BaseAdapter, hook_type: RunHookType, extra_context: Dict[str, Any] ) -> RunStatus: - started_at = datetime.utcnow() ordered_hooks = self.get_hooks_by_type(hook_type) if hook_type == RunHookType.End and ordered_hooks: @@ -653,14 +653,20 @@ def safe_run_hooks( hook.index = idx hook_name = f"{hook.package_name}.{hook_type}.{hook.index - 1}" execution_time = 0.0 - timing = [] + timing: List[TimingInfo] = [] failures = 1 if not failed: + with collect_timing_info("compile", timing.append): + sql = self.get_hook_sql( + adapter, hook, hook.index, num_hooks, extra_context + ) + + started_at = timing[0].started_at or datetime.utcnow() hook.update_event_status( started_at=started_at.isoformat(), node_status=RunningStatus.Started ) - sql = self.get_hook_sql(adapter, hook, hook.index, num_hooks, extra_context) + fire_event( LogHookStartLine( statement=hook_name, @@ -670,11 +676,12 @@ def safe_run_hooks( ) ) - status, message = get_execution_status(sql, adapter) - finished_at = datetime.utcnow() + with collect_timing_info("execute", timing.append): + status, message = get_execution_status(sql, adapter) + + finished_at = timing[1].completed_at or datetime.utcnow() hook.update_event_status(finished_at=finished_at.isoformat()) execution_time = (finished_at - started_at).total_seconds() - timing = [TimingInfo(hook_name, started_at, finished_at)] failures = 0 if status == RunStatus.Success else 1 if status == RunStatus.Success: diff --git a/core/dbt/task/run_operation.py b/core/dbt/task/run_operation.py index 793ba81fb01..ebe8b14352e 100644 --- a/core/dbt/task/run_operation.py +++ b/core/dbt/task/run_operation.py @@ -2,11 +2,11 @@ import threading import traceback from datetime import datetime -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List import dbt_common.exceptions from dbt.adapters.factory import get_adapter -from dbt.artifacts.schemas.results import RunStatus, TimingInfo +from dbt.artifacts.schemas.results import RunStatus, TimingInfo, collect_timing_info from dbt.artifacts.schemas.run import RunResult, RunResultsArtifact from dbt.contracts.files import FileHash from dbt.contracts.graph.nodes import HookNode @@ -51,25 +51,29 @@ def _run_unsafe(self, package_name, macro_name) -> "agate.Table": return res def run(self) -> RunResultsArtifact: - start = datetime.utcnow() - self.compile_manifest() + timing: List[TimingInfo] = [] - success = True + with collect_timing_info("compile", timing.append): + self.compile_manifest() + + start = timing[0].started_at + success = True package_name, macro_name = self._get_macro_parts() - try: - self._run_unsafe(package_name, macro_name) - except dbt_common.exceptions.DbtBaseException as exc: - fire_event(RunningOperationCaughtError(exc=str(exc))) - fire_event(LogDebugStackTrace(exc_info=traceback.format_exc())) - success = False - except Exception as exc: - fire_event(RunningOperationUncaughtError(exc=str(exc))) - fire_event(LogDebugStackTrace(exc_info=traceback.format_exc())) - success = False + with collect_timing_info("execute", timing.append): + try: + self._run_unsafe(package_name, macro_name) + except dbt_common.exceptions.DbtBaseException as exc: + fire_event(RunningOperationCaughtError(exc=str(exc))) + fire_event(LogDebugStackTrace(exc_info=traceback.format_exc())) + success = False + except Exception as exc: + fire_event(RunningOperationUncaughtError(exc=str(exc))) + fire_event(LogDebugStackTrace(exc_info=traceback.format_exc())) + success = False - end = datetime.utcnow() + end = timing[1].completed_at macro = ( self.manifest.find_macro_by_name(macro_name, self.config.project_name, package_name) @@ -85,10 +89,12 @@ def run(self) -> RunResultsArtifact: f"dbt could not find a macro with the name '{macro_name}' in any package" ) + execution_time = (end - start).total_seconds() if start and end else 0.0 + run_result = RunResult( adapter_response={}, status=RunStatus.Success if success else RunStatus.Error, - execution_time=(end - start).total_seconds(), + execution_time=execution_time, failures=0 if success else 1, message=None, node=HookNode( @@ -105,13 +111,13 @@ def run(self) -> RunResultsArtifact: original_file_path="", ), thread_id=threading.current_thread().name, - timing=[TimingInfo(name=macro_name, started_at=start, completed_at=end)], + timing=timing, batch_results=None, ) results = RunResultsArtifact.from_execution_results( - generated_at=end, - elapsed_time=(end - start).total_seconds(), + generated_at=end or datetime.utcnow(), + elapsed_time=execution_time, args={ k: v for k, v in self.args.__dict__.items() diff --git a/tests/functional/adapter/hooks/test_on_run_hooks.py b/tests/functional/adapter/hooks/test_on_run_hooks.py index 42edbdae970..b9239b93b4a 100644 --- a/tests/functional/adapter/hooks/test_on_run_hooks.py +++ b/tests/functional/adapter/hooks/test_on_run_hooks.py @@ -55,6 +55,14 @@ def test_results(self, project, log_counts, my_model_run_status): for result in results if isinstance(result.node, HookNode) ] == [(id, str(status)) for id, status in expected_results if id.startswith("operation")] + + for result in results: + if result.status == RunStatus.Skipped: + continue + + timing_keys = [timing.name for timing in result.timing] + assert timing_keys == ["compile", "execute"] + assert log_counts in log_output assert "4 project hooks, 1 view model" in log_output diff --git a/tests/functional/run_operations/test_run_operations.py b/tests/functional/run_operations/test_run_operations.py index 064c98b3a51..258ed679d7e 100644 --- a/tests/functional/run_operations/test_run_operations.py +++ b/tests/functional/run_operations/test_run_operations.py @@ -3,6 +3,7 @@ import pytest import yaml +from dbt.artifacts.schemas.results import RunStatus from dbt.tests.util import ( check_table_does_exist, mkdir, @@ -135,9 +136,25 @@ def test_run_operation_local_macro(self, project): run_dbt(["deps"]) results, log_output = run_dbt_and_capture(["run-operation", "something_cool"]) + + for result in results: + if result.status == RunStatus.Skipped: + continue + + timing_keys = [timing.name for timing in result.timing] + assert timing_keys == ["compile", "execute"] + assert "something cool" in log_output results, log_output = run_dbt_and_capture(["run-operation", "pkg.something_cool"]) + + for result in results: + if result.status == RunStatus.Skipped: + continue + + timing_keys = [timing.name for timing in result.timing] + assert timing_keys == ["compile", "execute"] + assert "something cool" in log_output rm_dir("pkg")