Skip to content

Commit

Permalink
Remove oss execute method code
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro committed Sep 30, 2024
1 parent ec765aa commit 35e58b6
Showing 1 changed file with 20 additions and 166 deletions.
186 changes: 20 additions & 166 deletions cosmos/operators/airflow_async.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, BigQueryJob
from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
from airflow.providers.google.cloud.triggers.bigquery import BigQueryInsertJobTrigger
from airflow.providers.google.cloud.utils.bigquery import convert_job_id
from airflow.utils.context import Context
from google.api_core.exceptions import Conflict
from google.cloud.bigquery import CopyJob, ExtractJob, LoadJob, QueryJob

if TYPE_CHECKING:
from google.cloud.bigquery import UnknownJob
from typing import Any

from airflow.io.path import ObjectStoragePath
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
from airflow.utils.context import Context

from cosmos.operators.base import DbtCompileMixin
from cosmos.operators.local import (
Expand All @@ -27,7 +17,6 @@
DbtDocsS3LocalOperator,
DbtLocalBaseOperator,
DbtLSLocalOperator,
DbtRunLocalOperator,
DbtRunOperationLocalOperator,
DbtSeedLocalOperator,
DbtSnapshotLocalOperator,
Expand Down Expand Up @@ -59,10 +48,19 @@ class DbtSourceAirflowAsyncOperator(DbtSourceLocalOperator):
pass


class DbtRunAirflowAsyncOperator(DbtRunLocalOperator):
class DbtRunAirflowAsyncOperator(BigQueryInsertJobOperator):
def __init__(self, *args, **kwargs):
# super(BigQueryInsertJobOperator, self).__init__(*args, **kwargs)
super(DbtRunLocalOperator, self).__init__(*args, **kwargs)
self.configuration = {}
self.job_id = kwargs.get("job_id", {}) or ""
self.impersonation_chain = kwargs.get("impersonation_chain", {}) or ""
self.project_id = kwargs.get("project_id", {}) or ""

self.profile_config = kwargs.get("profile_config")
self.project_dir = kwargs.get("project_dir")

self.async_op_args = kwargs.get("async_op_args", {})
self.async_op_args["deferrable"] = True
super().__init__(*args, configuration=self.configuration, task_id=kwargs.get("task_id"), **self.async_op_args)
self.profile_type = self.profile_config.get_profile_type()
if self.profile_type not in _SUPPORTED_DATABASES:
raise f"Async run are only supported: {_SUPPORTED_DATABASES}"
Expand All @@ -73,170 +71,26 @@ def get_remote_sql(self):
project_name = str(self.project_dir).split("/")[-1]
model_name: str = self.task_id.split(".")[0]
if model_name.startswith("stg_"):
remote_model_path = f"{remote_target_path}/target/compiled/{project_name}/models/staging/{model_name}.sql"
remote_model_path = f"{remote_target_path}/{project_name}/models/staging/{model_name}.sql"
else:
remote_model_path = f"{remote_target_path}/target/compiled/{project_name}/models/{model_name}.sql"
remote_model_path = f"{remote_target_path}/{project_name}/models/{model_name}.sql"

print("remote_model_path: ", remote_model_path)
object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id)
with object_storage_path.open() as fp:
return fp.read()

