Skip to content

Commit

Permalink
Draft: dbt compile task
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajkoti committed Sep 29, 2024
1 parent e0a9fd3 commit 851564f
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 5 deletions.
19 changes: 17 additions & 2 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from cosmos.core.graph.entities import Task as TaskMetadata
from cosmos.dbt.graph import DbtNode
from cosmos.log import get_logger
from cosmos.settings import dbt_compile_task_id

logger = get_logger(__name__)

Expand Down Expand Up @@ -332,18 +333,32 @@ def build_airflow_graph(
for leaf_node_id in leaves_ids:
tasks_map[leaf_node_id] >> test_task

create_airflow_task_dependencies(nodes, tasks_map)
if execution_mode == ExecutionMode.AIRFLOW_ASYNC:
compile_task_metadata = TaskMetadata(
id=dbt_compile_task_id,
owner="", # Set appropriate owner if needed
operator_class=f"cosmos.operators.airflow_async.DbtCompileAirflowAsyncOperator",
arguments=task_args,
extra_context={},
)
compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=None)
tasks_map[dbt_compile_task_id] = compile_airflow_task

create_airflow_task_dependencies(nodes, tasks_map, execution_mode)


def create_airflow_task_dependencies(
nodes: dict[str, DbtNode], tasks_map: dict[str, Union[TaskGroup, BaseOperator]]
nodes: dict[str, DbtNode], tasks_map: dict[str, Union[TaskGroup, BaseOperator]], execution_mode: ExecutionMode
) -> None:
"""
Create the Airflow task dependencies between non-test nodes.
:param nodes: Dictionary mapping dbt nodes (node.unique_id to node)
:param tasks_map: Dictionary mapping dbt nodes (node.unique_id to Airflow task)
"""
for node_id, node in nodes.items():
if not node.depends_on and execution_mode == ExecutionMode.AIRFLOW_ASYNC:
tasks_map[dbt_compile_task_id] >> tasks_map[node_id]

for parent_node_id in node.depends_on:
# depending on the node type, it will not have mapped 1:1 to tasks_map
if (node_id in tasks_map) and (parent_node_id in tasks_map):
Expand Down
1 change: 1 addition & 0 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class ExecutionMode(Enum):
"""

LOCAL = "local"
AIRFLOW_ASYNC = "airflow_async"
DOCKER = "docker"
KUBERNETES = "kubernetes"
AWS_EKS = "aws_eks"
Expand Down
91 changes: 91 additions & 0 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Any

from airflow.utils.context import Context

from cosmos.operators.base import DbtCompileMixin
from cosmos.operators.local import (
DbtBuildLocalOperator,
DbtDepsLocalOperator,
DbtDocsAzureStorageLocalOperator,
DbtDocsCloudLocalOperator,
DbtDocsGCSLocalOperator,
DbtDocsLocalOperator,
DbtDocsS3LocalOperator,
DbtLocalBaseOperator,
DbtLSLocalOperator,
DbtRunLocalOperator,
DbtRunOperationLocalOperator,
DbtSeedLocalOperator,
DbtSnapshotLocalOperator,
DbtSourceLocalOperator,
DbtTestLocalOperator,
)


class DbtBuildAirflowAsyncOperator(DbtBuildLocalOperator):
pass


class DbtLSAirflowAsyncOperator(DbtLSLocalOperator):
pass


class DbtSeedAirflowAsyncOperator(DbtSeedLocalOperator):
pass


class DbtSnapshotAirflowAsyncOperator(DbtSnapshotLocalOperator):
pass


class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator):
pass


class DbtRunAirflowAsyncOperator(DbtRunLocalOperator):
pass


class DbtTestAirflowAsyncOperator(DbtTestLocalOperator):
pass


class DbtRunOperationAirflowAsyncOperator(DbtRunOperationLocalOperator):
pass


class DbtDocsAirflowAsyncOperator(DbtDocsLocalOperator):
pass


class DbtDocsCloudAirflowAsyncOperator(DbtDocsCloudLocalOperator):
pass


class DbtDocsS3AirflowAsyncOperator(DbtDocsS3LocalOperator):
pass


class DbtDocsAzureStorageAirflowAsyncOperator(DbtDocsAzureStorageLocalOperator):
pass


class DbtDocsGCSAirflowAsyncOperator(DbtDocsGCSLocalOperator):
pass


class DbtDepsAirflowAsyncOperator(DbtDepsLocalOperator):
pass


class DbtCompileAirflowAsyncOperator(DbtCompileMixin, DbtLocalBaseOperator):
"""
Executes a dbt core build command.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["should_upload_compiled_sql"] = True
super().__init__(*args, **kwargs)

def execute(self, context: Context) -> None:
self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())
9 changes: 9 additions & 0 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,12 @@ def add_cmd_flags(self) -> list[str]:
flags.append("--args")
flags.append(yaml.dump(self.args))
return flags


