Skip to content

Commit

Permalink
Merge branch 'astronomer:main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
SiddiqueAhmad authored Apr 4, 2024
2 parents f7d3ea4 + 7c4d7d6 commit a8e3e49
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 33 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ repos:
- id: remove-tabs
exclude: ^docs/make.bat$|^docs/Makefile$|^dev/dags/dbt/jaffle_shop/seeds/raw_orders.csv$
- repo: https://github.com/asottile/pyupgrade
rev: v3.15.1
rev: v3.15.2
hooks:
- id: pyupgrade
args:
- --py37-plus
- --keep-runtime-typing
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.3
rev: v0.3.5
hooks:
- id: ruff
args:
Expand Down
16 changes: 8 additions & 8 deletions cosmos/operators/azure_container_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)


class DbtAzureContainerInstanceBaseOperator(AzureContainerInstancesOperator, AbstractDbtBaseOperator): # type: ignore
class DbtAzureContainerInstanceBaseOperator(AbstractDbtBaseOperator, AzureContainerInstancesOperator): # type: ignore
"""
Executes a dbt core cli command in an Azure Container Instance
"""
Expand Down Expand Up @@ -66,7 +66,7 @@ def __init__(
def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None) -> None:
self.build_command(context, cmd_flags)
self.log.info(f"Running command: {self.command}")
result = super().execute(context)
result = AzureContainerInstancesOperator.execute(self, context)
logger.info(result)

def build_command(self, context: Context, cmd_flags: list[str] | None = None) -> None:
Expand All @@ -79,13 +79,13 @@ def build_command(self, context: Context, cmd_flags: list[str] | None = None) ->
self.command: list[str] = dbt_cmd


