Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

state:modified.vars #11007

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20240729-173203.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Include models that depend on changed vars in state:modified, add state:modified.vars
selection method
time: 2024-07-29T17:32:03.368508-04:00
custom:
Author: michelleark
Issue: "4304"
1 change: 1 addition & 0 deletions core/dbt/artifacts/resources/v1/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class ParsedResource(ParsedResourceMandatory):
unrendered_config_call_dict: Dict[str, Any] = field(default_factory=dict)
relation_name: Optional[str] = None
raw_code: str = ""
vars: Dict[str, Any] = field(default_factory=dict)

def __post_serialize__(self, dct: Dict, context: Optional[Dict] = None):
dct = super().__post_serialize__(dct, context)
Expand Down
1 change: 1 addition & 0 deletions core/dbt/artifacts/resources/v1/exposure.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Exposure(GraphResource):
tags: List[str] = field(default_factory=list)
config: ExposureConfig = field(default_factory=ExposureConfig)
unrendered_config: Dict[str, Any] = field(default_factory=dict)
vars: Dict[str, Any] = field(default_factory=dict)
url: Optional[str] = None
depends_on: DependsOn = field(default_factory=DependsOn)
refs: List[RefArgs] = field(default_factory=list)
Expand Down
1 change: 1 addition & 0 deletions core/dbt/artifacts/resources/v1/source_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class SourceDefinition(ParsedSourceMandatory):
config: SourceConfig = field(default_factory=SourceConfig)
patch_path: Optional[str] = None
unrendered_config: Dict[str, Any] = field(default_factory=dict)
vars: Dict[str, Any] = field(default_factory=dict)
relation_name: Optional[str] = None
created_at: float = field(default_factory=lambda: time.time())
unrendered_database: Optional[str] = None
Expand Down
38 changes: 26 additions & 12 deletions core/dbt/context/configured.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,35 @@
self.resource_type = NodeType.Model


class SchemaYamlVars:
def __init__(self):
self.env_vars = {}
self.vars = {}


class ConfiguredVar(Var):
def __init__(
self,
context: Dict[str, Any],
config: AdapterRequiredConfig,
project_name: str,
schema_yaml_vars: Optional[SchemaYamlVars] = None,
):
super().__init__(context, config.cli_vars)
self._config = config
self._project_name = project_name
self.schema_yaml_vars = schema_yaml_vars

def __call__(self, var_name, default=Var._VAR_NOTSET):
my_config = self._config.load_dependencies()[self._project_name]

var_found = False
var_value = None

# cli vars > active project > local project
if var_name in self._config.cli_vars:
return self._config.cli_vars[var_name]
var_found = True
var_value = self._config.cli_vars[var_name]

adapter_type = self._config.credentials.type
lookup = FQNLookup(self._project_name)
Expand All @@ -58,19 +70,21 @@
all_vars.add(my_config.vars.vars_for(lookup, adapter_type))
all_vars.add(active_vars)

if var_name in all_vars:
return all_vars[var_name]
if not var_found and var_name in all_vars:
var_found = True
var_value = all_vars[var_name]

if default is not Var._VAR_NOTSET:
return default

return self.get_missing_var(var_name)
if not var_found and default is not Var._VAR_NOTSET:
var_found = True
var_value = default

if not var_found:
return self.get_missing_var(var_name)

Check warning on line 82 in core/dbt/context/configured.py

View check run for this annotation

Codecov / codecov/patch

core/dbt/context/configured.py#L82

Added line #L82 was not covered by tests
else:
if self.schema_yaml_vars:
self.schema_yaml_vars.vars[var_name] = var_value

class SchemaYamlVars:
def __init__(self):
self.env_vars = {}
self.vars = {}
return var_value


class SchemaYamlContext(ConfiguredContext):
Expand All @@ -82,7 +96,7 @@

@contextproperty()
def var(self) -> ConfiguredVar:
return ConfiguredVar(self._ctx, self.config, self._project_name)
return ConfiguredVar(self._ctx, self.config, self._project_name, self.schema_yaml_vars)

@contextmember()
def env_var(self, var: str, default: Optional[str] = None) -> str:
Expand Down
8 changes: 8 additions & 0 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,14 @@ def get_missing_var(self, var_name):
# in the parser, just always return None.
return None

def __call__(self, var_name: str, default: Any = ModelConfiguredVar._VAR_NOTSET) -> Any:
var_value = super().__call__(var_name, default)

if self._node and hasattr(self._node, "vars"):
self._node.vars[var_name] = var_value

return var_value


class RuntimeVar(ModelConfiguredVar):
pass
Expand Down
17 changes: 17 additions & 0 deletions core/dbt/contracts/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ class SchemaSourceFile(BaseSourceFile):
unrendered_configs: Dict[str, Any] = field(default_factory=dict)
unrendered_databases: Dict[str, Any] = field(default_factory=dict)
unrendered_schemas: Dict[str, Any] = field(default_factory=dict)
vars: Dict[str, Any] = field(default_factory=dict)
pp_dict: Optional[Dict[str, Any]] = None
pp_test_index: Optional[Dict[str, Any]] = None

