diff --git a/cosmos/converter.py b/cosmos/converter.py index 9a8d9d366..05b523a1d 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -142,6 +142,7 @@ def __init__( select=select, dbt_cmd=dbt_executable_path, profile_config=profile_config, + operator_args=operator_args, dbt_deps=dbt_deps, ) dbt_graph.load(method=load_mode, execution_mode=execution_mode) diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index ee29fdad0..07a55ad79 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -80,12 +80,14 @@ def __init__( select: list[str] | None = None, dbt_cmd: str = get_system_dbt(), profile_config: ProfileConfig | None = None, + operator_args: dict[str, Any] | None = None, dbt_deps: bool | None = True, ): self.project = project self.exclude = exclude or [] self.select = select or [] self.profile_config = profile_config + self.operator_args = operator_args or {} self.dbt_deps = dbt_deps # specific to loading using ls @@ -282,6 +284,7 @@ def load_via_custom_parser(self) -> None: dbt_models_dir=self.project.models_dir.stem if self.project.models_dir else None, dbt_seeds_dir=self.project.seeds_dir.stem if self.project.seeds_dir else None, project_name=self.project.name, + operator_args=self.operator_args, ) nodes = {} models = itertools.chain(project.models.items(), project.snapshots.items(), project.seeds.items()) diff --git a/cosmos/dbt/parser/project.py b/cosmos/dbt/parser/project.py index 4ad52d293..3e1b18e10 100644 --- a/cosmos/dbt/parser/project.py +++ b/cosmos/dbt/parser/project.py @@ -129,6 +129,7 @@ class DbtModel: name: str type: DbtModelType path: Path + operator_args: Dict[str, Any] = field(default_factory=dict) config: DbtModelConfig = field(default_factory=DbtModelConfig) def __post_init__(self) -> None: @@ -137,6 +138,7 @@ def __post_init__(self) -> None: """ # first, get an empty config config = DbtModelConfig() + var_args: Dict[str, Any] = self.operator_args.get("vars", {}) if self.type == DbtModelType.DBT_MODEL: # get the code from the file @@ -165,23 +167,40 @@ def __post_init__(self) -> None: # iterate over the jinja nodes to extract info for base_node in jinja2_ast.find_all(jinja2.nodes.Call): if hasattr(base_node.node, "name"): - # check we have a ref - this indicates a dependency - if base_node.node.name == "ref": - # if it is, get the first argument - first_arg = base_node.args[0] - if isinstance(first_arg, jinja2.nodes.Const): - # and add it to the config - config.upstream_models.add(first_arg.value) - - # check if we have a config - this could contain tags - if base_node.node.name == "config": - # if it is, check if any kwargs are tags - for kwarg in base_node.kwargs: - for selector in self.config.config_types: - extracted_config = self._extract_config(kwarg=kwarg, config_name=selector) - config.config_selectors |= ( - set(extracted_config) if isinstance(extracted_config, (str, List)) else set() - ) + try: + # check we have a ref - this indicates a dependency + if base_node.node.name == "ref": + # if it is, get the first argument + first_arg = base_node.args[0] + # if it contains vars, render the value of the var + if isinstance(first_arg, jinja2.nodes.Concat): + value = "" + for node in first_arg.nodes: + if isinstance(node, jinja2.nodes.Const): + value += node.value + elif ( + isinstance(node, jinja2.nodes.Call) + and isinstance(node.node, jinja2.nodes.Name) + and isinstance(node.args[0], jinja2.nodes.Const) + and node.node.name == "var" + ): + value += var_args[node.args[0].value] + config.upstream_models.add(value) + elif isinstance(first_arg, jinja2.nodes.Const): + # and add it to the config + config.upstream_models.add(first_arg.value) + + # check if we have a config - this could contain tags + if base_node.node.name == "config": + # if it is, check if any kwargs are tags + for kwarg in base_node.kwargs: + for selector in self.config.config_types: + extracted_config = self._extract_config(kwarg=kwarg, config_name=selector) + config.config_selectors |= ( + set(extracted_config) if isinstance(extracted_config, (str, List)) else set() + ) + except KeyError as e: + logger.warning(f"Could not add upstream model for config in {self.path}: {e}") # set the config and set the parsed file flag to true self.config = config @@ -236,6 +255,8 @@ class DbtProject: snapshots_dir: Path = field(init=False) seeds_dir: Path = field(init=False) + operator_args: Dict[str, Any] = field(default_factory=dict) + def __post_init__(self) -> None: """ Initializes the parser. @@ -287,6 +308,7 @@ def _handle_csv_file(self, path: Path) -> None: name=model_name, type=DbtModelType.DBT_SEED, path=path, + operator_args=self.operator_args, ) # add the model to the project self.seeds[model_name] = model @@ -304,6 +326,7 @@ def _handle_sql_file(self, path: Path) -> None: name=model_name, type=DbtModelType.DBT_MODEL, path=path, + operator_args=self.operator_args, ) # add the model to the project self.models[model.name] = model @@ -313,6 +336,7 @@ def _handle_sql_file(self, path: Path) -> None: name=model_name, type=DbtModelType.DBT_SNAPSHOT, path=path, + operator_args=self.operator_args, ) # add the snapshot to the project self.snapshots[model.name] = model diff --git a/tests/dbt/parser/test_project.py b/tests/dbt/parser/test_project.py index 1a3d66f83..544801526 100644 --- a/tests/dbt/parser/test_project.py +++ b/tests/dbt/parser/test_project.py @@ -190,3 +190,18 @@ def test_dbtmodelconfig_with_sources(tmp_path): dbt_model = DbtModel(name="some_name", type=DbtModelType.DBT_MODEL, path=path_with_sources) assert "sample_source" not in dbt_model.config.upstream_models + + +def test_dbtmodelconfig_with_vars(tmp_path): + model_sql = SAMPLE_MODEL_SQL_PATH.read_text() + model_with_vars_sql = model_sql.replace("ref('stg_customers')", "ref('stg_customers_'~ var('country_code'))") + path_with_sources = tmp_path / "customers_with_sources.sql" + path_with_sources.write_text(model_with_vars_sql) + + dbt_model = DbtModel( + name="some_name", + type=DbtModelType.DBT_MODEL, + path=path_with_sources, + operator_args={"vars": {"country_code": "us"}}, + ) + assert "stg_customers_us" in dbt_model.config.upstream_models