def execute(self, context: Context) -> Any | None:
sql = self.get_remote_sql()
sql = sql.replace("***", "astronomer-dag-authoring")
sql = sql.replace('"', "")
print("sql: ", sql)
configuration = {
self.configuration = {
"query": {
"query": sql,
"useLegacySql": False,
}
}
print("async_op_args: ", self.async_op_args)
# super(BigQueryInsertJobOperator).execute(context)

conn_id = self.profile_config.profile_mapping.conn_id

hook = BigQueryHook(
gcp_conn_id=conn_id,
impersonation_chain=self.async_op_args.get("impersonation_chain"),
)

self.job_id = hook.generate_job_id(
job_id=self.async_op_args.get("job_id"),
dag_id=self.dag_id,
task_id=self.task_id,
logical_date=context["logical_date"],
configuration=configuration,
force_rerun=self.async_op_args.get("force_rerun"),
)

try:
self.log.info("Executing: %s'", configuration)
# Create a job
job: BigQueryJob | UnknownJob = hook.insert_job(
configuration=configuration,
project_id=self.async_op_args.get("project_id"),
location=self.async_op_args.get("location"),
job_id=self.job_id,
timeout=self.async_op_args.get("result_timeout"),
retry=self.async_op_args.get("result_retry"),
nowait=True,
)
except Conflict:
# If the job already exists retrieve it
job = hook.get_job(
project_id=self.async_op_args.get("project_id"),
location=self.async_op_args.get("location"),
job_id=self.job_id,
)

if job.state not in self.reattach_states:
# Same job configuration, so we need force_rerun
raise AirflowException(
f"Job with id: {self.job_id} already exists and is in {job.state} state. If you "
f"want to force rerun it consider setting `force_rerun=True`."
f"Or, if you want to reattach in this scenario add {job.state} to `reattach_states`"
)

else:
# Job already reached state DONE
if job.state == "DONE":
raise AirflowException("Job is already in state DONE. Can not reattach to this job.")

# We are reattaching to a job
self.log.info("Reattaching to existing Job in state %s", job.state)
self._handle_job_error(job)

job_types = {
LoadJob._JOB_TYPE: ["sourceTable", "destinationTable"],
CopyJob._JOB_TYPE: ["sourceTable", "destinationTable"],
ExtractJob._JOB_TYPE: ["sourceTable"],
QueryJob._JOB_TYPE: ["destinationTable"],
}

if self.async_op_args.get("project_id"):
for job_type, tables_prop in job_types.items():
job_configuration = job.to_api_repr()["configuration"]
if job_type in job_configuration:
for table_prop in tables_prop:
if table_prop in job_configuration[job_type]:
table = job_configuration[job_type][table_prop]
persist_kwargs = {
"context": context,
"task_instance": self,
"project_id": self.async_op_args.get("project_id"),
"table_id": table,
}
if not isinstance(table, str):
persist_kwargs["table_id"] = table["tableId"]
persist_kwargs["dataset_id"] = table["datasetId"]
persist_kwargs["project_id"] = table["projectId"]
BigQueryTableLink.persist(**persist_kwargs)

self.job_id = job.job_id

if self.async_op_args.get("project_id"):
job_id_path = convert_job_id(
job_id=self.job_id,
project_id=self.async_op_args.get("project_id"),
location=self.async_op_args.get("location"),
)
context["ti"].xcom_push(key="job_id_path", value=job_id_path)

if job.running():
self.defer(
timeout=self.execution_timeout,
trigger=BigQueryInsertJobTrigger(
conn_id=conn_id,
job_id=self.job_id,
project_id=self.async_op_args.get("project_id"),
location=self.async_op_args.get("location"),
poll_interval=self.async_op_args.get("poll_interval"),
impersonation_chain=self.async_op_args.get("impersonation_chain"),
cancel_on_kill=self.async_op_args.get("cancel_on_kill"),
),
method_name="execute_complete",
)
self.log.info("Current state of job %s is %s", job.job_id, job.state)
self._handle_job_error(job)

def _handle_job_error(self, job: BigQueryJob | UnknownJob):
if job.error_result:
raise AirflowException(f"BigQuery job {job.job_id} failed: {job.error_result}")

def execute_complete(self, context: Context, event: dict[str, Any]) -> str | None:
"""
Act as a callback for when the trigger fires.
This returns immediately. It relies on trigger to throw an exception,
otherwise it assumes execution was successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])
self.log.info(
"%s completed with response %s ",
self.task_id,
event["message"],
)
# Save job_id as an attribute to be later used by listeners
self.job_id = event.get("job_id")
return self.job_id

def on_kill(self) -> None:
if self.job_id and self.async_op_args.get("cancel_on_kill"):
self.hook.cancel_job( # type: ignore[union-attr]
job_id=self.job_id,
project_id=self.async_op_args.get("project_id"),
location=self.async_op_args.get("location"),
)
else:
self.log.info(
"Skipping to cancel job: %s:%s.%s",
self.async_op_args.get("project_id"),
self.async_op_args.get("location"),
self.job_id,
)
super().execute(context)


class DbtTestAirflowAsyncOperator(DbtTestLocalOperator):
Expand Down

0 comments on commit 35e58b6

Please sign in to comment.