Expand Down Expand Up @@ -356,6 +357,22 @@ def delete_from_unrendered_configs(self, yaml_key, name):
if not self.unrendered_configs[yaml_key]:
del self.unrendered_configs[yaml_key]

def add_vars(self, vars: Dict[str, Any], yaml_key: str, name: str) -> None:
if yaml_key not in self.vars:
self.vars[yaml_key] = {}

if name not in self.vars[yaml_key]:
self.vars[yaml_key][name] = vars

def get_vars(self, yaml_key: str, name: str) -> Dict[str, Any]:
if yaml_key not in self.vars:
return {}

if name not in self.vars[yaml_key]:
return {}

return self.vars[yaml_key][name]

def add_env_var(self, var, yaml_key, name):
if yaml_key not in self.env_vars:
self.env_vars[yaml_key] = {}
Expand Down
32 changes: 32 additions & 0 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,19 +369,30 @@ def same_contract(self, old, adapter_type=None) -> bool:
# This would only apply to seeds
return True

def same_vars(self, old) -> bool:
return self.vars == old.vars

def same_contents(self, old, adapter_type) -> bool:
if old is None:
return False

# Need to ensure that same_contract is called because it
# could throw an error
same_contract = self.same_contract(old, adapter_type)

# Legacy behaviour
if not get_flags().state_modified_compare_vars:
same_vars = True
else:
same_vars = self.same_vars(old)

return (
self.same_body(old)
and self.same_config(old)
and self.same_persisted_description(old)
and self.same_fqn(old)
and self.same_database_representation(old)
and same_vars
and same_contract
and True
)
Expand Down Expand Up @@ -1264,6 +1275,9 @@ def same_config(self, old: "SourceDefinition") -> bool:
old.unrendered_config,
)

def same_vars(self, other: "SourceDefinition") -> bool:
return self.vars == other.vars

def same_contents(self, old: Optional["SourceDefinition"]) -> bool:
# existing when it didn't before is a change!
if old is None:
Expand All @@ -1277,13 +1291,20 @@ def same_contents(self, old: Optional["SourceDefinition"]) -> bool:
# freshness changes are changes, I guess
# metadata/tags changes are not "changes"
# patching/description changes are not "changes"
# Legacy behaviour
if not get_flags().state_modified_compare_vars:
same_vars = True
else:
same_vars = self.same_vars(old)

return (
self.same_database_representation(old)
and self.same_fqn(old)
and self.same_config(old)
and self.same_quoting(old)
and self.same_freshness(old)
and self.same_external(old)
and same_vars
and True
)

Expand Down Expand Up @@ -1380,12 +1401,21 @@ def same_config(self, old: "Exposure") -> bool:
old.unrendered_config,
)

def same_vars(self, old: "Exposure") -> bool:
return self.vars == old.vars

def same_contents(self, old: Optional["Exposure"]) -> bool:
# existing when it didn't before is a change!
# metadata/tags changes are not "changes"
if old is None:
return True

# Legacy behaviour
if not get_flags().state_modified_compare_vars:
same_vars = True
else:
same_vars = self.same_vars(old)

return (
self.same_fqn(old)
and self.same_exposure_type(old)
Expand All @@ -1396,6 +1426,7 @@ def same_contents(self, old: Optional["Exposure"]) -> bool:
and self.same_label(old)
and self.same_depends_on(old)
and self.same_config(old)
and same_vars
and True
)

Expand Down Expand Up @@ -1647,6 +1678,7 @@ class ParsedNodePatch(ParsedPatch):
latest_version: Optional[NodeVersion]
constraints: List[Dict[str, Any]]
deprecation_date: Optional[datetime]
vars: Dict[str, Any]
time_spine: Optional[TimeSpine] = None


Expand Down
1 change: 1 addition & 0 deletions core/dbt/graph/selector_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ def search(self, included_nodes: Set[UniqueId], selector: str) -> Iterator[Uniqu
"modified.relation": self.check_modified_factory("same_database_representation"),
"modified.macros": self.check_modified_macros,
"modified.contract": self.check_modified_contract("same_contract", adapter_type),
"modified.vars": self.check_modified_factory("same_vars"),
}
if selector in state_checks:
checker = state_checks[selector]
Expand Down
5 changes: 4 additions & 1 deletion core/dbt/parser/schema_yaml_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def parse_exposure(self, unparsed: UnparsedExposure) -> None:
unique_id = f"{NodeType.Exposure}.{package_name}.{unparsed.name}"
path = self.yaml.path.relative_path

assert isinstance(self.yaml.file, SchemaSourceFile)
exposure_vars = self.yaml.file.get_vars(self.key, unparsed.name)

