Skip to content

Commit

Permalink
Fix global flags for lists (#863)
Browse files Browse the repository at this point in the history
Correctly deals with global flags when they are a list.
Note - in the module there's no distinguishing which are and which
aren't.

Co-authored-by: Tatiana Al-Chueyr <[email protected]>
  • Loading branch information
ms32035 and tatiana authored Apr 26, 2024
1 parent 73a9ba2 commit d1dbc41
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 27 deletions.
35 changes: 17 additions & 18 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,28 @@ def add_global_flags(self) -> list[str]:

dbt_name = f"--{global_flag.replace('_', '-')}"
global_flag_value = self.__getattribute__(global_flag)
if global_flag_value is not None:
if isinstance(global_flag_value, dict):
yaml_string = yaml.dump(global_flag_value)
flags.extend([dbt_name, yaml_string])
else:
flags.extend([dbt_name, str(global_flag_value)])
flags.extend(self._process_global_flag(dbt_name, global_flag_value))

for global_boolean_flag in self.global_boolean_flags:
if self.__getattribute__(global_boolean_flag):
flags.append(f"--{global_boolean_flag.replace('_', '-')}")
return flags

@staticmethod
def _process_global_flag(flag_name: str, flag_value: Any) -> list[str]:
"""Helper method to process global flags and reduce complexity."""
if flag_value is None:
return []
elif isinstance(flag_value, dict):
yaml_string = yaml.dump(flag_value)
return [flag_name, yaml_string]
elif isinstance(flag_value, list) and flag_value:
return [flag_name, " ".join(flag_value)]
elif isinstance(flag_value, list):
return []
else:
return [flag_name, str(flag_value)]

def add_cmd_flags(self) -> list[str]:
"""Allows subclasses to override to add flags for their dbt command"""
return []
Expand Down Expand Up @@ -373,18 +384,6 @@ def __init__(
self.selector = selector
super().__init__(exclude=exclude, select=select, selector=selector, **kwargs) # type: ignore

def add_cmd_flags(self) -> list[str]:
flags = []
if self.exclude:
flags.extend(["--exclude", *self.exclude])

if self.select:
flags.extend(["--select", *self.select])

if self.selector:
flags.extend(["--selector", self.selector])
return flags


class DbtRunOperationMixin:
"""
Expand Down
46 changes: 37 additions & 9 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ def test_dbt_base_operator_add_global_flags() -> None:
"end_time": "{{ data_interval_end.strftime('%Y%m%d%H%M%S') }}",
},
no_version_check=True,
select=["my_first_model", "my_second_model"],
)
assert dbt_base_operator.add_global_flags() == [
"--select",
"my_first_model my_second_model",
"--vars",
"end_time: '{{ data_interval_end.strftime(''%Y%m%d%H%M%S'') }}'\n"
"start_time: '{{ data_interval_start.strftime(''%Y%m%d%H%M%S'') }}'\n",
Expand Down Expand Up @@ -564,37 +567,62 @@ def test_store_compiled_sql() -> None:
@pytest.mark.parametrize(
"operator_class,kwargs,expected_call_kwargs",
[
(DbtSeedLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}),
(DbtBuildLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}),
(DbtRunLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}),
(
DbtSeedLocalOperator,
{"full_refresh": True},
{"context": {}, "env": {}, "cmd_flags": ["seed", "--full-refresh"]},
),
(
DbtBuildLocalOperator,
{"full_refresh": True},
{"context": {}, "env": {}, "cmd_flags": ["build", "--full-refresh"]},
),
(
DbtRunLocalOperator,
{"full_refresh": True},
{"context": {}, "env": {}, "cmd_flags": ["run", "--full-refresh"]},
),
(
DbtTestLocalOperator,
{},
{"context": {}, "env": {}, "cmd_flags": ["test"]},
),
(
DbtTestLocalOperator,
{"select": []},
{"context": {}, "env": {}, "cmd_flags": ["test"]},
),
(
DbtTestLocalOperator,
{"full_refresh": True, "select": ["tag:daily"], "exclude": ["tag:disabled"]},
{"context": {}, "cmd_flags": ["--exclude", "tag:disabled", "--select", "tag:daily"]},
{"context": {}, "env": {}, "cmd_flags": ["test", "--select", "tag:daily", "--exclude", "tag:disabled"]},
),
(
DbtTestLocalOperator,
{"full_refresh": True, "selector": "nightly_snowplow"},
{"context": {}, "cmd_flags": ["--selector", "nightly_snowplow"]},
{"context": {}, "env": {}, "cmd_flags": ["test", "--selector", "nightly_snowplow"]},
),
(
DbtRunOperationLocalOperator,
{"args": {"days": 7, "dry_run": True}, "macro_name": "bla"},
{"context": {}, "cmd_flags": ["--args", "days: 7\ndry_run: true\n"]},
{"context": {}, "env": {}, "cmd_flags": ["run-operation", "bla", "--args", "days: 7\ndry_run: true\n"]},
),
],
)
@patch("cosmos.operators.local.DbtLocalBaseOperator.build_and_run_cmd")
def test_operator_execute_with_flags(mock_build_and_run_cmd, operator_class, kwargs, expected_call_kwargs):
@patch("cosmos.operators.local.DbtLocalBaseOperator.run_command")
def test_operator_execute_with_flags(mock_run_cmd, operator_class, kwargs, expected_call_kwargs):
task = operator_class(
profile_config=profile_config,
task_id="my-task",
project_dir="my/dir",
invocation_mode=InvocationMode.DBT_RUNNER,
**kwargs,
)
task.get_env = MagicMock(return_value={})
task.execute(context={})
mock_build_and_run_cmd.assert_called_once_with(**expected_call_kwargs)
mock_run_cmd.assert_called_once_with(
cmd=[task.dbt_executable_path, *expected_call_kwargs.pop("cmd_flags")], **expected_call_kwargs
)


@pytest.mark.parametrize(
Expand Down

0 comments on commit d1dbc41

Please sign in to comment.