diff --git a/dags/operators/DagBatchPipelineOperator.py b/dags/operators/DagBatchPipelineOperator.py index 64b4e5db..f7dd27fb 100644 --- a/dags/operators/DagBatchPipelineOperator.py +++ b/dags/operators/DagBatchPipelineOperator.py @@ -18,7 +18,6 @@ NOTICE_IDS_KEY = "notice_ids" START_WITH_STEP_NAME_KEY = "start_with_step_name" EXECUTE_ONLY_ONE_STEP_KEY = "execute_only_one_step" -DEFAULT_NUMBER_OF_CELERY_WORKERS = 6 # TODO: revise this config NOTICE_PROCESSING_PIPELINE_DAG_NAME = "notice_processing_pipeline" DEFAULT_START_WITH_TASK_ID = "notice_normalisation_pipeline" DEFAULT_PIPELINE_NAME_FOR_LOGS = "unknown_pipeline_name" @@ -116,7 +115,7 @@ def multithread_notice_processor(notice_id: str): batch_event_message.end_record() logger.info(event_message=batch_event_message) if not processed_notice_ids: - raise Exception(f"No notice has been processed!") + raise Exception("No notice has been processed!") smart_xcom_push(key=NOTICE_IDS_KEY, value=processed_notice_ids) @@ -143,9 +142,9 @@ def execute(self, context: Any): self.execute_only_one_step = get_dag_param(key=EXECUTE_ONLY_ONE_STEP_KEY, default_value=False) notice_ids = pull_dag_upstream(key=NOTICE_IDS_KEY) if notice_ids: - if self.batch_size is None: - computed_batch_size = 1 + len(notice_ids) // DEFAULT_NUMBER_OF_CELERY_WORKERS - batch_size = computed_batch_size if computed_batch_size < MAX_BATCH_SIZE else MAX_BATCH_SIZE + if self.batch_size: + batch_size = 1 + len(notice_ids) // AIRFLOW_NUMBER_OF_WORKERS + batch_size = batch_size if batch_size < MAX_BATCH_SIZE else MAX_BATCH_SIZE else: batch_size = self.batch_size for notice_batch in chunks(notice_ids, chunk_size=batch_size):