fqn = self.schema_parser.get_fqn_prefix(path)
fqn.append(unparsed.name)

Expand Down Expand Up @@ -133,6 +136,7 @@ def parse_exposure(self, unparsed: UnparsedExposure) -> None:
maturity=unparsed.maturity,
config=config,
unrendered_config=unrendered_config,
vars=exposure_vars,
)
ctx = generate_parse_exposure(
parsed,
Expand All @@ -144,7 +148,6 @@ def parse_exposure(self, unparsed: UnparsedExposure) -> None:
get_rendered(depends_on_jinja, ctx, parsed, capture_macros=True)
# parsed now has a populated refs/sources/metrics

assert isinstance(self.yaml.file, SchemaSourceFile)
if parsed.config.enabled:
self.manifest.add_exposure(self.yaml.file, parsed)
else:
Expand Down
18 changes: 14 additions & 4 deletions core/dbt/parser/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,10 +432,14 @@ def get_key_dicts(self) -> Iterable[Dict[str, Any]]:

if self.schema_yaml_vars.env_vars:
self.schema_parser.manifest.env_vars.update(self.schema_yaml_vars.env_vars)
for var in self.schema_yaml_vars.env_vars.keys():
schema_file.add_env_var(var, self.key, entry["name"])
for env_var in self.schema_yaml_vars.env_vars.keys():
schema_file.add_env_var(env_var, self.key, entry["name"])
self.schema_yaml_vars.env_vars = {}

if self.schema_yaml_vars.vars:
schema_file.add_vars(self.schema_yaml_vars.vars, self.key, entry["name"])
self.schema_yaml_vars.vars = {}

yield entry

def render_entry(self, dct):
Expand Down Expand Up @@ -715,6 +719,9 @@ def parse_patch(self, block: TargetBlock[NodeTarget], refs: ParserRef) -> None:
# code consistency.
deprecation_date: Optional[datetime.datetime] = None
time_spine: Optional[TimeSpine] = None
assert isinstance(self.yaml.file, SchemaSourceFile)
source_file: SchemaSourceFile = self.yaml.file

if isinstance(block.target, UnparsedModelUpdate):
deprecation_date = block.target.deprecation_date
time_spine = (
Expand Down Expand Up @@ -747,9 +754,9 @@ def parse_patch(self, block: TargetBlock[NodeTarget], refs: ParserRef) -> None:
constraints=block.target.constraints,
deprecation_date=deprecation_date,
time_spine=time_spine,
vars=source_file.get_vars(block.target.yaml_key, block.target.name),
)
assert isinstance(self.yaml.file, SchemaSourceFile)
source_file: SchemaSourceFile = self.yaml.file

if patch.yaml_key in ["models", "seeds", "snapshots"]:
unique_id = self.manifest.ref_lookup.get_unique_id(
patch.name, self.project.project_name, None
Expand Down Expand Up @@ -843,6 +850,8 @@ def patch_node_properties(self, node, patch: "ParsedNodePatch") -> None:
node.description = patch.description
node.columns = patch.columns
node.name = patch.name
# Prefer node-level vars to vars from patch
node.vars = {**patch.vars, **node.vars}

if not isinstance(node, ModelNode):
for attr in ["latest_version", "access", "version", "constraints"]:
Expand Down Expand Up @@ -992,6 +1001,7 @@ def parse_patch(self, block: TargetBlock[UnparsedModelUpdate], refs: ParserRef)
latest_version=latest_version,
constraints=unparsed_version.constraints or target.constraints,
deprecation_date=unparsed_version.deprecation_date,
vars=source_file.get_vars(block.target.yaml_key, block.target.name),
)
# Node patched before config because config patching depends on model name,
# which may have been updated in the version patch
Expand Down
6 changes: 6 additions & 0 deletions core/dbt/parser/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ContextConfigGenerator,
UnrenderedConfigGenerator,
)
from dbt.contracts.files import SchemaSourceFile
from dbt.contracts.graph.manifest import Manifest, SourceKey
from dbt.contracts.graph.nodes import (
GenericTestNode,
Expand Down Expand Up @@ -158,6 +159,10 @@ def parse_source(self, target: UnpatchedSourceDefinition) -> SourceDefinition:
rendered=False,
)

schema_file = self.manifest.files[target.file_id]
assert isinstance(schema_file, SchemaSourceFile)
source_vars = schema_file.get_vars("sources", source.name)

if not isinstance(config, SourceConfig):
raise DbtInternalError(
f"Calculated a {type(config)} for a source, but expected a SourceConfig"
Expand Down Expand Up @@ -192,6 +197,7 @@ def parse_source(self, target: UnpatchedSourceDefinition) -> SourceDefinition:
tags=tags,
config=config,
unrendered_config=unrendered_config,
vars=source_vars,
)

if (
Expand Down
Loading
Loading