From ff2726c3b59e52e1268d5b78ad59d0cf2453f409 Mon Sep 17 00:00:00 2001 From: Michelle Ark Date: Wed, 31 Jul 2024 20:02:27 -0400 Subject: [PATCH] more defensive node.all_constraints access (#10508) --- .../unreleased/Fixes-20240731-095152.yaml | 6 + core/dbt/compilation.py | 3 +- core/dbt/contracts/graph/manifest.py | 36 ++++- core/dbt/parser/schemas.py | 44 +++--- .../configs/test_disabled_configs.py | 44 ++++++ tests/unit/contracts/graph/test_manifest.py | 125 +++++++++++++++++- 6 files changed, 233 insertions(+), 25 deletions(-) create mode 100644 .changes/unreleased/Fixes-20240731-095152.yaml diff --git a/.changes/unreleased/Fixes-20240731-095152.yaml b/.changes/unreleased/Fixes-20240731-095152.yaml new file mode 100644 index 00000000000..c7899f6c30b --- /dev/null +++ b/.changes/unreleased/Fixes-20240731-095152.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: fix all_constraints access, disabled node parsing of non-uniquely named resources +time: 2024-07-31T09:51:52.751135-04:00 +custom: + Author: michelleark gshank + Issue: "10509" diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index 47d7ffbdb51..90ccc42c479 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -21,6 +21,7 @@ InjectedCTE, ManifestNode, ManifestSQLNode, + ModelNode, SeedNode, UnitTestDefinition, UnitTestNode, @@ -441,7 +442,7 @@ def _compile_code( node.relation_name = relation_name # Compile 'ref' and 'source' expressions in foreign key constraints - if node.resource_type == NodeType.Model: + if isinstance(node, ModelNode): for constraint in node.all_constraints: if constraint.type == ConstraintType.foreign_key and constraint.to: constraint.to = self._compile_relation_for_foreign_key_constraint_to( diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 21c5571b74b..f4cdafea737 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -413,11 +413,11 @@ def __init__(self, manifest: "Manifest") -> None: self.storage: Dict[str, Dict[PackageName, List[Any]]] = {} self.populate(manifest) - def populate(self, manifest): + def populate(self, manifest: "Manifest"): for node in list(chain.from_iterable(manifest.disabled.values())): self.add_node(node) - def add_node(self, node): + def add_node(self, node: GraphMemberNode) -> None: if node.search_name not in self.storage: self.storage[node.search_name] = {} if node.package_name not in self.storage[node.search_name]: @@ -427,8 +427,12 @@ def add_node(self, node): # This should return a list of disabled nodes. It's different from # the other Lookup functions in that it returns full nodes, not just unique_ids def find( - self, search_name, package: Optional[PackageName], version: Optional[NodeVersion] = None - ): + self, + search_name, + package: Optional[PackageName], + version: Optional[NodeVersion] = None, + resource_types: Optional[List[NodeType]] = None, + ) -> Optional[List[Any]]: if version: search_name = f"{search_name}.v{version}" @@ -437,16 +441,29 @@ def find( pkg_dct: Mapping[PackageName, List[Any]] = self.storage[search_name] + nodes = [] if package is None: if not pkg_dct: return None else: - return next(iter(pkg_dct.values())) + nodes = next(iter(pkg_dct.values())) elif package in pkg_dct: - return pkg_dct[package] + nodes = pkg_dct[package] else: return None + if resource_types is None: + return nodes + else: + new_nodes = [] + for node in nodes: + if node.resource_type in resource_types: + new_nodes.append(node) + if not new_nodes: + return None + else: + return new_nodes + class AnalysisLookup(RefableLookup): _lookup_types: ClassVar[set] = set([NodeType.Analysis]) @@ -1295,7 +1312,12 @@ def resolve_ref( # it's possible that the node is disabled if disabled is None: - disabled = self.disabled_lookup.find(target_model_name, pkg, target_model_version) + disabled = self.disabled_lookup.find( + target_model_name, + pkg, + version=target_model_version, + resource_types=REFABLE_NODE_TYPES, + ) if disabled: return Disabled(disabled[0]) diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index 04f63d04e34..96313425faa 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -69,18 +69,20 @@ from dbt_common.exceptions import DbtValidationError from dbt_common.utils import deep_merge -schema_file_keys = ( - "models", - "seeds", - "snapshots", - "sources", - "macros", - "analyses", - "exposures", - "metrics", - "semantic_models", - "saved_queries", -) +schema_file_keys_to_resource_types = { + "models": NodeType.Model, + "seeds": NodeType.Seed, + "snapshots": NodeType.Snapshot, + "sources": NodeType.Source, + "macros": NodeType.Macro, + "analyses": NodeType.Analysis, + "exposures": NodeType.Exposure, + "metrics": NodeType.Metric, + "semantic_models": NodeType.SemanticModel, + "saved_queries": NodeType.SavedQuery, +} + +schema_file_keys = list(schema_file_keys_to_resource_types.keys()) # =============================================================================== @@ -678,7 +680,10 @@ def parse_patch(self, block: TargetBlock[NodeTarget], refs: ParserRef) -> None: # handle disabled nodes if unique_id is None: # Node might be disabled. Following call returns list of matching disabled nodes - found_nodes = self.manifest.disabled_lookup.find(patch.name, patch.package_name) + resource_type = schema_file_keys_to_resource_types[patch.yaml_key] + found_nodes = self.manifest.disabled_lookup.find( + patch.name, patch.package_name, resource_types=[resource_type] + ) if found_nodes: if len(found_nodes) > 1 and patch.config.get("enabled"): # There are multiple disabled nodes for this model and the schema file wants to enable one. @@ -810,7 +815,9 @@ def parse_patch(self, block: TargetBlock[UnparsedModelUpdate], refs: ParserRef) if versioned_model_unique_id is None: # Node might be disabled. Following call returns list of matching disabled nodes - found_nodes = self.manifest.disabled_lookup.find(versioned_model_name, None) + found_nodes = self.manifest.disabled_lookup.find( + versioned_model_name, None, resource_types=[NodeType.Model] + ) if found_nodes: if len(found_nodes) > 1 and target.config.get("enabled"): # There are multiple disabled nodes for this model and the schema file wants to enable one. @@ -911,6 +918,11 @@ def _target_type(self) -> Type[UnparsedModelUpdate]: def patch_node_properties(self, node, patch: "ParsedNodePatch") -> None: super().patch_node_properties(node, patch) + + # Remaining patch properties are only relevant to ModelNode objects + if not isinstance(node, ModelNode): + return + node.version = patch.version node.latest_version = patch.latest_version node.deprecation_date = patch.deprecation_date @@ -927,7 +939,7 @@ def patch_node_properties(self, node, patch: "ParsedNodePatch") -> None: self.patch_time_spine(node, patch.time_spine) node.build_contract_checksum() - def patch_constraints(self, node, constraints: List[Dict[str, Any]]) -> None: + def patch_constraints(self, node: ModelNode, constraints: List[Dict[str, Any]]) -> None: contract_config = node.config.get("contract") if contract_config.enforced is True: self._validate_constraint_prerequisites(node) @@ -963,7 +975,7 @@ def _process_constraints_refs_and_sources(self, model_node: ModelNode) -> None: else: model_node.sources.append(ref_or_source) - def patch_time_spine(self, node, time_spine: Optional[TimeSpine]) -> None: + def patch_time_spine(self, node: ModelNode, time_spine: Optional[TimeSpine]) -> None: node.time_spine = time_spine def _validate_pk_constraints( diff --git a/tests/functional/configs/test_disabled_configs.py b/tests/functional/configs/test_disabled_configs.py index d2ee83e801a..f0176788777 100644 --- a/tests/functional/configs/test_disabled_configs.py +++ b/tests/functional/configs/test_disabled_configs.py @@ -88,3 +88,47 @@ def test_conditional_model(self, project): assert len(results) == 2 results = run_dbt(["test"]) assert len(results) == 5 + + +my_analysis_sql = """ +{{ + config(enabled=False) +}} +select 1 as id +""" + + +schema_yml = """ +models: + - name: my_analysis + description: "A Sample model" + config: + meta: + owner: Joe + +analyses: + - name: my_analysis + description: "A sample analysis" + config: + enabled: false +""" + + +class TestDisabledConfigsSameName: + @pytest.fixture(scope="class") + def models(self): + return { + "my_analysis.sql": my_analysis_sql, + "schema.yml": schema_yml, + } + + @pytest.fixture(scope="class") + def analyses(self): + return { + "my_analysis.sql": my_analysis_sql, + } + + def test_disabled_analysis(self, project): + manifest = run_dbt(["parse"]) + assert len(manifest.disabled) == 2 + assert len(manifest.nodes) == 0 diff --git a/tests/unit/contracts/graph/test_manifest.py b/tests/unit/contracts/graph/test_manifest.py index 5eef57324b7..d8d1df0d900 100644 --- a/tests/unit/contracts/graph/test_manifest.py +++ b/tests/unit/contracts/graph/test_manifest.py @@ -26,7 +26,7 @@ WhereFilterIntersection, ) from dbt.contracts.files import FileHash -from dbt.contracts.graph.manifest import Manifest, ManifestMetadata +from dbt.contracts.graph.manifest import DisabledLookup, Manifest, ManifestMetadata from dbt.contracts.graph.nodes import ( DependsOn, Exposure, @@ -2013,3 +2013,126 @@ def test_find_node_from_ref_or_source_invalid_expression( ): with pytest.raises(ParsingError): mock_manifest.find_node_from_ref_or_source(invalid_expression) + + +class TestDisabledLookup: + @pytest.fixture(scope="class") + def manifest(self): + return Manifest( + nodes={}, + sources={}, + macros={}, + docs={}, + disabled={}, + files={}, + exposures={}, + selectors={}, + ) + + @pytest.fixture(scope="class") + def mock_model(self): + return MockNode("package", "name", NodeType.Model) + + @pytest.fixture(scope="class") + def mock_model_with_version(self): + return MockNode("package", "name", NodeType.Model, version=3) + + @pytest.fixture(scope="class") + def mock_seed(self): + return MockNode("package", "name", NodeType.Seed) + + def test_find(self, manifest, mock_model): + manifest.disabled = {"model.package.name": [mock_model]} + lookup = DisabledLookup(manifest) + + assert lookup.find("name", "package") == [mock_model] + + def test_find_wrong_name(self, manifest, mock_model): + manifest.disabled = {"model.package.name": [mock_model]} + lookup = DisabledLookup(manifest) + + assert lookup.find("missing_name", "package") is None + + def test_find_wrong_package(self, manifest, mock_model): + manifest.disabled = {"model.package.name": [mock_model]} + lookup = DisabledLookup(manifest) + + assert lookup.find("name", "missing_package") is None + + def test_find_wrong_version(self, manifest, mock_model): + manifest.disabled = {"model.package.name": [mock_model]} + lookup = DisabledLookup(manifest) + + assert lookup.find("name", "package", version=3) is None + + def test_find_wrong_resource_types(self, manifest, mock_model): + manifest.disabled = {"model.package.name": [mock_model]} + lookup = DisabledLookup(manifest) + + assert lookup.find("name", "package", resource_types=[NodeType.Analysis]) is None + + def test_find_no_package(self, manifest, mock_model): + manifest.disabled = {"model.package.name": [mock_model]} + lookup = DisabledLookup(manifest) + + assert lookup.find("name", None) == [mock_model] + + def test_find_versioned_node(self, manifest, mock_model_with_version): + manifest.disabled = {"model.package.name": [mock_model_with_version]} + lookup = DisabledLookup(manifest) + + assert lookup.find("name", "package", version=3) == [mock_model_with_version] + + def test_find_versioned_node_no_package(self, manifest, mock_model_with_version): + manifest.disabled = {"model.package.name": [mock_model_with_version]} + lookup = DisabledLookup(manifest) + + assert lookup.find("name", None, version=3) == [mock_model_with_version] + + def test_find_versioned_node_no_version(self, manifest, mock_model_with_version): + manifest.disabled = {"model.package.name": [mock_model_with_version]} + lookup = DisabledLookup(manifest) + + assert lookup.find("name", "package") is None + + def test_find_versioned_node_wrong_version(self, manifest, mock_model_with_version): + manifest.disabled = {"model.package.name": [mock_model_with_version]} + lookup = DisabledLookup(manifest) + + assert lookup.find("name", "package", version=2) is None + + def test_find_versioned_node_wrong_name(self, manifest, mock_model_with_version): + manifest.disabled = {"model.package.name": [mock_model_with_version]} + lookup = DisabledLookup(manifest) + + assert lookup.find("wrong_name", "package", version=3) is None + + def test_find_versioned_node_wrong_package(self, manifest, mock_model_with_version): + manifest.disabled = {"model.package.name": [mock_model_with_version]} + lookup = DisabledLookup(manifest) + + assert lookup.find("name", "wrong_package", version=3) is None + + def test_find_multiple_nodes(self, manifest, mock_model, mock_seed): + manifest.disabled = {"model.package.name": [mock_model, mock_seed]} + lookup = DisabledLookup(manifest) + + assert lookup.find("name", "package") == [mock_model, mock_seed] + + def test_find_multiple_nodes_with_resource_types(self, manifest, mock_model, mock_seed): + manifest.disabled = {"model.package.name": [mock_model, mock_seed]} + lookup = DisabledLookup(manifest) + + assert lookup.find("name", "package", resource_types=[NodeType.Model]) == [mock_model] + + def test_find_multiple_nodes_with_wrong_resource_types(self, manifest, mock_model, mock_seed): + manifest.disabled = {"model.package.name": [mock_model, mock_seed]} + lookup = DisabledLookup(manifest) + + assert lookup.find("name", "package", resource_types=[NodeType.Analysis]) is None + + def test_find_multiple_nodes_with_resource_types_empty(self, manifest, mock_model, mock_seed): + manifest.disabled = {"model.package.name": [mock_model, mock_seed]} + lookup = DisabledLookup(manifest) + + assert lookup.find("name", "package", resource_types=[]) is None