From 5e9f1b515f37dfe6cdae1ab1aa7d190b92490e24 Mon Sep 17 00:00:00 2001
From: Kshitij Aranke <kshitij.aranke@dbtlabs.com>
Date: Tue, 1 Oct 2024 08:05:36 +0100
Subject: [PATCH] [Round 2] Fix #9005: Allow singular tests to be documented in
 properties.yml (#10792)

---
 .../unreleased/Fixes-20240923-190758.yaml     |  6 ++
 core/dbt/config/project.py                    | 13 ++--
 core/dbt/contracts/graph/manifest.py          | 50 ++++++++++++-
 core/dbt/contracts/graph/nodes.py             |  5 ++
 core/dbt/contracts/graph/unparsed.py          |  5 ++
 core/dbt/parser/common.py                     |  2 +
 core/dbt/parser/schemas.py                    | 72 ++++++++++++++++++-
 tests/functional/data_test_patch/fixtures.py  | 38 ++++++++++
 .../test_singular_test_patch.py               | 65 +++++++++++++++++
 tests/unit/config/test_project.py             |  7 +-
 tests/unit/config/test_runtime.py             |  2 +-
 11 files changed, 249 insertions(+), 16 deletions(-)
 create mode 100644 .changes/unreleased/Fixes-20240923-190758.yaml
 create mode 100644 tests/functional/data_test_patch/fixtures.py
 create mode 100644 tests/functional/data_test_patch/test_singular_test_patch.py

diff --git a/.changes/unreleased/Fixes-20240923-190758.yaml b/.changes/unreleased/Fixes-20240923-190758.yaml
new file mode 100644
index 00000000000..4d005ec5999
--- /dev/null
+++ b/.changes/unreleased/Fixes-20240923-190758.yaml
@@ -0,0 +1,6 @@
+kind: Fixes
+body: Allow singular tests to be documented in properties.yml
+time: 2024-09-23T19:07:58.151069+01:00
+custom:
+    Author: aranke
+    Issue: "9005"
diff --git a/core/dbt/config/project.py b/core/dbt/config/project.py
index 25b0f343ef2..cbad5a38434 100644
--- a/core/dbt/config/project.py
+++ b/core/dbt/config/project.py
@@ -158,14 +158,8 @@ def _parse_versions(versions: Union[List[str], str]) -> List[VersionSpecifier]:
     return [VersionSpecifier.from_version_string(v) for v in versions]
 
 
-def _all_source_paths(
-    model_paths: List[str],
-    seed_paths: List[str],
-    snapshot_paths: List[str],
-    analysis_paths: List[str],
-    macro_paths: List[str],
-) -> List[str]:
-    paths = chain(model_paths, seed_paths, snapshot_paths, analysis_paths, macro_paths)
+def _all_source_paths(*args: List[str]) -> List[str]:
+    paths = chain(*args)
     # Strip trailing slashes since the path is the same even though the name is not
     stripped_paths = map(lambda s: s.rstrip("/"), paths)
     return list(set(stripped_paths))
@@ -409,7 +403,7 @@ def create_project(self, rendered: RenderComponents) -> "Project":
         snapshot_paths: List[str] = value_or(cfg.snapshot_paths, ["snapshots"])
 
         all_source_paths: List[str] = _all_source_paths(
-            model_paths, seed_paths, snapshot_paths, analysis_paths, macro_paths
+            model_paths, seed_paths, snapshot_paths, analysis_paths, macro_paths, test_paths
         )
 
         docs_paths: List[str] = value_or(cfg.docs_paths, all_source_paths)
@@ -652,6 +646,7 @@ def all_source_paths(self) -> List[str]:
             self.snapshot_paths,
             self.analysis_paths,
             self.macro_paths,
+            self.test_paths,
         )
 
     @property
diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py
index f4cdafea737..b556b479fb4 100644
--- a/core/dbt/contracts/graph/manifest.py
+++ b/core/dbt/contracts/graph/manifest.py
@@ -58,6 +58,7 @@
     SavedQuery,
     SeedNode,
     SemanticModel,
+    SingularTestNode,
     SourceDefinition,
     UnitTestDefinition,
     UnitTestFileFixture,