class DbtLSAzureContainerInstanceOperator(DbtLSMixin, DbtAzureContainerInstanceBaseOperator):
class DbtLSAzureContainerInstanceOperator(DbtLSMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore
"""
Executes a dbt core ls command.
"""


class DbtSeedAzureContainerInstanceOperator(DbtSeedMixin, DbtAzureContainerInstanceBaseOperator):
class DbtSeedAzureContainerInstanceOperator(DbtSeedMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore
"""
Executes a dbt core seed command.
Expand All @@ -95,22 +95,22 @@ class DbtSeedAzureContainerInstanceOperator(DbtSeedMixin, DbtAzureContainerInsta
template_fields: Sequence[str] = DbtAzureContainerInstanceBaseOperator.template_fields + DbtRunMixin.template_fields # type: ignore[operator]


class DbtSnapshotAzureContainerInstanceOperator(DbtSnapshotMixin, DbtAzureContainerInstanceBaseOperator):
class DbtSnapshotAzureContainerInstanceOperator(DbtSnapshotMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore
"""
Executes a dbt core snapshot command.
"""


class DbtRunAzureContainerInstanceOperator(DbtRunMixin, DbtAzureContainerInstanceBaseOperator):
class DbtRunAzureContainerInstanceOperator(DbtRunMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore
"""
Executes a dbt core run command.
"""

template_fields: Sequence[str] = DbtAzureContainerInstanceBaseOperator.template_fields + DbtRunMixin.template_fields # type: ignore[operator]


class DbtTestAzureContainerInstanceOperator(DbtTestMixin, DbtAzureContainerInstanceBaseOperator):
class DbtTestAzureContainerInstanceOperator(DbtTestMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore
"""
Executes a dbt core test command.
"""
Expand All @@ -121,7 +121,7 @@ def __init__(self, on_warning_callback: Callable[..., Any] | None = None, **kwar
self.on_warning_callback = on_warning_callback


class DbtRunOperationAzureContainerInstanceOperator(DbtRunOperationMixin, DbtAzureContainerInstanceBaseOperator):
class DbtRunOperationAzureContainerInstanceOperator(DbtRunOperationMixin, DbtAzureContainerInstanceBaseOperator): # type: ignore
"""
Executes a dbt core run-operation command.
Expand Down
9 changes: 5 additions & 4 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@ class AbstractDbtBaseOperator(BaseOperator, metaclass=ABCMeta):
environment variables for the new process; these are used instead
of inheriting the current process environment, which is the default
behavior. (templated)
:param append_env: If False(default) uses the environment variables passed in env params
and does not inherit the current process environment. If True, inherits the environment variables
:param append_env: . If True (default), inherits the environment variables
from current passes and then environment variable passed by the user will either update the existing
inherited environment variables or the new variables gets appended to it
inherited environment variables or the new variables gets appended to it.
If False, only uses the environment variables passed in env params
and does not inherit the current process environment.
:param output_encoding: Output encoding of bash command
:param skip_exit_code: If task exits with this exit code, leave the task
in ``skipped`` state (default: 99). If set to ``None``, any non-zero
Expand Down Expand Up @@ -99,7 +100,7 @@ def __init__(
db_name: str | None = None,
schema: str | None = None,
env: dict[str, Any] | None = None,
append_env: bool = False,
append_env: bool = True,
output_encoding: str = "utf-8",
skip_exit_code: int = 99,
partial_parse: bool = True,
Expand Down
18 changes: 11 additions & 7 deletions dev/dags/example_cosmos_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from pathlib import Path

from airflow.models.dag import DAG
from airflow.operators.dummy import DummyOperator

try: # available since Airflow 2.4.0
from airflow.operators.empty import EmptyOperator
except ImportError:
from airflow.operators.dummy import DummyOperator as EmptyOperator
from airflow.utils.task_group import TaskGroup

from cosmos import DbtDag, ProfileConfig, ProjectConfig, RenderConfig
Expand All @@ -38,21 +42,21 @@


# [START custom_dbt_nodes]
# Cosmos will use this function to generate a DummyOperator task when it finds a source node, in the manifest.
# Cosmos will use this function to generate an empty task when it finds a source node, in the manifest.
# A more realistic use case could be to use an Airflow sensor to represent a source.
def convert_source(dag: DAG, task_group: TaskGroup, node: DbtNode, **kwargs):
"""
Return an instance of DummyOperator to represent a dbt "source" node.
Return an instance of a desired operator to represent a dbt "source" node.
"""
return DummyOperator(dag=dag, task_group=task_group, task_id=f"{node.name}_source")
return EmptyOperator(dag=dag, task_group=task_group, task_id=f"{node.name}_source")


# Cosmos will use this function to generate a DummyOperator task when it finds a exposure node, in the manifest.
# Cosmos will use this function to generate an empty task when it finds a exposure node, in the manifest.
def convert_exposure(dag: DAG, task_group: TaskGroup, node: DbtNode, **kwargs):
"""
Return an instance of DummyOperator to represent a dbt "exposure" node.
Return an instance of a desired operator to represent a dbt "exposure" node.
"""
return DummyOperator(dag=dag, task_group=task_group, task_id=f"{node.name}_exposure")
return EmptyOperator(dag=dag, task_group=task_group, task_id=f"{node.name}_exposure")


# Use `RenderConfig` to tell Cosmos, given a node type, how to convert a dbt node into an Airflow task or task group.
Expand Down
24 changes: 24 additions & 0 deletions tests/operators/test_azure_container_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_dbt_azure_container_instance_operator_get_env(p_context_to_airflow_vars
name="my-aci",
resource_group="my-rg",
project_dir="my/dir",
append_env=False,
)
dbt_base_operator.env = {
"start_date": "20220101",
Expand Down Expand Up @@ -90,6 +91,7 @@ def test_dbt_azure_container_instance_operator_check_environment_variables(
resource_group="my-rg",
project_dir="my/dir",
environment_variables={"FOO": "BAR"},
append_env=False,
)
dbt_base_operator.env = {
"start_date": "20220101",
Expand Down Expand Up @@ -143,3 +145,25 @@ def test_dbt_azure_container_instance_build_command():
"start_time: '{{ data_interval_start.strftime(''%Y%m%d%H%M%S'') }}'\n",
"--no-version-check",
]


@patch("cosmos.operators.azure_container_instance.AzureContainerInstancesOperator.execute")
def test_dbt_azure_container_instance_build_and_run_cmd(mock_execute):
dbt_base_operator = ConcreteDbtAzureContainerInstanceOperator(
ci_conn_id="my_airflow_connection",
task_id="my-task",
image="my_image",
region="Mordor",
name="my-aci",
resource_group="my-rg",
project_dir="my/dir",
environment_variables={"FOO": "BAR"},
)
mock_build_command = MagicMock()
dbt_base_operator.build_command = mock_build_command

mock_context = MagicMock()
dbt_base_operator.build_and_run_cmd(context=mock_context)

mock_build_command.assert_called_with(mock_context, None)
mock_execute.assert_called_once_with(dbt_base_operator, mock_context)
5 changes: 1 addition & 4 deletions tests/operators/test_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,7 @@ def test_dbt_docker_operator_get_env(p_context_to_airflow_vars: MagicMock, base_
If an end user passes in a
"""
dbt_base_operator = base_operator(
conn_id="my_airflow_connection",
task_id="my-task",
image="my_image",
project_dir="my/dir",
conn_id="my_airflow_connection", task_id="my-task", image="my_image", project_dir="my/dir", append_env=False
)
dbt_base_operator.env = {
"start_date": "20220101",
Expand Down
7 changes: 2 additions & 5 deletions tests/operators/test_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,7 @@ def test_dbt_kubernetes_operator_get_env(p_context_to_airflow_vars: MagicMock, b
If an end user passes in a
"""
dbt_kube_operator = base_operator(
conn_id="my_airflow_connection",
task_id="my-task",
image="my_image",
project_dir="my/dir",
conn_id="my_airflow_connection", task_id="my-task", image="my_image", project_dir="my/dir", append_env=False
)
dbt_kube_operator.env = {
"start_date": "20220101",
Expand Down Expand Up @@ -254,7 +251,7 @@ def cleanup(pod: str, remote_pod: str):


def test_created_pod():
ls_kwargs = {"env_vars": {"FOO": "BAR"}, "namespace": "foo"}
ls_kwargs = {"env_vars": {"FOO": "BAR"}, "namespace": "foo", "append_env": False}
ls_kwargs.update(base_kwargs)
ls_operator = DbtLSKubernetesOperator(**ls_kwargs)
ls_operator.hook = MagicMock()
Expand Down
4 changes: 1 addition & 3 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,7 @@ def test_dbt_base_operator_get_env(p_context_to_airflow_vars: MagicMock) -> None
If an end user passes in a
"""
dbt_base_operator = ConcreteDbtLocalBaseOperator(
profile_config=profile_config,
task_id="my-task",
project_dir="my/dir",
profile_config=profile_config, task_id="my-task", project_dir="my/dir", append_env=False
)
dbt_base_operator.env = {
"start_date": "20220101",
Expand Down

0 comments on commit a8e3e49

Please sign in to comment.