From 9dc2c9cd4ab77267207e0e62c32bbf13692428c2 Mon Sep 17 00:00:00 2001 From: Pankaj Koti Date: Sun, 29 Sep 2024 19:58:51 +0530 Subject: [PATCH] Put compiled files under dag_id folder & refactor few snippets --- cosmos/airflow/graph.py | 43 ++++++++++++++++++++----------- cosmos/operators/airflow_async.py | 5 ---- cosmos/operators/local.py | 23 ++++++++++++----- 3 files changed, 44 insertions(+), 27 deletions(-) diff --git a/cosmos/airflow/graph.py b/cosmos/airflow/graph.py index 230c51d63..d84a1fafb 100644 --- a/cosmos/airflow/graph.py +++ b/cosmos/airflow/graph.py @@ -253,6 +253,30 @@ def generate_task_or_group( return task_or_group +def _add_dbt_compile_task( + nodes: dict[str, DbtNode], + dag: DAG, + execution_mode: ExecutionMode, + task_args: dict[str, Any], + tasks_map: dict[str, Any], +) -> None: + if execution_mode != ExecutionMode.AIRFLOW_ASYNC: + return + + compile_task_metadata = TaskMetadata( + id=dbt_compile_task_id, + operator_class=f"cosmos.operators.airflow_async.DbtCompileAirflowAsyncOperator", + arguments=task_args, + extra_context={}, + ) + compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=None) + tasks_map[dbt_compile_task_id] = compile_airflow_task + + for node_id, node in nodes.items(): + if not node.depends_on and node_id in tasks_map: + tasks_map[dbt_compile_task_id] >> tasks_map[node_id] + + def build_airflow_graph( nodes: dict[str, DbtNode], dag: DAG, # Airflow-specific - parent DAG where to associate tasks and (optional) task groups @@ -333,22 +357,14 @@ def build_airflow_graph( for leaf_node_id in leaves_ids: tasks_map[leaf_node_id] >> test_task - if execution_mode == ExecutionMode.AIRFLOW_ASYNC: - compile_task_metadata = TaskMetadata( - id=dbt_compile_task_id, - owner="", # Set appropriate owner if needed - operator_class=f"cosmos.operators.airflow_async.DbtCompileAirflowAsyncOperator", - arguments=task_args, - extra_context={}, - ) - compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=None) - tasks_map[dbt_compile_task_id] = compile_airflow_task + _add_dbt_compile_task(nodes, dag, execution_mode, task_args, tasks_map) - create_airflow_task_dependencies(nodes, tasks_map, execution_mode) + create_airflow_task_dependencies(nodes, tasks_map) def create_airflow_task_dependencies( - nodes: dict[str, DbtNode], tasks_map: dict[str, Union[TaskGroup, BaseOperator]], execution_mode: ExecutionMode + nodes: dict[str, DbtNode], + tasks_map: dict[str, Union[TaskGroup, BaseOperator]], ) -> None: """ Create the Airflow task dependencies between non-test nodes. @@ -356,9 +372,6 @@ def create_airflow_task_dependencies( :param tasks_map: Dictionary mapping dbt nodes (node.unique_id to Airflow task) """ for node_id, node in nodes.items(): - if not node.depends_on and execution_mode == ExecutionMode.AIRFLOW_ASYNC: - tasks_map[dbt_compile_task_id] >> tasks_map[node_id] - for parent_node_id in node.depends_on: # depending on the node type, it will not have mapped 1:1 to tasks_map if (node_id in tasks_map) and (parent_node_id in tasks_map): diff --git a/cosmos/operators/airflow_async.py b/cosmos/operators/airflow_async.py index b02a54416..334e074e5 100644 --- a/cosmos/operators/airflow_async.py +++ b/cosmos/operators/airflow_async.py @@ -1,7 +1,5 @@ from typing import Any -from airflow.utils.context import Context - from cosmos.operators.base import DbtCompileMixin from cosmos.operators.local import ( DbtBuildLocalOperator, @@ -86,6 +84,3 @@ class DbtCompileAirflowAsyncOperator(DbtCompileMixin, DbtLocalBaseOperator): def __init__(self, *args: Any, **kwargs: Any) -> None: kwargs["should_upload_compiled_sql"] = True super().__init__(*args, **kwargs) - - def execute(self, context: Context) -> None: - self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags()) diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 3254c8959..25b7d9dde 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -256,10 +256,10 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se self.log.info("Warning: ti is of type TaskInstancePydantic. Cannot update template_fields.") @staticmethod - def _configure_remote_target_path() -> Path | None: + def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]: """Configure the remote target path if it is provided.""" if not remote_target_path: - return None + return None, None _configured_target_path = None @@ -270,7 +270,7 @@ def _configure_remote_target_path() -> Path | None: target_path_schema = target_path_str.split("://")[0] remote_conn_id = FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP.get(target_path_schema, None) # type: ignore[assignment] if remote_conn_id is None: - return _configured_target_path + return None, None if not AIRFLOW_IO_AVAILABLE: raise CosmosValueError( @@ -286,7 +286,7 @@ def _configure_remote_target_path() -> Path | None: if not _configured_target_path.exists(): # type: ignore[no-untyped-call] _configured_target_path.mkdir(parents=True, exist_ok=True) - return _configured_target_path + return _configured_target_path, remote_conn_id def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None: """ @@ -295,7 +295,7 @@ def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None: if not self.should_upload_compiled_sql: return - dest_target_dir = self._configure_remote_target_path() + dest_target_dir, dest_conn_id = self._configure_remote_target_path() if not dest_target_dir: raise CosmosValueError( "You're trying to upload compiled SQL files, but the remote target path is not configured. " @@ -303,9 +303,18 @@ def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None: from airflow.io.path import ObjectStoragePath - source_target_dir = ObjectStoragePath(Path(tmp_project_dir) / "target" / "compiled") + source_compiled_dir = Path(tmp_project_dir) / "target" / "compiled" + files = [str(file) for file in source_compiled_dir.rglob("*") if file.is_file()] - source_target_dir.copy(dest_target_dir, recursive=True) # type: ignore[arg-type] + for file_path in files: + rel_path = os.path.relpath(file_path, source_compiled_dir) + + dest_path = ObjectStoragePath( + f"{str(dest_target_dir).rstrip('/')}/{context['dag'].dag_id}/{rel_path.lstrip('/')}", + conn_id=dest_conn_id, + ) + ObjectStoragePath(file_path).copy(dest_path) + self.log.debug("Copied %s to %s", file_path, dest_path) @provide_session def store_freshness_json(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None: