Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix global flags for lists #863

Merged
merged 12 commits into from
Apr 26, 2024
14 changes: 2 additions & 12 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def add_global_flags(self) -> list[str]:
if isinstance(global_flag_value, dict):
yaml_string = yaml.dump(global_flag_value)
flags.extend([dbt_name, yaml_string])
elif isinstance(global_flag_value, list):
flags.extend([i for j in global_flag_value for i in [dbt_name, j]])
ms32035 marked this conversation as resolved.
Show resolved Hide resolved
else:
flags.extend([dbt_name, str(global_flag_value)])
for global_boolean_flag in self.global_boolean_flags:
Expand Down Expand Up @@ -337,18 +339,6 @@ def __init__(
self.selector = selector
super().__init__(exclude=exclude, select=select, selector=selector, **kwargs) # type: ignore

def add_cmd_flags(self) -> list[str]:
flags = []
if self.exclude:
flags.extend(["--exclude", *self.exclude])

if self.select:
flags.extend(["--select", *self.select])

if self.selector:
flags.extend(["--selector", self.selector])
return flags


class DbtRunOperationMixin:
"""
Expand Down
26 changes: 18 additions & 8 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,27 +566,35 @@ def test_store_compiled_sql() -> None:
@pytest.mark.parametrize(
"operator_class,kwargs,expected_call_kwargs",
[
(DbtSeedLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}),
(DbtRunLocalOperator, {"full_refresh": True}, {"context": {}, "cmd_flags": ["--full-refresh"]}),
(
DbtSeedLocalOperator,
{"full_refresh": True},
{"context": {}, "env": {}, "cmd_flags": ["seed", "--full-refresh"]},
),
(
DbtRunLocalOperator,
{"full_refresh": True},
{"context": {}, "env": {}, "cmd_flags": ["run", "--full-refresh"]},
),
(
DbtTestLocalOperator,
{"full_refresh": True, "select": ["tag:daily"], "exclude": ["tag:disabled"]},
{"context": {}, "cmd_flags": ["--exclude", "tag:disabled", "--select", "tag:daily"]},
{"context": {}, "env": {}, "cmd_flags": ["test", "--select", "tag:daily", "--exclude", "tag:disabled"]},
),
(
DbtTestLocalOperator,
{"full_refresh": True, "selector": "nightly_snowplow"},
{"context": {}, "cmd_flags": ["--selector", "nightly_snowplow"]},
{"context": {}, "env": {}, "cmd_flags": ["test", "--selector", "nightly_snowplow"]},
),
(
DbtRunOperationLocalOperator,
{"args": {"days": 7, "dry_run": True}, "macro_name": "bla"},
{"context": {}, "cmd_flags": ["--args", "days: 7\ndry_run: true\n"]},
{"context": {}, "env": {}, "cmd_flags": ["run-operation", "bla", "--args", "days: 7\ndry_run: true\n"]},
),
],
)
@patch("cosmos.operators.local.DbtLocalBaseOperator.build_and_run_cmd")
def test_operator_execute_with_flags(mock_build_and_run_cmd, operator_class, kwargs, expected_call_kwargs):
@patch("cosmos.operators.local.DbtLocalBaseOperator.run_command")
def test_operator_execute_with_flags(mock_run_cmd, operator_class, kwargs, expected_call_kwargs):
ms32035 marked this conversation as resolved.
Show resolved Hide resolved
task = operator_class(
profile_config=profile_config,
task_id="my-task",
Expand All @@ -595,7 +603,9 @@ def test_operator_execute_with_flags(mock_build_and_run_cmd, operator_class, kwa
**kwargs,
)
task.execute(context={})
mock_build_and_run_cmd.assert_called_once_with(**expected_call_kwargs)
mock_run_cmd.assert_called_once_with(
cmd=[task.dbt_executable_path, *expected_call_kwargs.pop("cmd_flags")], **expected_call_kwargs
)


@pytest.mark.parametrize(
Expand Down
Loading