Skip to content

Commit

Permalink
Add async run operator
Browse files Browse the repository at this point in the history
Remove print stmt

Fix query
Fix query

Remove oss execute method code
  • Loading branch information
pankajkoti authored and pankajastro committed Sep 30, 2024
1 parent 3414513 commit 407d311
Show file tree
Hide file tree
Showing 11 changed files with 268 additions and 9 deletions.
31 changes: 27 additions & 4 deletions cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from cosmos.core.graph.entities import Task as TaskMetadata
from cosmos.dbt.graph import DbtNode
from cosmos.log import get_logger
from cosmos.settings import dbt_compile_task_id

logger = get_logger(__name__)

Expand Down Expand Up @@ -80,6 +81,7 @@ def create_test_task_metadata(
on_warning_callback: Callable[..., Any] | None = None,
node: DbtNode | None = None,
render_config: RenderConfig | None = None,
async_op_args: dict[str, Any] | None = None,
) -> TaskMetadata:
"""
Create the metadata that will be used to instantiate the Airflow Task that will be used to run the Dbt test node.
Expand Down Expand Up @@ -124,6 +126,7 @@ def create_test_task_metadata(
),
arguments=task_args,
extra_context=extra_context,
async_op_args=async_op_args,
)


Expand All @@ -133,6 +136,7 @@ def create_task_metadata(
args: dict[str, Any],
use_task_group: bool = False,
source_rendering_behavior: SourceRenderingBehavior = SourceRenderingBehavior.NONE,
async_op_args: dict[str, Any] | None = None,
) -> TaskMetadata | None:
"""
Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node.
Expand Down Expand Up @@ -190,6 +194,7 @@ def create_task_metadata(
),
arguments=args,
extra_context=extra_context,
async_op_args=async_op_args,
)
return task_metadata
else:
Expand All @@ -211,10 +216,10 @@ def generate_task_or_group(
source_rendering_behavior: SourceRenderingBehavior,
test_indirect_selection: TestIndirectSelection,
on_warning_callback: Callable[..., Any] | None,
async_op_args: dict[str, Any] = None,
**kwargs: Any,
) -> BaseOperator | TaskGroup | None:
task_or_group: BaseOperator | TaskGroup | None = None

