diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b0c33dfc6..a444dbc59 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -4,7 +4,7 @@ on: push: # Run on pushes to the default branch branches: [main] pull_request_target: # Also run on pull requests originated from forks - branches: [main] + branches: [main,poc-dbt-compile-task] concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -176,6 +176,8 @@ jobs: POSTGRES_DB: postgres POSTGRES_SCHEMA: public POSTGRES_PORT: 5432 + AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/" + AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn - name: Upload coverage to Github uses: actions/upload-artifact@v4 @@ -248,6 +250,8 @@ jobs: POSTGRES_DB: postgres POSTGRES_SCHEMA: public POSTGRES_PORT: 5432 + AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/" + AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn - name: Upload coverage to Github uses: actions/upload-artifact@v4 @@ -316,6 +320,8 @@ jobs: POSTGRES_DB: postgres POSTGRES_SCHEMA: public POSTGRES_PORT: 5432 + AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/" + AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn - name: Upload coverage to Github uses: actions/upload-artifact@v4 @@ -393,6 +399,8 @@ jobs: POSTGRES_DB: postgres POSTGRES_SCHEMA: public POSTGRES_PORT: 5432 + AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/" + AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn - name: Upload coverage to Github uses: actions/upload-artifact@v4 @@ -537,6 +545,8 @@ jobs: POSTGRES_DB: postgres POSTGRES_SCHEMA: public POSTGRES_PORT: 5432 + AIRFLOW__COSMOS__REMOTE_TARGET_PATH: "s3://cosmos-remote-cache/target_compiled/" + AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID: aws_s3_conn - name: Upload coverage to Github uses: actions/upload-artifact@v4 diff --git a/cosmos/operators/airflow_async.py b/cosmos/operators/airflow_async.py index 334e074e5..05f762702 100644 --- a/cosmos/operators/airflow_async.py +++ b/cosmos/operators/airflow_async.py @@ -1,15 +1,10 @@ -from typing import Any - -from cosmos.operators.base import DbtCompileMixin from cosmos.operators.local import ( DbtBuildLocalOperator, - DbtDepsLocalOperator, + DbtCompileLocalOperator, DbtDocsAzureStorageLocalOperator, - DbtDocsCloudLocalOperator, DbtDocsGCSLocalOperator, DbtDocsLocalOperator, DbtDocsS3LocalOperator, - DbtLocalBaseOperator, DbtLSLocalOperator, DbtRunLocalOperator, DbtRunOperationLocalOperator, @@ -56,10 +51,6 @@ class DbtDocsAirflowAsyncOperator(DbtDocsLocalOperator): pass -class DbtDocsCloudAirflowAsyncOperator(DbtDocsCloudLocalOperator): - pass - - class DbtDocsS3AirflowAsyncOperator(DbtDocsS3LocalOperator): pass @@ -72,15 +63,5 @@ class DbtDocsGCSAirflowAsyncOperator(DbtDocsGCSLocalOperator): pass -class DbtDepsAirflowAsyncOperator(DbtDepsLocalOperator): +class DbtCompileAirflowAsyncOperator(DbtCompileLocalOperator): pass - - -class DbtCompileAirflowAsyncOperator(DbtCompileMixin, DbtLocalBaseOperator): - """ - Executes a dbt core build command. - """ - - def __init__(self, *args: Any, **kwargs: Any) -> None: - kwargs["should_upload_compiled_sql"] = True - super().__init__(*args, **kwargs) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 25b7d9dde..1083d5703 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -65,6 +65,7 @@ from cosmos.operators.base import ( AbstractDbtBaseOperator, DbtBuildMixin, + DbtCompileMixin, DbtLSMixin, DbtRunMixin, DbtRunOperationMixin, @@ -949,3 +950,9 @@ def __init__(self, **kwargs: str) -> None: raise DeprecationWarning( "The DbtDepsOperator has been deprecated. " "Please use the `install_deps` flag in dbt_args instead." ) + + +class DbtCompileLocalOperator(DbtCompileMixin, DbtLocalBaseOperator): + def __init__(self, *args: Any, **kwargs: Any) -> None: + kwargs["should_upload_compiled_sql"] = True + super().__init__(*args, **kwargs) diff --git a/dev/dags/simple_dag_async.py b/dev/dags/simple_dag_async.py new file mode 100644 index 000000000..d364ee6f2 --- /dev/null +++ b/dev/dags/simple_dag_async.py @@ -0,0 +1,36 @@ +import os +from datetime import datetime +from pathlib import Path + +from cosmos import DbtDag, ExecutionConfig, ExecutionMode, ProfileConfig, ProjectConfig +from cosmos.profiles import PostgresUserPasswordProfileMapping + +DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt" +DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH)) + +profile_config = ProfileConfig( + profile_name="default", + target_name="dev", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="example_conn", + profile_args={"schema": "public"}, + disable_event_tracking=True, + ), +) + +simple_dag_async = DbtDag( + # dbt/cosmos-specific parameters + project_config=ProjectConfig( + DBT_ROOT_PATH / "jaffle_shop", + ), + profile_config=profile_config, + execution_config=ExecutionConfig( + execution_mode=ExecutionMode.AIRFLOW_ASYNC, + ), + # normal dag parameters + schedule_interval=None, + start_date=datetime(2023, 1, 1), + catchup=False, + dag_id="simple_dag_async", + tags=["simple"], +) diff --git a/tests/airflow/test_graph.py b/tests/airflow/test_graph.py index 72a09a5e5..d864b73e4 100644 --- a/tests/airflow/test_graph.py +++ b/tests/airflow/test_graph.py @@ -30,6 +30,7 @@ from cosmos.converter import airflow_kwargs from cosmos.dbt.graph import DbtNode from cosmos.profiles import PostgresUserPasswordProfileMapping +from cosmos.settings import dbt_compile_task_id SAMPLE_PROJ_PATH = Path("/home/user/path/dbt-proj/") SOURCE_RENDERING_BEHAVIOR = SourceRenderingBehavior(os.getenv("SOURCE_RENDERING_BEHAVIOR", "none")) @@ -226,6 +227,41 @@ def test_build_airflow_graph_with_after_all(): assert dag.leaves[0].select == ["tag:some"] +@pytest.mark.integration +def test_build_airflow_graph_with_dbt_compile_task(): + with DAG("test-id-dbt-compile", start_date=datetime(2022, 1, 1)) as dag: + task_args = { + "project_dir": SAMPLE_PROJ_PATH, + "conn_id": "fake_conn", + "profile_config": ProfileConfig( + profile_name="default", + target_name="default", + profile_mapping=PostgresUserPasswordProfileMapping( + conn_id="fake_conn", + profile_args={"schema": "public"}, + ), + ), + } + render_config = RenderConfig( + select=["tag:some"], + test_behavior=TestBehavior.AFTER_ALL, + source_rendering_behavior=SOURCE_RENDERING_BEHAVIOR, + ) + build_airflow_graph( + nodes=sample_nodes, + dag=dag, + execution_mode=ExecutionMode.AIRFLOW_ASYNC, + test_indirect_selection=TestIndirectSelection.EAGER, + task_args=task_args, + dbt_project_name="astro_shop", + render_config=render_config, + ) + + task_ids = [task.task_id for task in dag.tasks] + assert dbt_compile_task_id in task_ids + assert dbt_compile_task_id in dag.tasks[0].upstream_task_ids + + def test_calculate_operator_class(): class_module_import_path = calculate_operator_class(execution_mode=ExecutionMode.KUBERNETES, dbt_class="DbtSeed") assert class_module_import_path == "cosmos.operators.kubernetes.DbtSeedKubernetesOperator" diff --git a/tests/operators/test_airflow_async.py b/tests/operators/test_airflow_async.py new file mode 100644 index 000000000..fc085c7d0 --- /dev/null +++ b/tests/operators/test_airflow_async.py @@ -0,0 +1,82 @@ +from cosmos.operators.airflow_async import ( + DbtBuildAirflowAsyncOperator, + DbtCompileAirflowAsyncOperator, + DbtDocsAirflowAsyncOperator, + DbtDocsAzureStorageAirflowAsyncOperator, + DbtDocsGCSAirflowAsyncOperator, + DbtDocsS3AirflowAsyncOperator, + DbtLSAirflowAsyncOperator, + DbtRunAirflowAsyncOperator, + DbtRunOperationAirflowAsyncOperator, + DbtSeedAirflowAsyncOperator, + DbtSnapshotAirflowAsyncOperator, + DbtSourceAirflowAsyncOperator, + DbtTestAirflowAsyncOperator, +) +from cosmos.operators.local import ( + DbtBuildLocalOperator, + DbtCompileLocalOperator, + DbtDocsAzureStorageLocalOperator, + DbtDocsGCSLocalOperator, + DbtDocsLocalOperator, + DbtDocsS3LocalOperator, + DbtLSLocalOperator, + DbtRunLocalOperator, + DbtRunOperationLocalOperator, + DbtSeedLocalOperator, + DbtSnapshotLocalOperator, + DbtSourceLocalOperator, + DbtTestLocalOperator, +) + + +def test_dbt_build_airflow_async_operator_inheritance(): + assert issubclass(DbtBuildAirflowAsyncOperator, DbtBuildLocalOperator) + + +def test_dbt_ls_airflow_async_operator_inheritance(): + assert issubclass(DbtLSAirflowAsyncOperator, DbtLSLocalOperator) + + +def test_dbt_seed_airflow_async_operator_inheritance(): + assert issubclass(DbtSeedAirflowAsyncOperator, DbtSeedLocalOperator) + + +def test_dbt_snapshot_airflow_async_operator_inheritance(): + assert issubclass(DbtSnapshotAirflowAsyncOperator, DbtSnapshotLocalOperator) + + +def test_dbt_source_airflow_async_operator_inheritance(): + assert issubclass(DbtSourceAirflowAsyncOperator, DbtSourceLocalOperator) + + +def test_dbt_run_airflow_async_operator_inheritance(): + assert issubclass(DbtRunAirflowAsyncOperator, DbtRunLocalOperator) + + +def test_dbt_test_airflow_async_operator_inheritance(): + assert issubclass(DbtTestAirflowAsyncOperator, DbtTestLocalOperator) + + +def test_dbt_run_operation_airflow_async_operator_inheritance(): + assert issubclass(DbtRunOperationAirflowAsyncOperator, DbtRunOperationLocalOperator) + + +def test_dbt_docs_airflow_async_operator_inheritance(): + assert issubclass(DbtDocsAirflowAsyncOperator, DbtDocsLocalOperator) + + +def test_dbt_docs_s3_airflow_async_operator_inheritance(): + assert issubclass(DbtDocsS3AirflowAsyncOperator, DbtDocsS3LocalOperator) + + +def test_dbt_docs_azure_storage_airflow_async_operator_inheritance(): + assert issubclass(DbtDocsAzureStorageAirflowAsyncOperator, DbtDocsAzureStorageLocalOperator) + + +def test_dbt_docs_gcs_airflow_async_operator_inheritance(): + assert issubclass(DbtDocsGCSAirflowAsyncOperator, DbtDocsGCSLocalOperator) + + +def test_dbt_compile_airflow_async_operator_inheritance(): + assert issubclass(DbtCompileAirflowAsyncOperator, DbtCompileLocalOperator) diff --git a/tests/operators/test_base.py b/tests/operators/test_base.py index 6f4425282..e97c2d396 100644 --- a/tests/operators/test_base.py +++ b/tests/operators/test_base.py @@ -8,6 +8,7 @@ from cosmos.operators.base import ( AbstractDbtBaseOperator, DbtBuildMixin, + DbtCompileMixin, DbtLSMixin, DbtRunMixin, DbtRunOperationMixin, @@ -143,6 +144,7 @@ def test_dbt_base_operator_context_merge( ("seed", DbtSeedMixin), ("run", DbtRunMixin), ("build", DbtBuildMixin), + ("compile", DbtCompileMixin), ], ) def test_dbt_mixin_base_cmd(dbt_command, dbt_operator_class): diff --git a/tests/operators/test_local.py b/tests/operators/test_local.py index d54bbb5e1..fa3d87e4f 100644 --- a/tests/operators/test_local.py +++ b/tests/operators/test_local.py @@ -25,9 +25,11 @@ parse_number_of_warnings_dbt_runner, parse_number_of_warnings_subprocess, ) +from cosmos.exceptions import CosmosValueError from cosmos.hooks.subprocess import FullOutputSubprocessResult from cosmos.operators.local import ( DbtBuildLocalOperator, + DbtCompileLocalOperator, DbtDocsAzureStorageLocalOperator, DbtDocsGCSLocalOperator, DbtDocsLocalOperator, @@ -42,6 +44,7 @@ DbtTestLocalOperator, ) from cosmos.profiles import PostgresUserPasswordProfileMapping +from cosmos.settings import AIRFLOW_IO_AVAILABLE from tests.utils import test_dag as run_test_dag DBT_PROJ_DIR = Path(__file__).parent.parent.parent / "dev/dags/dbt/jaffle_shop" @@ -1052,3 +1055,137 @@ def test_store_freshness_not_store_compiled_sql(mock_context, mock_session): # Verify the freshness attribute is set correctly assert instance.freshness == "" + + +def test_dbt_compile_local_operator_initialisation(): + operator = DbtCompileLocalOperator( + task_id="fake-task", + profile_config=profile_config, + project_dir="fake-dir", + ) + assert operator.should_upload_compiled_sql is True + assert "compile" in operator.base_cmd + + +@patch("cosmos.operators.local.remote_target_path", new="s3://some-bucket/target") +@patch("cosmos.operators.local.AIRFLOW_IO_AVAILABLE", new=False) +def test_configure_remote_target_path_object_storage_unavailable_on_earlier_airflow_versions(): + operator = DbtCompileLocalOperator( + task_id="fake-task", + profile_config=profile_config, + project_dir="fake-dir", + ) + with pytest.raises(CosmosValueError, match="Object Storage feature is unavailable"): + operator._configure_remote_target_path() + + +@pytest.mark.parametrize( + "rem_target_path, rem_target_path_conn_id", + [ + (None, "aws_s3_conn"), + ("unknown://some-bucket/cache", None), + ], +) +def test_config_remote_target_path_unset_settings(rem_target_path, rem_target_path_conn_id): + with patch("cosmos.operators.local.remote_target_path", new=rem_target_path): + with patch("cosmos.operators.local.remote_target_path_conn_id", new=rem_target_path_conn_id): + operator = DbtCompileLocalOperator( + task_id="fake-task", + profile_config=profile_config, + project_dir="fake-dir", + ) + target_path, target_conn = operator._configure_remote_target_path() + assert target_path is None + assert target_conn is None + + +@pytest.mark.skipif(not AIRFLOW_IO_AVAILABLE, reason="Airflow did not have Object Storage until the 2.8 release") +@patch("cosmos.operators.local.remote_target_path", new="s3://some-bucket/target") +@patch("cosmos.operators.local.remote_target_path_conn_id", new="aws_s3_conn") +@patch("airflow.io.path.ObjectStoragePath") +def test_configure_remote_target_path(mock_object_storage_path): + operator = DbtCompileLocalOperator( + task_id="fake-task", + profile_config=profile_config, + project_dir="fake-dir", + ) + mock_remote_path = MagicMock() + mock_object_storage_path.return_value.exists.return_value = True + mock_object_storage_path.return_value = mock_remote_path + target_path, target_conn = operator._configure_remote_target_path() + assert target_path == mock_remote_path + assert target_conn == "aws_s3_conn" + mock_object_storage_path.assert_called_with("s3://some-bucket/target", conn_id="aws_s3_conn") + + mock_object_storage_path.return_value.exists.return_value = False + mock_object_storage_path.return_value.mkdir.return_value = MagicMock() + _, _ = operator._configure_remote_target_path() + mock_object_storage_path.return_value.mkdir.assert_called_with(parents=True, exist_ok=True) + + +@patch.object(DbtLocalBaseOperator, "_configure_remote_target_path") +def test_no_compiled_sql_upload_for_other_operators(mock_configure_remote_target_path): + operator = DbtSeedLocalOperator( + task_id="fake-task", + profile_config=profile_config, + project_dir="fake-dir", + ) + assert operator.should_upload_compiled_sql is False + operator.upload_compiled_sql("fake-dir", MagicMock()) + mock_configure_remote_target_path.assert_not_called() + + +@patch("cosmos.operators.local.DbtCompileLocalOperator._configure_remote_target_path") +def test_upload_compiled_sql_no_remote_path_raises_error(mock_configure_remote): + operator = DbtCompileLocalOperator( + task_id="fake-task", + profile_config=profile_config, + project_dir="fake-dir", + ) + + mock_configure_remote.return_value = (None, None) + + tmp_project_dir = "/fake/tmp/project" + context = {"dag": MagicMock(dag_id="test_dag")} + + with pytest.raises(CosmosValueError, match="remote target path is not configured"): + operator.upload_compiled_sql(tmp_project_dir, context) + + +@pytest.mark.skipif(not AIRFLOW_IO_AVAILABLE, reason="Airflow did not have Object Storage until the 2.8 release") +@patch("airflow.io.path.ObjectStoragePath.copy") +@patch("airflow.io.path.ObjectStoragePath") +@patch("cosmos.operators.local.DbtCompileLocalOperator._configure_remote_target_path") +def test_upload_compiled_sql_should_upload(mock_configure_remote, mock_object_storage_path, mock_copy): + """Test upload_compiled_sql when should_upload_compiled_sql is True and uploads files.""" + operator = DbtCompileLocalOperator( + task_id="fake-task", + profile_config=profile_config, + project_dir="fake-dir", + ) + + mock_configure_remote.return_value = ("mock_remote_path", "mock_conn_id") + + tmp_project_dir = "/fake/tmp/project" + source_compiled_dir = Path(tmp_project_dir) / "target" / "compiled" + + file1 = MagicMock(spec=Path) + file1.is_file.return_value = True + file1.__str__.return_value = str(source_compiled_dir / "file1.sql") + + file2 = MagicMock(spec=Path) + file2.is_file.return_value = True + file2.__str__.return_value = str(source_compiled_dir / "file2.sql") + + files = [file1, file2] + + with patch.object(Path, "rglob", return_value=files): + context = {"dag": MagicMock(dag_id="test_dag")} + + operator.upload_compiled_sql(tmp_project_dir, context) + + for file_path in files: + rel_path = os.path.relpath(str(file_path), str(source_compiled_dir)) + expected_dest_path = f"mock_remote_path/test_dag/{rel_path.lstrip('/')}" + mock_object_storage_path.assert_any_call(expected_dest_path, conn_id="mock_conn_id") + mock_object_storage_path.return_value.copy.assert_any_call(mock_object_storage_path.return_value) diff --git a/tests/test_example_dags.py b/tests/test_example_dags.py index 9f8601156..9aa66432d 100644 --- a/tests/test_example_dags.py +++ b/tests/test_example_dags.py @@ -28,7 +28,7 @@ MIN_VER_DAG_FILE: dict[str, list[str]] = { "2.4": ["cosmos_seed_dag.py"], - "2.8": ["cosmos_manifest_example.py"], + "2.8": ["cosmos_manifest_example.py", "simple_dag_async.py"], } IGNORED_DAG_FILES = ["performance_dag.py", "jaffle_shop_kubernetes.py"]