From 8534c8124d132e8524227ac35d3574ba8f9d0522 Mon Sep 17 00:00:00 2001 From: CaptainOfHacks <39195263+CaptainOfHacks@users.noreply.github.com> Date: Wed, 20 Sep 2023 16:26:41 +0300 Subject: [PATCH] Update DagBatchPipelineOperator.py --- dags/operators/DagBatchPipelineOperator.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/dags/operators/DagBatchPipelineOperator.py b/dags/operators/DagBatchPipelineOperator.py index 64b4e5db..60ee7302 100644 --- a/dags/operators/DagBatchPipelineOperator.py +++ b/dags/operators/DagBatchPipelineOperator.py @@ -18,12 +18,11 @@ NOTICE_IDS_KEY = "notice_ids" START_WITH_STEP_NAME_KEY = "start_with_step_name" EXECUTE_ONLY_ONE_STEP_KEY = "execute_only_one_step" -DEFAULT_NUMBER_OF_CELERY_WORKERS = 6 # TODO: revise this config NOTICE_PROCESSING_PIPELINE_DAG_NAME = "notice_processing_pipeline" DEFAULT_START_WITH_TASK_ID = "notice_normalisation_pipeline" DEFAULT_PIPELINE_NAME_FOR_LOGS = "unknown_pipeline_name" AIRFLOW_NUMBER_OF_WORKERS = config.AIRFLOW_NUMBER_OF_WORKERS -MAX_BATCH_SIZE = 2000 +DEFAULT_BATCH_SIZE = 5000 class BatchPipelineCallable(Protocol): @@ -116,7 +115,7 @@ def multithread_notice_processor(notice_id: str): batch_event_message.end_record() logger.info(event_message=batch_event_message) if not processed_notice_ids: - raise Exception(f"No notice has been processed!") + raise Exception("No notice has been processed!") smart_xcom_push(key=NOTICE_IDS_KEY, value=processed_notice_ids) @@ -143,11 +142,7 @@ def execute(self, context: Any): self.execute_only_one_step = get_dag_param(key=EXECUTE_ONLY_ONE_STEP_KEY, default_value=False) notice_ids = pull_dag_upstream(key=NOTICE_IDS_KEY) if notice_ids: - if self.batch_size is None: - computed_batch_size = 1 + len(notice_ids) // DEFAULT_NUMBER_OF_CELERY_WORKERS - batch_size = computed_batch_size if computed_batch_size < MAX_BATCH_SIZE else MAX_BATCH_SIZE - else: - batch_size = self.batch_size + batch_size = self.batch_size or DEFAULT_BATCH_SIZE for notice_batch in chunks(notice_ids, chunk_size=batch_size): TriggerDagRunOperator( task_id=f'trigger_worker_dag_{uuid4().hex}',