use_task_group = (
node.resource_type in TESTABLE_DBT_RESOURCES
and test_behavior == TestBehavior.AFTER_EACH
Expand All @@ -227,6 +232,7 @@ def generate_task_or_group(
args=task_args,
use_task_group=use_task_group,
source_rendering_behavior=source_rendering_behavior,
async_op_args=async_op_args,
)

# In most cases, we'll map one DBT node to one Airflow task
Expand All @@ -243,6 +249,7 @@ def generate_task_or_group(
task_args=task_args,
node=node,
on_warning_callback=on_warning_callback,
async_op_args=async_op_args,
)
test_task = create_airflow_task(test_meta, dag, task_group=model_task_group)
task >> test_task
Expand All @@ -262,6 +269,7 @@ def build_airflow_graph(
render_config: RenderConfig,
task_group: TaskGroup | None = None,
on_warning_callback: Callable[..., Any] | None = None, # argument specific to the DBT test command
async_op_args: dict[str, Any] = None,
) -> None:
"""
Instantiate dbt `nodes` as Airflow tasks within the given `task_group` (optional) or `dag` (mandatory).
Expand Down Expand Up @@ -291,7 +299,6 @@ def build_airflow_graph(
source_rendering_behavior = render_config.source_rendering_behavior
tasks_map = {}
task_or_group: TaskGroup | BaseOperator

for node_id, node in nodes.items():
conversion_function = node_converters.get(node.resource_type, generate_task_or_group)
if conversion_function != generate_task_or_group:
Expand All @@ -311,6 +318,7 @@ def build_airflow_graph(
test_indirect_selection=test_indirect_selection,
on_warning_callback=on_warning_callback,
node=node,
async_op_args=async_op_args,
)
if task_or_group is not None:
logger.debug(f"Conversion of <{node.unique_id}> was successful!")
Expand All @@ -326,24 +334,39 @@ def build_airflow_graph(
task_args=task_args,
on_warning_callback=on_warning_callback,
render_config=render_config,
async_op_args=async_op_args,
)
test_task = create_airflow_task(test_meta, dag, task_group=task_group)
leaves_ids = calculate_leaves(tasks_ids=list(tasks_map.keys()), nodes=nodes)
for leaf_node_id in leaves_ids:
tasks_map[leaf_node_id] >> test_task

create_airflow_task_dependencies(nodes, tasks_map)
if execution_mode == ExecutionMode.AIRFLOW_ASYNC:
compile_task_metadata = TaskMetadata(
id=dbt_compile_task_id,
owner="", # Set appropriate owner if needed
operator_class=f"cosmos.operators.airflow_async.DbtCompileAirflowAsyncOperator",
arguments=task_args,
extra_context={},
)
compile_airflow_task = create_airflow_task(compile_task_metadata, dag, task_group=None)
tasks_map[dbt_compile_task_id] = compile_airflow_task

create_airflow_task_dependencies(nodes, tasks_map, execution_mode)


def create_airflow_task_dependencies(
nodes: dict[str, DbtNode], tasks_map: dict[str, Union[TaskGroup, BaseOperator]]
nodes: dict[str, DbtNode], tasks_map: dict[str, Union[TaskGroup, BaseOperator]], execution_mode: ExecutionMode
) -> None:
"""
Create the Airflow task dependencies between non-test nodes.
:param nodes: Dictionary mapping dbt nodes (node.unique_id to node)
:param tasks_map: Dictionary mapping dbt nodes (node.unique_id to Airflow task)
"""
for node_id, node in nodes.items():
if not node.depends_on and execution_mode == ExecutionMode.AIRFLOW_ASYNC:
tasks_map[dbt_compile_task_id] >> tasks_map[node_id]

for parent_node_id in node.depends_on:
# depending on the node type, it will not have mapped 1:1 to tasks_map
if (node_id in tasks_map) and (parent_node_id in tasks_map):
Expand Down
14 changes: 14 additions & 0 deletions cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pathlib import Path
from typing import Any, Callable, Iterator

import yaml
from airflow.version import version as airflow_version

from cosmos.cache import create_cache_profile, get_cached_profile, is_profile_cache_enabled
Expand Down Expand Up @@ -286,6 +287,19 @@ def validate_profiles_yml(self) -> None:
if self.profiles_yml_filepath and not Path(self.profiles_yml_filepath).exists():
raise CosmosValueError(f"The file {self.profiles_yml_filepath} does not exist.")

def get_profile_type(self):
if self.profile_mapping.dbt_profile_type:
return self.profile_mapping.dbt_profile_type

profile_path = self._get_profile_path()

with open(profile_path) as file:
profiles = yaml.safe_load(file)

profile = profiles[self.profile_name]
target_type = profile["outputs"][self.target_name]["type"]
return target_type

def _get_profile_path(self, use_mock_values: bool = False) -> Path:
"""
Handle the profile caching mechanism.
Expand Down
1 change: 1 addition & 0 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class ExecutionMode(Enum):
"""

LOCAL = "local"
AIRFLOW_ASYNC = "airflow_async"
DOCKER = "docker"
KUBERNETES = "kubernetes"
AWS_EKS = "aws_eks"
Expand Down
3 changes: 3 additions & 0 deletions cosmos/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def __init__(
task_group: TaskGroup | None = None,
operator_args: dict[str, Any] | None = None,
on_warning_callback: Callable[..., Any] | None = None,
async_op_args: dict[str, Any] | None = None,
*args: Any,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -257,6 +258,7 @@ def __init__(
cache_identifier=cache_identifier,
dbt_vars=dbt_vars,
airflow_metadata=cache._get_airflow_metadata(dag, task_group),
async_op_args=async_op_args,
)
self.dbt_graph.load(method=render_config.load_method, execution_mode=execution_config.execution_mode)

Expand Down Expand Up @@ -302,6 +304,7 @@ def __init__(
dbt_project_name=project_config.project_name,
on_warning_callback=on_warning_callback,
render_config=render_config,
async_op_args=async_op_args,
)

current_time = time.perf_counter()
Expand Down
7 changes: 6 additions & 1 deletion cosmos/core/airflow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import importlib

from airflow.models import BaseOperator
Expand All @@ -10,7 +12,7 @@
logger = get_logger(__name__)


def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None) -> BaseOperator:
def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None) -> BaseOperator:
"""
Get the Airflow Operator class for a Task.
Expand All @@ -29,6 +31,9 @@ def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None
if task.owner != "":
task_kwargs["owner"] = task.owner

if task.async_op_args:
task_kwargs["async_op_args"] = task.async_op_args

airflow_task = Operator(
task_id=task.id,
dag=dag,
Expand Down
1 change: 1 addition & 0 deletions cosmos/core/graph/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,4 @@ class Task(CosmosEntity):
operator_class: str = "airflow.operators.empty.EmptyOperator"
arguments: Dict[str, Any] = field(default_factory=dict)
extra_context: Dict[str, Any] = field(default_factory=dict)
async_op_args: Dict[str, Any] = field(default_factory=dict)
3 changes: 2 additions & 1 deletion cosmos/dbt/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,15 @@ def __init__(
dbt_vars: dict[str, str] | None = None,
airflow_metadata: dict[str, str] | None = None,
operator_args: dict[str, Any] | None = None,
async_op_args: dict[str, Any] | None = None,
):
self.project = project
self.render_config = render_config
self.profile_config = profile_config
self.execution_config = execution_config
self.cache_dir = cache_dir
self.airflow_metadata = airflow_metadata or {}
self.async_op_args = async_op_args
if cache_identifier:
self.dbt_ls_cache_key = cache.create_cache_key(cache_identifier)
else:
Expand Down Expand Up @@ -467,7 +469,6 @@ def should_use_dbt_ls_cache(self) -> bool:

def load_via_dbt_ls_cache(self) -> bool:
"""(Try to) load dbt ls cache from an Airflow Variable"""

logger.info(f"Trying to parse the dbt project using dbt ls cache {self.dbt_ls_cache_key}...")
if self.should_use_dbt_ls_cache():
project_path = self.project_path
Expand Down
138 changes: 138 additions & 0 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from __future__ import annotations

from typing import Any

from airflow.io.path import ObjectStoragePath
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
from airflow.utils.context import Context

from cosmos.operators.base import DbtCompileMixin
from cosmos.operators.local import (
DbtBuildLocalOperator,
DbtDepsLocalOperator,
DbtDocsAzureStorageLocalOperator,
DbtDocsCloudLocalOperator,
DbtDocsGCSLocalOperator,
DbtDocsLocalOperator,
DbtDocsS3LocalOperator,
DbtLocalBaseOperator,
DbtLSLocalOperator,
DbtRunOperationLocalOperator,
DbtSeedLocalOperator,
DbtSnapshotLocalOperator,
DbtSourceLocalOperator,
DbtTestLocalOperator,
)
from cosmos.settings import remote_target_path, remote_target_path_conn_id

_SUPPORTED_DATABASES = ["bigquery"]


class DbtBuildAirflowAsyncOperator(DbtBuildLocalOperator):
pass


class DbtLSAirflowAsyncOperator(DbtLSLocalOperator):
pass


class DbtSeedAirflowAsyncOperator(DbtSeedLocalOperator):
pass


class DbtSnapshotAirflowAsyncOperator(DbtSnapshotLocalOperator):
pass


class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator):
pass


class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator):
def __init__(self, *args, **kwargs):
self.configuration = {}
self.job_id = kwargs.get("job_id", {}) or ""
self.impersonation_chain = kwargs.get("impersonation_chain", {}) or ""
self.project_id = kwargs.get("project_id", {}) or ""

self.profile_config = kwargs.get("profile_config")
self.project_dir = kwargs.get("project_dir")

self.async_op_args = kwargs.get("async_op_args", {})
self.async_op_args["deferrable"] = True
super().__init__(*args, configuration=self.configuration, task_id=kwargs.get("task_id"), **self.async_op_args)
self.profile_type = self.profile_config.get_profile_type()
if self.profile_type not in _SUPPORTED_DATABASES:
raise f"Async run are only supported: {_SUPPORTED_DATABASES}"

self.reattach_states: set[str] = self.async_op_args.get("reattach_states") or set()

def get_remote_sql(self):
project_name = str(self.project_dir).split("/")[-1]
model_name: str = self.task_id.split(".")[0]
if model_name.startswith("stg_"):
remote_model_path = f"{remote_target_path}/{project_name}/models/staging/{model_name}.sql"
else:
remote_model_path = f"{remote_target_path}/{project_name}/models/{model_name}.sql"

print("remote_model_path: ", remote_model_path)
object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id)
with object_storage_path.open() as fp:
return fp.read()

def execute(self, context: Context) -> Any | None:
sql = self.get_remote_sql()
print("sql: ", sql)
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
print("async_op_args: ", self.async_op_args)
super().execute(context)


class DbtTestAirflowAsyncOperator(DbtTestLocalOperator):
pass


class DbtRunOperationAirflowAsyncOperator(DbtRunOperationLocalOperator):
pass


class DbtDocsAirflowAsyncOperator(DbtDocsLocalOperator):
pass


class DbtDocsCloudAirflowAsyncOperator(DbtDocsCloudLocalOperator):
pass


class DbtDocsS3AirflowAsyncOperator(DbtDocsS3LocalOperator):
pass


class DbtDocsAzureStorageAirflowAsyncOperator(DbtDocsAzureStorageLocalOperator):
pass


class DbtDocsGCSAirflowAsyncOperator(DbtDocsGCSLocalOperator):
pass


class DbtDepsAirflowAsyncOperator(DbtDepsLocalOperator):
pass


class DbtCompileAirflowAsyncOperator(DbtCompileMixin, DbtLocalBaseOperator):
"""
Executes a dbt core build command.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
kwargs["should_upload_compiled_sql"] = True
super().__init__(*args, **kwargs)

def execute(self, context: Context) -> None:
self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())
Loading

0 comments on commit 407d311

Please sign in to comment.