diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index d84a1fafb..8edc9b232 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -81,6 +81,7 @@ def create_test_task_metadata( on_warning_callback: Callable[..., Any] | None = None, node: DbtNode | None = None, render_config: RenderConfig | None = None, + async_op_args: dict | None = None, ) -> TaskMetadata: """ Create the metadata that will be used to instantiate the Airflow Task that will be used to run the Dbt test node. @@ -125,6 +126,7 @@ def create_test_task_metadata( ), arguments=task_args, extra_context=extra_context, + async_op_args=async_op_args, ) @@ -287,6 +289,7 @@ def build_airflow_graph( render_config: RenderConfig, task_group: TaskGroup | None = None, on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command + async_op_args: dict | None = None, ) -> None: """ Instantiate dbt `nodes` as Airflow tasks within the given `task_group` (optional) or `dag` (mandatory). @@ -351,6 +354,7 @@ def build_airflow_graph( task_args=task_args, on_warning_callback=on_warning_callback, render_config=render_config, + async_op_args=async_op_args, ) test_task = create_airflow_task(test_meta, dag, task_group=task_group) leaves_ids = calculate_leaves(tasks_ids=list(tasks_map.keys()), nodes=nodes) diff --git a/cosmos/operators/airflow_async.py b/cosmos/operators/airflow_async.py index f278dcba8..4f1fae227 100644 --- a/cosmos/operators/airflow_async.py +++ b/cosmos/operators/airflow_async.py @@ -6,25 +6,22 @@ from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator from airflow.utils.context import Context -from cosmos.operators.base import DbtCompileMixin from cosmos.operators.local import ( DbtBuildLocalOperator, + DbtCompileLocalOperator, DbtDepsLocalOperator, DbtDocsAzureStorageLocalOperator, DbtDocsCloudLocalOperator, DbtDocsGCSLocalOperator, DbtDocsLocalOperator, DbtDocsS3LocalOperator, - DbtLocalBaseOperator, DbtLSLocalOperator, + DbtRunOperationLocalOperator, DbtSeedLocalOperator, DbtSnapshotLocalOperator, DbtSourceLocalOperator, DbtTestLocalOperator, - DbtRunOperationLocalOperator, - DbtCompileLocalOperator, ) - from cosmos.settings import remote_target_path, remote_target_path_conn_id _SUPPORTED_DATABASES = ["bigquery"] @@ -73,9 +70,9 @@ def get_remote_sql(self): project_name = str(self.project_dir).split("/")[-1] model_name: str = self.task_id.split(".")[0] if model_name.startswith("stg_"): - remote_model_path = f"{remote_target_path}/{project_name}/models/staging/{model_name}.sql" + remote_model_path = f"{remote_target_path}/{self.dag_id}/{project_name}/models/staging/{model_name}.sql" else: - remote_model_path = f"{remote_target_path}/{project_name}/models/{model_name}.sql" + remote_model_path = f"{remote_target_path}/{self.dag_id}/{project_name}/models/{model_name}.sql" print("remote_model_path: ", remote_model_path) object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id) @@ -95,7 +92,6 @@ def execute(self, context: Context) -> Any | None: super().execute(context) - class DbtTestAirflowAsyncOperator(DbtTestLocalOperator): pass @@ -127,7 +123,6 @@ class DbtDocsGCSAirflowAsyncOperator(DbtDocsGCSLocalOperator): class DbtCompileAirflowAsyncOperator(DbtCompileLocalOperator): pass + class DbtDepsAirflowAsyncOperator(DbtDepsLocalOperator): pass - -