class DbtCompileMixin:
"""
Mixin for dbt compile command.
"""

base_cmd = ["compile"]
ui_color = "#877c7c"
62 changes: 59 additions & 3 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from airflow.models.taskinstance import TaskInstance
from airflow.utils.context import Context
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.version import version as airflow_version
from attr import define

from cosmos import cache
Expand All @@ -24,10 +25,10 @@
_get_latest_cached_package_lockfile,
is_cache_package_lockfile_enabled,
)
from cosmos.constants import InvocationMode
from cosmos.constants import FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP, InvocationMode
from cosmos.dbt.project import get_partial_parse_path, has_non_empty_dependencies_file
from cosmos.exceptions import AirflowCompatibilityError
from cosmos.settings import LINEAGE_NAMESPACE
from cosmos.exceptions import AirflowCompatibilityError, CosmosValueError
from cosmos.settings import AIRFLOW_IO_AVAILABLE, LINEAGE_NAMESPACE, remote_target_path, remote_target_path_conn_id

try:
from airflow.datasets import Dataset
Expand Down Expand Up @@ -131,6 +132,7 @@ def __init__(
install_deps: bool = False,
callback: Callable[[str], None] | None = None,
should_store_compiled_sql: bool = True,
should_upload_compiled_sql: bool = False,
append_env: bool = True,
**kwargs: Any,
) -> None:
Expand All @@ -139,6 +141,7 @@ def __init__(
self.compiled_sql = ""
self.freshness = ""
self.should_store_compiled_sql = should_store_compiled_sql
self.should_upload_compiled_sql = should_upload_compiled_sql
self.openlineage_events_completes: list[RunEvent] = []
self.invocation_mode = invocation_mode
self.invoke_dbt: Callable[..., FullOutputSubprocessResult | dbtRunnerResult]
Expand Down Expand Up @@ -252,6 +255,58 @@ def store_compiled_sql(self, tmp_project_dir: str, context: Context, session: Se
else:
self.log.info("Warning: ti is of type TaskInstancePydantic. Cannot update template_fields.")

@staticmethod
def _configure_remote_target_path() -> Path | None:
"""Configure the remote target path if it is provided."""
if not remote_target_path:
return None

_configured_target_path = None

target_path_str = str(remote_target_path)

remote_conn_id = remote_target_path_conn_id
if not remote_conn_id:
target_path_schema = target_path_str.split("://")[0]
remote_conn_id = FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP.get(target_path_schema, None) # type: ignore[assignment]
if remote_conn_id is None:
return _configured_target_path

if not AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(
f"You're trying to specify remote target path {target_path_str}, but the required "
f"Object Storage feature is unavailable in Airflow version {airflow_version}. Please upgrade to "
"Airflow 2.8 or later."
)

from airflow.io.path import ObjectStoragePath

_configured_target_path = ObjectStoragePath(target_path_str, conn_id=remote_conn_id)

if not _configured_target_path.exists(): # type: ignore[no-untyped-call]
_configured_target_path.mkdir(parents=True, exist_ok=True)

return _configured_target_path

def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None:
"""
Uploads the compiled SQL files from the dbt compile output to the remote store.
"""
if not self.should_upload_compiled_sql:
return

dest_target_dir = self._configure_remote_target_path()
if not dest_target_dir:
raise CosmosValueError(
"You're trying to upload compiled SQL files, but the remote target path is not configured. "
)

from airflow.io.path import ObjectStoragePath

source_target_dir = ObjectStoragePath(Path(tmp_project_dir) / "target" / "compiled")

source_target_dir.copy(dest_target_dir, recursive=True) # type: ignore[arg-type]

@provide_session
def store_freshness_json(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None:
"""
Expand Down Expand Up @@ -397,6 +452,7 @@ def run_command(

self.store_freshness_json(tmp_project_dir, context)
self.store_compiled_sql(tmp_project_dir, context)
self.upload_compiled_sql(tmp_project_dir, context)
self.handle_exception(result)
if self.callback:
self.callback(tmp_project_dir)
Expand Down
4 changes: 4 additions & 0 deletions cosmos/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
remote_cache_dir = conf.get("cosmos", "remote_cache_dir", fallback=None)
remote_cache_dir_conn_id = conf.get("cosmos", "remote_cache_dir_conn_id", fallback=None)

dbt_compile_task_id = conf.get("cosmos", "dbt_compile_task_id", fallback="dbt_compile")
remote_target_path = conf.get("cosmos", "remote_target_path", fallback=None)
remote_target_path_conn_id = conf.get("cosmos", "remote_target_path_conn_id", fallback=None)

try:
LINEAGE_NAMESPACE = conf.get("openlineage", "namespace")
except airflow.exceptions.AirflowConfigException:
Expand Down

0 comments on commit 851564f

Please sign in to comment.