diff --git a/cosmos/config.py b/cosmos/config.py index 2cebbf3cc..b92a25a20 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -10,6 +10,7 @@ from pathlib import Path from typing import Any, Callable, Iterator +import yaml from airflow.version import version as airflow_version from cosmos.cache import create_cache_profile, get_cached_profile, is_profile_cache_enabled @@ -286,6 +287,19 @@ def validate_profiles_yml(self) -> None: if self.profiles_yml_filepath and not Path(self.profiles_yml_filepath).exists(): raise CosmosValueError(f"The file {self.profiles_yml_filepath} does not exist.") + def get_profile_type(self): + if self.profile_mapping.dbt_profile_type: + return self.profile_mapping.dbt_profile_type + + profile_path = self._get_profile_path() + + with open(profile_path) as file: + profiles = yaml.safe_load(file) + + profile = profiles[self.profile_name] + target_type = profile["outputs"][self.target_name]["type"] + return target_type + def _get_profile_path(self, use_mock_values: bool = False) -> Path: """ Handle the profile caching mechanism. diff --git a/cosmos/converter.py b/cosmos/converter.py index fd077c465..e4e0b7f6b 100644 --- a/cosmos/converter.py +++ b/cosmos/converter.py @@ -207,6 +207,7 @@ def __init__( task_group: TaskGroup | None = None, operator_args: dict[str, Any] | None = None, on_warning_callback: Callable[..., Any] | None = None, + async_op_args: dict[str, Any] | None = None, *args: Any, **kwargs: Any, ) -> None: @@ -256,6 +257,7 @@ def __init__( cache_identifier=cache_identifier, dbt_vars=dbt_vars, airflow_metadata=cache._get_airflow_metadata(dag, task_group), + async_op_args=async_op_args, ) self.dbt_graph.load(method=render_config.load_method, execution_mode=execution_config.execution_mode) @@ -301,6 +303,7 @@ def __init__( dbt_project_name=project_config.project_name, on_warning_callback=on_warning_callback, render_config=render_config, + async_op_args=async_op_args, ) current_time = time.perf_counter() diff --git a/cosmos/core/airflow.py b/cosmos/core/airflow.py index 9e1d08ac1..67230fb39 100644 --- a/cosmos/core/airflow.py +++ b/cosmos/core/airflow.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib from airflow.models import BaseOperator @@ -10,7 +12,7 @@ logger = get_logger(__name__) -def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None) -> BaseOperator: +def get_airflow_task(task: Task, dag: DAG, task_group: TaskGroup | None = None) -> BaseOperator: """ Get the Airflow Operator class for a Task. @@ -29,6 +31,9 @@ def get_airflow_task(task: Task, dag: DAG, task_group: "TaskGroup | None" = None if task.owner != "": task_kwargs["owner"] = task.owner + if task.async_op_args: + task_kwargs["async_op_args"] = task.async_op_args + airflow_task = Operator( task_id=task.id, dag=dag, diff --git a/cosmos/core/graph/entities.py b/cosmos/core/graph/entities.py index 6bf9ff046..340cb7781 100644 --- a/cosmos/core/graph/entities.py +++ b/cosmos/core/graph/entities.py @@ -61,3 +61,4 @@ class Task(CosmosEntity): operator_class: str = "airflow.operators.empty.EmptyOperator" arguments: Dict[str, Any] = field(default_factory=dict) extra_context: Dict[str, Any] = field(default_factory=dict) + async_op_args: Dict[str, Any] = field(default_factory=dict) diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index 1c0237e8f..0400edc1a 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -217,6 +217,7 @@ def __init__( dbt_vars: dict[str, str] | None = None, airflow_metadata: dict[str, str] | None = None, operator_args: dict[str, Any] | None = None, + async_op_args: dict[str, Any] | None = None, ): self.project = project self.render_config = render_config @@ -224,6 +225,7 @@ def __init__( self.execution_config = execution_config self.cache_dir = cache_dir self.airflow_metadata = airflow_metadata or {} + self.async_op_args = async_op_args if cache_identifier: self.dbt_ls_cache_key = cache.create_cache_key(cache_identifier) else: @@ -467,7 +469,6 @@ def should_use_dbt_ls_cache(self) -> bool: def load_via_dbt_ls_cache(self) -> bool: """(Try to) load dbt ls cache from an Airflow Variable""" - logger.info(f"Trying to parse the dbt project using dbt ls cache {self.dbt_ls_cache_key}...") if self.should_use_dbt_ls_cache(): project_path = self.project_path diff --git a/cosmos/operators/airflow_async.py b/cosmos/operators/airflow_async.py index 05f762702..f278dcba8 100644 --- a/cosmos/operators/airflow_async.py +++ b/cosmos/operators/airflow_async.py @@ -1,19 +1,34 @@ +from __future__ import annotations + +from typing import Any + +from airflow.io.path import ObjectStoragePath +from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator +from airflow.utils.context import Context + +from cosmos.operators.base import DbtCompileMixin from cosmos.operators.local import ( DbtBuildLocalOperator, - DbtCompileLocalOperator, + DbtDepsLocalOperator, DbtDocsAzureStorageLocalOperator, + DbtDocsCloudLocalOperator, DbtDocsGCSLocalOperator, DbtDocsLocalOperator, DbtDocsS3LocalOperator, + DbtLocalBaseOperator, DbtLSLocalOperator, - DbtRunLocalOperator, - DbtRunOperationLocalOperator, DbtSeedLocalOperator, DbtSnapshotLocalOperator, DbtSourceLocalOperator, DbtTestLocalOperator, + DbtRunOperationLocalOperator, + DbtCompileLocalOperator, ) +from cosmos.settings import remote_target_path, remote_target_path_conn_id + +_SUPPORTED_DATABASES = ["bigquery"] + class DbtBuildAirflowAsyncOperator(DbtBuildLocalOperator): pass @@ -35,8 +50,50 @@ class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator): pass -class DbtRunAirflowAsyncOperator(DbtRunLocalOperator): - pass +class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator): + def __init__(self, *args, **kwargs): + self.configuration = {} + self.job_id = kwargs.get("job_id", {}) or "" + self.impersonation_chain = kwargs.get("impersonation_chain", {}) or "" + self.project_id = kwargs.get("project_id", {}) or "" + + self.profile_config = kwargs.get("profile_config") + self.project_dir = kwargs.get("project_dir") + + self.async_op_args = kwargs.get("async_op_args", {}) + self.async_op_args["deferrable"] = True + super().__init__(*args, configuration=self.configuration, task_id=kwargs.get("task_id"), **self.async_op_args) + self.profile_type = self.profile_config.get_profile_type() + if self.profile_type not in _SUPPORTED_DATABASES: + raise f"Async run are only supported: {_SUPPORTED_DATABASES}" + + self.reattach_states: set[str] = self.async_op_args.get("reattach_states") or set() + + def get_remote_sql(self): + project_name = str(self.project_dir).split("/")[-1] + model_name: str = self.task_id.split(".")[0] + if model_name.startswith("stg_"): + remote_model_path = f"{remote_target_path}/{project_name}/models/staging/{model_name}.sql" + else: + remote_model_path = f"{remote_target_path}/{project_name}/models/{model_name}.sql" + + print("remote_model_path: ", remote_model_path) + object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id) + with object_storage_path.open() as fp: + return fp.read() + + def execute(self, context: Context) -> Any | None: + sql = self.get_remote_sql() + print("sql: ", sql) + self.configuration = { + "query": { + "query": sql, + "useLegacySql": False, + } + } + print("async_op_args: ", self.async_op_args) + super().execute(context) + class DbtTestAirflowAsyncOperator(DbtTestLocalOperator): @@ -51,6 +108,10 @@ class DbtDocsAirflowAsyncOperator(DbtDocsLocalOperator): pass +class DbtDocsCloudAirflowAsyncOperator(DbtDocsCloudLocalOperator): + pass + + class DbtDocsS3AirflowAsyncOperator(DbtDocsS3LocalOperator): pass @@ -65,3 +126,8 @@ class DbtDocsGCSAirflowAsyncOperator(DbtDocsGCSLocalOperator): class DbtCompileAirflowAsyncOperator(DbtCompileLocalOperator): pass + +class DbtDepsAirflowAsyncOperator(DbtDepsLocalOperator): + pass + + diff --git a/cosmos/operators/base.py b/cosmos/operators/base.py index ed7969ebd..45f8bd173 100644 --- a/cosmos/operators/base.py +++ b/cosmos/operators/base.py @@ -110,6 +110,7 @@ def __init__( dbt_cmd_global_flags: list[str] | None = None, cache_dir: Path | None = None, extra_context: dict[str, Any] | None = None, + # configuration: dict[str, Any] | None = None, **kwargs: Any, ) -> None: self.project_dir = project_dir @@ -140,6 +141,7 @@ def __init__( self.cache_dir = cache_dir self.extra_context = extra_context or {} kwargs.pop("full_refresh", None) # usage of this param should be implemented in child classes + # kwargs["configuration"] = {} super().__init__(**kwargs) def get_env(self, context: Context) -> dict[str, str | bytes | os.PathLike[Any]]: diff --git a/cosmos/operators/local.py b/cosmos/operators/local.py index 1083d5703..ad320d5fc 100644 --- a/cosmos/operators/local.py +++ b/cosmos/operators/local.py @@ -135,6 +135,7 @@ def __init__( should_store_compiled_sql: bool = True, should_upload_compiled_sql: bool = False, append_env: bool = True, + async_op_args: dict[str, Any] | None = None, **kwargs: Any, ) -> None: self.profile_config = profile_config @@ -148,6 +149,7 @@ def __init__( self.invoke_dbt: Callable[..., FullOutputSubprocessResult | dbtRunnerResult] self.handle_exception: Callable[..., None] self._dbt_runner: dbtRunner | None = None + self.async_op_args = async_op_args if self.invocation_mode: self._set_invocation_methods() super().__init__(**kwargs) @@ -289,6 +291,7 @@ def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]: return _configured_target_path, remote_conn_id + def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None: """ Uploads the compiled SQL files from the dbt compile output to the remote store. @@ -297,6 +300,7 @@ def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None: return dest_target_dir, dest_conn_id = self._configure_remote_target_path() + if not dest_target_dir: raise CosmosValueError( "You're trying to upload compiled SQL files, but the remote target path is not configured. "