@@ -89,7 +90,7 @@
 RefName = str
 
 
-def find_unique_id_for_package(storage, key, package: Optional[PackageName]):
+def find_unique_id_for_package(storage, key, package: Optional[PackageName]) -> Optional[UniqueID]:
     if key not in storage:
         return None
 
@@ -470,6 +471,43 @@ class AnalysisLookup(RefableLookup):
     _versioned_types: ClassVar[set] = set()
 
 
+class SingularTestLookup(dbtClassMixin):
+    def __init__(self, manifest: "Manifest") -> None:
+        self.storage: Dict[str, Dict[PackageName, UniqueID]] = {}
+        self.populate(manifest)
+
+    def get_unique_id(self, search_name, package: Optional[PackageName]) -> Optional[UniqueID]:
+        return find_unique_id_for_package(self.storage, search_name, package)
+
+    def find(
+        self, search_name, package: Optional[PackageName], manifest: "Manifest"
+    ) -> Optional[SingularTestNode]:
+        unique_id = self.get_unique_id(search_name, package)
+        if unique_id is not None:
+            return self.perform_lookup(unique_id, manifest)
+        return None
+
+    def add_singular_test(self, source: SingularTestNode) -> None:
+        if source.search_name not in self.storage:
+            self.storage[source.search_name] = {}
+
+        self.storage[source.search_name][source.package_name] = source.unique_id
+
+    def populate(self, manifest: "Manifest") -> None:
+        for node in manifest.nodes.values():
+            if isinstance(node, SingularTestNode):
+                self.add_singular_test(node)
+
+    def perform_lookup(self, unique_id: UniqueID, manifest: "Manifest") -> SingularTestNode:
+        if unique_id not in manifest.nodes:
+            raise dbt_common.exceptions.DbtInternalError(
+                f"Singular test {unique_id} found in cache but not found in manifest"
+            )
+        node = manifest.nodes[unique_id]
+        assert isinstance(node, SingularTestNode)
+        return node
+
+
 def _packages_to_search(
     current_project: str,
     node_package: str,
@@ -869,6 +907,9 @@ class Manifest(MacroMethods, dbtClassMixin):
     _analysis_lookup: Optional[AnalysisLookup] = field(
         default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
     )
+    _singular_test_lookup: Optional[SingularTestLookup] = field(
+        default=None, metadata={"serialize": lambda x: None, "deserialize": lambda x: None}
+    )
     _parsing_info: ParsingInfo = field(
         default_factory=ParsingInfo,
         metadata={"serialize": lambda x: None, "deserialize": lambda x: None},
@@ -1264,6 +1305,12 @@ def analysis_lookup(self) -> AnalysisLookup:
             self._analysis_lookup = AnalysisLookup(self)
         return self._analysis_lookup
 
+    @property
+    def singular_test_lookup(self) -> SingularTestLookup:
+        if self._singular_test_lookup is None:
+            self._singular_test_lookup = SingularTestLookup(self)
+        return self._singular_test_lookup
+
     @property
     def external_node_unique_ids(self):
         return [node.unique_id for node in self.nodes.values() if node.is_external_node]
@@ -1708,6 +1755,7 @@ def __reduce_ex__(self, protocol):
             self._semantic_model_by_measure_lookup,
             self._disabled_lookup,
             self._analysis_lookup,
+            self._singular_test_lookup,
         )
         return self.__class__, args
 
diff --git a/core/dbt/contracts/graph/nodes.py b/core/dbt/contracts/graph/nodes.py
index d5ef3d51174..f50eb1d9641 100644
--- a/core/dbt/contracts/graph/nodes.py
+++ b/core/dbt/contracts/graph/nodes.py
@@ -1678,6 +1678,11 @@ class ParsedMacroPatch(ParsedPatch):
     arguments: List[MacroArgument] = field(default_factory=list)
 
 
+@dataclass
+class ParsedSingularTestPatch(ParsedPatch):
+    pass
+
+
 # ====================================
 # Node unions/categories
 # ====================================
diff --git a/core/dbt/contracts/graph/unparsed.py b/core/dbt/contracts/graph/unparsed.py
index bb931123749..f20e76a8b68 100644
--- a/core/dbt/contracts/graph/unparsed.py
+++ b/core/dbt/contracts/graph/unparsed.py
@@ -202,6 +202,11 @@ class UnparsedAnalysisUpdate(HasConfig, HasColumnDocs, HasColumnProps, HasYamlMe
     access: Optional[str] = None
 
 
+@dataclass
+class UnparsedSingularTestUpdate(HasConfig, HasColumnProps, HasYamlMetadata):
+    pass
+
+
 @dataclass
 class UnparsedNodeUpdate(HasConfig, HasColumnTests, HasColumnAndTestProps, HasYamlMetadata):
     quote_columns: Optional[bool] = None
diff --git a/core/dbt/parser/common.py b/core/dbt/parser/common.py
index 66d84d2db9b..5cc4385ea1c 100644
--- a/core/dbt/parser/common.py
+++ b/core/dbt/parser/common.py
@@ -13,6 +13,7 @@
     UnparsedMacroUpdate,
     UnparsedModelUpdate,
     UnparsedNodeUpdate,
+    UnparsedSingularTestUpdate,
 )
 from dbt.exceptions import ParsingError
 from dbt.node_types import NodeType
@@ -58,6 +59,7 @@ def trimmed(inp: str) -> str:
     UnpatchedSourceDefinition,
     UnparsedExposure,
     UnparsedModelUpdate,
+    UnparsedSingularTestUpdate,
 )
 
 
diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py
index ae6aca476a3..727f4bcf676 100644
--- a/core/dbt/parser/schemas.py
+++ b/core/dbt/parser/schemas.py
@@ -17,6 +17,7 @@
     ModelNode,
     ParsedMacroPatch,
     ParsedNodePatch,
+    ParsedSingularTestPatch,
     UnpatchedSourceDefinition,
 )
 from dbt.contracts.graph.unparsed import (
@@ -27,6 +28,7 @@
     UnparsedMacroUpdate,
     UnparsedModelUpdate,
     UnparsedNodeUpdate,
+    UnparsedSingularTestUpdate,
     UnparsedSourceDefinition,
 )
 from dbt.events.types import (
@@ -65,7 +67,9 @@
 from dbt.utils import coerce_dict_str
 from dbt_common.contracts.constraints import ConstraintType, ModelLevelConstraint
 from dbt_common.dataclass_schema import ValidationError, dbtClassMixin
-from dbt_common.events.functions import warn_or_error
+from dbt_common.events import EventLevel
+from dbt_common.events.functions import fire_event, warn_or_error
+from dbt_common.events.types import Note
 from dbt_common.exceptions import DbtValidationError
 from dbt_common.utils import deep_merge
 
@@ -207,6 +211,18 @@ def parse_file(self, block: FileBlock, dct: Optional[Dict] = None) -> None:
                 parser = MacroPatchParser(self, yaml_block, "macros")
                 parser.parse()
 
+            if "data_tests" in dct:
+                parser = SingularTestPatchParser(self, yaml_block, "data_tests")
+                try:
+                    parser.parse()
+                except ParsingError as e:
+                    fire_event(
+                        Note(
+                            msg=f"Unable to parse 'data_tests' section of file '{block.path.original_file_path}'\n{e}",
+                        ),
+                        EventLevel.WARN,
+                    )
+
             # PatchParser.parse() (but never test_blocks)
             if "analyses" in dct:
                 parser = AnalysisPatchParser(self, yaml_block, "analyses")
@@ -301,7 +317,9 @@ def _add_yaml_snapshot_nodes_to_manifest(
             self.manifest.rebuild_ref_lookup()
 
 
-Parsed = TypeVar("Parsed", UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch)
+Parsed = TypeVar(
+    "Parsed", UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch, ParsedSingularTestPatch
+)
 NodeTarget = TypeVar("NodeTarget", UnparsedNodeUpdate, UnparsedAnalysisUpdate, UnparsedModelUpdate)
 NonSourceTarget = TypeVar(
     "NonSourceTarget",
@@ -309,6 +327,7 @@ def _add_yaml_snapshot_nodes_to_manifest(
     UnparsedAnalysisUpdate,
     UnparsedMacroUpdate,
     UnparsedModelUpdate,
+    UnparsedSingularTestUpdate,
 )
 
 
@@ -1144,6 +1163,55 @@ def _target_type(self) -> Type[UnparsedAnalysisUpdate]:
         return UnparsedAnalysisUpdate
 
 
+class SingularTestPatchParser(PatchParser[UnparsedSingularTestUpdate, ParsedSingularTestPatch]):
+    def get_block(self, node: UnparsedSingularTestUpdate) -> TargetBlock:
+        return TargetBlock.from_yaml_block(self.yaml, node)
+
+    def _target_type(self) -> Type[UnparsedSingularTestUpdate]:
+        return UnparsedSingularTestUpdate
+
+    def parse_patch(self, block: TargetBlock[UnparsedSingularTestUpdate], refs: ParserRef) -> None:
+        patch = ParsedSingularTestPatch(
+            name=block.target.name,
+            description=block.target.description,
+            meta=block.target.meta,
+            docs=block.target.docs,
+            config=block.target.config,
+            original_file_path=block.target.original_file_path,
+            yaml_key=block.target.yaml_key,
+            package_name=block.target.package_name,
+        )
+
+        assert isinstance(self.yaml.file, SchemaSourceFile)
+        source_file: SchemaSourceFile = self.yaml.file
+
+        unique_id = self.manifest.singular_test_lookup.get_unique_id(
+            block.name, block.target.package_name
+        )
+        if not unique_id:
+            warn_or_error(
+                NoNodeForYamlKey(
+                    patch_name=patch.name,
+                    yaml_key=patch.yaml_key,
+                    file_path=source_file.path.original_file_path,
+                )
+            )
+            return
+
+        node = self.manifest.nodes.get(unique_id)
+        assert node is not None
+
+        source_file.append_patch(patch.yaml_key, unique_id)
+        if patch.config:
+            self.patch_node_config(node, patch)
+
+        node.patch_path = patch.file_id
+        node.description = patch.description
+        node.created_at = time.time()
+        node.meta = patch.meta
+        node.docs = patch.docs
+
+
 class MacroPatchParser(PatchParser[UnparsedMacroUpdate, ParsedMacroPatch]):
     def get_block(self, node: UnparsedMacroUpdate) -> TargetBlock:
         return TargetBlock.from_yaml_block(self.yaml, node)
diff --git a/tests/functional/data_test_patch/fixtures.py b/tests/functional/data_test_patch/fixtures.py
new file mode 100644
index 00000000000..be056f32680
--- /dev/null
+++ b/tests/functional/data_test_patch/fixtures.py
@@ -0,0 +1,38 @@
+tests__my_singular_test_sql = """
+with my_cte as (
+    select 1 as id, 'foo' as name
+    union all
+    select 2 as id, 'bar' as name
+)
+select * from my_cte
+"""
+
+tests__schema_yml = """
+data_tests:
+  - name: my_singular_test
+    description: "{{ doc('my_singular_test_documentation') }}"
+    config:
+      error_if: ">10"
+    meta:
+      some_key: some_val
+"""
+
+tests__doc_block_md = """
+{% docs my_singular_test_documentation %}
+
+Some docs from a doc block
+
+{% enddocs %}
+"""
+
+tests__invalid_name_schema_yml = """
+data_tests:
+  - name: my_double_test
+    description: documentation, but make it double
+"""
+
+tests__malformed_schema_yml = """
+data_tests: &not_null
+  - not_null:
+      where: some_condition
+"""
diff --git a/tests/functional/data_test_patch/test_singular_test_patch.py b/tests/functional/data_test_patch/test_singular_test_patch.py
new file mode 100644
index 00000000000..df359c5e645
--- /dev/null
+++ b/tests/functional/data_test_patch/test_singular_test_patch.py
@@ -0,0 +1,65 @@
+from pathlib import Path
+
+import pytest
+
+from dbt.tests.util import get_artifact, run_dbt, run_dbt_and_capture
+from tests.functional.data_test_patch.fixtures import (
+    tests__doc_block_md,
+    tests__invalid_name_schema_yml,
+    tests__malformed_schema_yml,
+    tests__my_singular_test_sql,
+    tests__schema_yml,
+)
+
+
+class TestPatchSingularTest:
+    @pytest.fixture(scope="class")
+    def tests(self):
+        return {
+            "my_singular_test.sql": tests__my_singular_test_sql,
+            "schema.yml": tests__schema_yml,
+            "doc_block.md": tests__doc_block_md,
+        }
+
+    def test_compile(self, project):
+        run_dbt(["compile"])
+        manifest = get_artifact(project.project_root, "target", "manifest.json")
+        assert len(manifest["nodes"]) == 1
+
+        my_singular_test_node = manifest["nodes"]["test.test.my_singular_test"]
+        assert my_singular_test_node["description"] == "Some docs from a doc block"
+        assert my_singular_test_node["config"]["error_if"] == ">10"
+        assert my_singular_test_node["config"]["meta"] == {"some_key": "some_val"}
+
+
+class TestPatchSingularTestInvalidName:
+    @pytest.fixture(scope="class")
+    def tests(self):
+        return {
+            "my_singular_test.sql": tests__my_singular_test_sql,
+            "schema_with_invalid_name.yml": tests__invalid_name_schema_yml,
+        }
+
+    def test_compile(self, project):
+        _, log_output = run_dbt_and_capture(["compile"])
+
+        file_path = Path("tests/schema_with_invalid_name.yml")
+        assert (
+            f"Did not find matching node for patch with name 'my_double_test' in the 'data_tests' section of file '{file_path}'"
+            in log_output
+        )
+
+
+class TestPatchSingularTestMalformedYaml:
+    @pytest.fixture(scope="class")
+    def tests(self):
+        return {
+            "my_singular_test.sql": tests__my_singular_test_sql,
+            "schema.yml": tests__malformed_schema_yml,
+        }
+
+    def test_compile(self, project):
+        _, log_output = run_dbt_and_capture(["compile"])
+        file_path = Path("tests/schema.yml")
+        assert f"Unable to parse 'data_tests' section of file '{file_path}'" in log_output
+        assert "Entry did not contain a name" in log_output
diff --git a/tests/unit/config/test_project.py b/tests/unit/config/test_project.py
index ab842c164d7..ddd519cc6ee 100644
--- a/tests/unit/config/test_project.py
+++ b/tests/unit/config/test_project.py
@@ -31,7 +31,7 @@ class TestProjectMethods:
     def test_all_source_paths(self, project: Project):
         assert (
             project.all_source_paths.sort()
-            == ["models", "seeds", "snapshots", "analyses", "macros"].sort()
+            == ["models", "seeds", "snapshots", "analyses", "macros", "tests"].sort()
         )
 
     def test_generic_test_paths(self, project: Project):
@@ -99,7 +99,8 @@ def test_defaults(self):
         self.assertEqual(project.test_paths, ["tests"])
         self.assertEqual(project.analysis_paths, ["analyses"])
         self.assertEqual(
-            set(project.docs_paths), set(["models", "seeds", "snapshots", "analyses", "macros"])
+            set(project.docs_paths),
+            {"models", "seeds", "snapshots", "analyses", "macros", "tests"},
         )
         self.assertEqual(project.asset_paths, [])
         self.assertEqual(project.target_path, "target")
@@ -128,7 +129,7 @@ def test_implicit_overrides(self):
         )
         self.assertEqual(
             set(project.docs_paths),
-            set(["other-models", "seeds", "snapshots", "analyses", "macros"]),
+            {"other-models", "seeds", "snapshots", "analyses", "macros", "tests"},
         )
 
     def test_all_overrides(self):
diff --git a/tests/unit/config/test_runtime.py b/tests/unit/config/test_runtime.py
index 816ec8f98c3..d03d33dab94 100644
--- a/tests/unit/config/test_runtime.py
+++ b/tests/unit/config/test_runtime.py
@@ -129,7 +129,7 @@ def test_from_args(self):
         self.assertEqual(config.test_paths, ["tests"])
         self.assertEqual(config.analysis_paths, ["analyses"])
         self.assertEqual(
-            set(config.docs_paths), set(["models", "seeds", "snapshots", "analyses", "macros"])
+            set(config.docs_paths), {"models", "seeds", "snapshots", "analyses", "macros", "tests"}
         )
         self.assertEqual(config.asset_paths, [])
         self.assertEqual(config.target_path, "target")