diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4610fea2c..5e76c4a98 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,14 +47,14 @@ repos: - id: remove-tabs exclude: ^docs/make.bat$|^docs/Makefile$|^dev/dags/dbt/jaffle_shop/seeds/raw_orders.csv$ - repo: https://github.com/asottile/pyupgrade - rev: v3.15.1 + rev: v3.15.2 hooks: - id: pyupgrade args: - --py37-plus - --keep-runtime-typing - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.3.3 + rev: v0.3.5 hooks: - id: ruff args: diff --git a/cosmos/operators/azure_container_instance.py b/cosmos/operators/azure_container_instance.py index 397a47551..d8427b2fb 100644 --- a/cosmos/operators/azure_container_instance.py +++ b/cosmos/operators/azure_container_instance.py @@ -28,7 +28,7 @@ ) -class DbtAzureContainerInstanceBaseOperator(AzureContainerInstancesOperator, AbstractDbtBaseOperator): # type: ignore +class DbtAzureContainerInstanceBaseOperator(AbstractDbtBaseOperator, AzureContainerInstancesOperator): # type: ignore """ Executes a dbt core cli command in an Azure Container Instance """ @@ -66,7 +66,7 @@ def __init__( def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> None: self.build_command(context, cmd_flags) self.log.info(f"Running command: {self.command}") - result = super().execute(context) + result = AzureContainerInstancesOperator.execute(self, context) logger.info(result) def build_command(self, context: Context, cmd_flags: list[str] | None = None) -> None: @@ -79,13 +79,13 @@ def build_command(self, context: Context, cmd_flags: list[str] | None = None) -> self.command: list[str] = dbt_cmd -class DbtLSAzureContainerInstanceOperator(DbtLSMixin, DbtAzureContainerInstanceBaseOperator): +class DbtLSAzureContainerInstanceOperator(DbtLSMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore """ Executes a dbt core ls command. """ -class DbtSeedAzureContainerInstanceOperator(DbtSeedMixin, DbtAzureContainerInstanceBaseOperator): +class DbtSeedAzureContainerInstanceOperator(DbtSeedMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore """ Executes a dbt core seed command. @@ -95,14 +95,14 @@ class DbtSeedAzureContainerInstanceOperator(DbtSeedMixin, DbtAzureContainerInsta template_fields: Sequence[str] = DbtAzureContainerInstanceBaseOperator.template_fields + DbtRunMixin.template_fields # type: ignore[operator] -class DbtSnapshotAzureContainerInstanceOperator(DbtSnapshotMixin, DbtAzureContainerInstanceBaseOperator): +class DbtSnapshotAzureContainerInstanceOperator(DbtSnapshotMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore """ Executes a dbt core snapshot command. """ -class DbtRunAzureContainerInstanceOperator(DbtRunMixin, DbtAzureContainerInstanceBaseOperator): +class DbtRunAzureContainerInstanceOperator(DbtRunMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore """ Executes a dbt core run command. """ @@ -110,7 +110,7 @@ class DbtRunAzureContainerInstanceOperator(DbtRunMixin, DbtAzureContainerInstanc template_fields: Sequence[str] = DbtAzureContainerInstanceBaseOperator.template_fields + DbtRunMixin.template_fields # type: ignore[operator] -class DbtTestAzureContainerInstanceOperator(DbtTestMixin, DbtAzureContainerInstanceBaseOperator): +class DbtTestAzureContainerInstanceOperator(DbtTestMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore """ Executes a dbt core test command. """ @@ -121,7 +121,7 @@ def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwar self.on_warning_callback = on_warning_callback -class DbtRunOperationAzureContainerInstanceOperator(DbtRunOperationMixin, DbtAzureContainerInstanceBaseOperator): +class DbtRunOperationAzureContainerInstanceOperator(DbtRunOperationMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore """ Executes a dbt core run-operation command. diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index 94d5d4a8c..b9f25758d 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -43,10 +43,11 @@ class AbstractDbtBaseOperator(BaseOperator, metaclass=ABCMeta): environment variables for the new process; these are used instead of inheriting the current process environment, which is the default behavior. (templated) - :param append_env: If False(default) uses the environment variables passed in env params - and does not inherit the current process environment. If True, inherits the environment variables + :param append_env: . If True (default), inherits the environment variables from current passes and then environment variable passed by the user will either update the existing - inherited environment variables or the new variables gets appended to it + inherited environment variables or the new variables gets appended to it. + If False, only uses the environment variables passed in env params + and does not inherit the current process environment. :param output_encoding: Output encoding of bash command :param skip_exit_code: If task exits with this exit code, leave the task in ``skipped`` state (default: 99). If set to ``None``, any non-zero @@ -99,7 +100,7 @@ def __init__( db_name: str | None = None, schema: str | None = None, env: dict[str, Any] | None = None, - append_env: bool = False, + append_env: bool = True, output_encoding: str = "utf-8", skip_exit_code: int = 99, partial_parse: bool = True, diff --git a/dev/dags/example_cosmos_sources.py b/dev/dags/example_cosmos_sources.py index 1a85b6d9f..346f37370 100644 --- a/dev/dags/example_cosmos_sources.py +++ b/dev/dags/example_cosmos_sources.py @@ -17,7 +17,11 @@ from pathlib import Path from airflow.models.dag import DAG -from airflow.operators.dummy import DummyOperator + +try: # available since Airflow 2.4.0 + from airflow.operators.empty import EmptyOperator +except ImportError: + from airflow.operators.dummy import DummyOperator as EmptyOperator from airflow.utils.task_group import TaskGroup from cosmos import DbtDag, ProfileConfig, ProjectConfig, RenderConfig @@ -38,21 +42,21 @@ # [START custom_dbt_nodes] -# Cosmos will use this function to generate a DummyOperator task when it finds a source node, in the manifest. +# Cosmos will use this function to generate an empty task when it finds a source node, in the manifest. # A more realistic use case could be to use an Airflow sensor to represent a source. def convert_source(dag: DAG, task_group: TaskGroup, node: DbtNode, **kwargs): """ - Return an instance of DummyOperator to represent a dbt "source" node. + Return an instance of a desired operator to represent a dbt "source" node. """ - return DummyOperator(dag=dag, task_group=task_group, task_id=f"{node.name}_source") + return EmptyOperator(dag=dag, task_group=task_group, task_id=f"{node.name}_source") -# Cosmos will use this function to generate a DummyOperator task when it finds a exposure node, in the manifest. +# Cosmos will use this function to generate an empty task when it finds a exposure node, in the manifest. def convert_exposure(dag: DAG, task_group: TaskGroup, node: DbtNode, **kwargs): """ - Return an instance of DummyOperator to represent a dbt "exposure" node. + Return an instance of a desired operator to represent a dbt "exposure" node. """ - return DummyOperator(dag=dag, task_group=task_group, task_id=f"{node.name}_exposure") + return EmptyOperator(dag=dag, task_group=task_group, task_id=f"{node.name}_exposure") # Use `RenderConfig` to tell Cosmos, given a node type, how to convert a dbt node into an Airflow task or task group. diff --git a/tests/operators/test_azure_container_instance.py b/tests/operators/test_azure_container_instance.py index 01fa3e20e..84d733ce3 100644 --- a/tests/operators/test_azure_container_instance.py +++ b/tests/operators/test_azure_container_instance.py @@ -53,6 +53,7 @@ def test_dbt_azure_container_instance_operator_get_env(p_context_to_airflow_vars name="my-aci", resource_group="my-rg", project_dir="my/dir", + append_env=False, ) dbt_base_operator.env = { "start_date": "20220101", @@ -90,6 +91,7 @@ def test_dbt_azure_container_instance_operator_check_environment_variables( resource_group="my-rg", project_dir="my/dir", environment_variables={"FOO": "BAR"}, + append_env=False, ) dbt_base_operator.env = { "start_date": "20220101", @@ -143,3 +145,25 @@ def test_dbt_azure_container_instance_build_command(): "start_time: '{{ data_interval_start.strftime(''%Y%m%d%H%M%S'') }}'\n", "--no-version-check", ] + + +@patch("cosmos.operators.azure_container_instance.AzureContainerInstancesOperator.execute") +def test_dbt_azure_container_instance_build_and_run_cmd(mock_execute): + dbt_base_operator = ConcreteDbtAzureContainerInstanceOperator( + ci_conn_id="my_airflow_connection", + task_id="my-task", + image="my_image", + region="Mordor", + name="my-aci", + resource_group="my-rg", + project_dir="my/dir", + environment_variables={"FOO": "BAR"}, + ) + mock_build_command = MagicMock() + dbt_base_operator.build_command = mock_build_command + + mock_context = MagicMock() + dbt_base_operator.build_and_run_cmd(context=mock_context) + + mock_build_command.assert_called_with(mock_context, None) + mock_execute.assert_called_once_with(dbt_base_operator, mock_context) diff --git a/tests/operators/test_docker.py b/tests/operators/test_docker.py index ad3ec5485..2cfb6b835 100644 --- a/tests/operators/test_docker.py +++ b/tests/operators/test_docker.py @@ -73,10 +73,7 @@ def test_dbt_docker_operator_get_env(p_context_to_airflow_vars: MagicMock, base_ If an end user passes in a """ dbt_base_operator = base_operator( - conn_id="my_airflow_connection", - task_id="my-task", - image="my_image", - project_dir="my/dir", + conn_id="my_airflow_connection", task_id="my-task", image="my_image", project_dir="my/dir", append_env=False ) dbt_base_operator.env = { "start_date": "20220101", diff --git a/tests/operators/test_kubernetes.py b/tests/operators/test_kubernetes.py index 75739111f..d0be2acad 100644 --- a/tests/operators/test_kubernetes.py +++ b/tests/operators/test_kubernetes.py @@ -81,10 +81,7 @@ def test_dbt_kubernetes_operator_get_env(p_context_to_airflow_vars: MagicMock, b If an end user passes in a """ dbt_kube_operator = base_operator( - conn_id="my_airflow_connection", - task_id="my-task", - image="my_image", - project_dir="my/dir", + conn_id="my_airflow_connection", task_id="my-task", image="my_image", project_dir="my/dir", append_env=False ) dbt_kube_operator.env = { "start_date": "20220101", @@ -254,7 +251,7 @@ def cleanup(pod: str, remote_pod: str): def test_created_pod(): - ls_kwargs = {"env_vars": {"FOO": "BAR"}, "namespace": "foo"} + ls_kwargs = {"env_vars": {"FOO": "BAR"}, "namespace": "foo", "append_env": False} ls_kwargs.update(base_kwargs) ls_operator = DbtLSKubernetesOperator(**ls_kwargs) ls_operator.hook = MagicMock() diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index 956c2a8d1..17b98ee56 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -329,9 +329,7 @@ def test_dbt_base_operator_get_env(p_context_to_airflow_vars: MagicMock) -> None If an end user passes in a """ dbt_base_operator = ConcreteDbtLocalBaseOperator( - profile_config=profile_config, - task_id="my-task", - project_dir="my/dir", + profile_config=profile_config, task_id="my-task", project_dir="my/dir", append_env=False ) dbt_base_operator.env = { "start_date": "20220101",