Skip to content

Commit

Permalink
Minor refactor on VirtualenvOperators & add test for PR #1253 & (#1286
Browse files Browse the repository at this point in the history
)

Cosmos virtualenv operators are using the system dbt instead of the
virtualenv dbt.

Create a test case that illustrates issue #1246. This test fails with
Cosmos 1.7 (and the current main branch) and passes when using PR #1252.

This PR also introduces two refactors:
- Reuse the parent class method where applicable, as opposed to
re-writing it completely
- Force the Virtualenv invocation mode to be `SUBPROCESS ` since
Airflow/Cosmos are not able to import dbt as a library if it is not part
of the same Python virtualenv
  • Loading branch information
tatiana authored Oct 29, 2024
1 parent 191c635 commit 4c9b28f
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 9 deletions.
12 changes: 3 additions & 9 deletions cosmos/operators/virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from airflow.utils.python_virtualenv import prepare_virtualenv

from cosmos import settings
from cosmos.constants import InvocationMode
from cosmos.exceptions import CosmosValueError
from cosmos.hooks.subprocess import FullOutputSubprocessResult
from cosmos.log import get_logger
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(
self.is_virtualenv_dir_temporary = is_virtualenv_dir_temporary
self.max_retries_lock = settings.virtualenv_max_retries_lock
self._py_bin: str | None = None
kwargs["invocation_mode"] = InvocationMode.SUBPROCESS
super().__init__(**kwargs)
if not self.py_requirements:
self.log.error("Cosmos virtualenv operators require the `py_requirements` parameter")
Expand All @@ -86,15 +88,7 @@ def run_subprocess(self, command: list[str], env: dict[str, str], cwd: str) -> F
if self._py_bin is not None:
self.log.info(f"Using Python binary from virtualenv: {self._py_bin}")
command[0] = str(Path(self._py_bin).parent / "dbt")
self.log.info("Trying to run the command:\n %s\nFrom %s", command, cwd)
subprocess_result = self.subprocess_hook.run_command(
command=command,
env=env,
cwd=cwd,
output_encoding=self.output_encoding,
)
self.log.info(subprocess_result.output)
return subprocess_result
return super().run_subprocess(command, env, cwd)

def run_command(
self,
Expand Down
35 changes: 35 additions & 0 deletions dev/dags/example_virtualenv_mini.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os
from datetime import datetime
from pathlib import Path

from airflow.models import DAG

from cosmos import ProfileConfig
from cosmos.operators.virtualenv import DbtSeedVirtualenvOperator
from cosmos.profiles import PostgresUserPasswordProfileMapping

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
DBT_PROJ_DIR = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) / "jaffle_shop"

profile_config = ProfileConfig(
profile_name="default",
target_name="dev",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="example_conn",
profile_args={"schema": "public"},
),
)

with DAG("example_virtualenv_mini", start_date=datetime(2022, 1, 1)) as dag:
seed_operator = DbtSeedVirtualenvOperator(
profile_config=profile_config,
project_dir=DBT_PROJ_DIR,
task_id="seed",
dbt_cmd_flags=["--select", "raw_customers"],
install_deps=True,
append_env=True,
py_system_site_packages=False,
py_requirements=["dbt-postgres"],
virtualenv_dir=Path("/tmp/persistent-venv2"),
)
seed_operator
37 changes: 37 additions & 0 deletions tests/operators/test_virtualenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,24 @@
from pathlib import Path
from unittest.mock import MagicMock, patch

import airflow
import pytest
from airflow.models import DAG
from airflow.models.connection import Connection
from packaging.version import Version

from cosmos.config import ProfileConfig
from cosmos.constants import InvocationMode
from cosmos.exceptions import CosmosValueError
from cosmos.operators.virtualenv import DbtVirtualenvBaseOperator
from cosmos.profiles import PostgresUserPasswordProfileMapping

AIRFLOW_VERSION = Version(airflow.__version__)

DBT_PROJ_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop"

DAGS_FOLDER = Path(__file__).parent.parent.parent / "dev/dags/"

profile_config = ProfileConfig(
profile_name="default",
target_name="dev",
Expand All @@ -25,6 +33,15 @@
),
)

real_profile_config = ProfileConfig(
profile_name="default",
target_name="dev",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="example_conn",
profile_args={"schema": "public"},
),
)


class ConcreteDbtVirtualenvBaseOperator(DbtVirtualenvBaseOperator):

Expand Down Expand Up @@ -339,3 +356,23 @@ def test__release_venv_lock_current_process(tmpdir):
)
assert venv_operator._release_venv_lock() is None
assert not lockfile.exists()


@pytest.mark.skipif(
AIRFLOW_VERSION < Version("2.5"),
reason="This error is only reproducible with dag.test, which was introduced in Airflow 2.5",
)
@pytest.mark.integration
def test_integration_virtualenv_operator(caplog):
"""
Confirm we're using the correct dbt command to run with virtualenv.
"""
from airflow.models.dagbag import DagBag

dag_bag = DagBag(dag_folder=DAGS_FOLDER, include_examples=False)
dag = dag_bag.get_dag("example_virtualenv_mini")

dag.test()

assert "Trying to run the command:\n ['/tmp/persistent-venv2/bin/dbt', 'deps'" in caplog.text
assert "Trying to run the command:\n ['/tmp/persistent-venv2/bin/dbt', 'seed'" in caplog.text

0 comments on commit 4c9b28f

Please sign in to comment.