diff --git a/dags/operators/DagBatchPipelineOperator.py b/dags/operators/DagBatchPipelineOperator.py index b8636ca0..f7dd27fb 100644 --- a/dags/operators/DagBatchPipelineOperator.py +++ b/dags/operators/DagBatchPipelineOperator.py @@ -22,7 +22,7 @@ DEFAULT_START_WITH_TASK_ID = "notice_normalisation_pipeline" DEFAULT_PIPELINE_NAME_FOR_LOGS = "unknown_pipeline_name" AIRFLOW_NUMBER_OF_WORKERS = config.AIRFLOW_NUMBER_OF_WORKERS -DEFAULT_BATCH_SIZE = 2000 +MAX_BATCH_SIZE = 2000 class BatchPipelineCallable(Protocol): @@ -142,7 +142,11 @@ def execute(self, context: Any): self.execute_only_one_step = get_dag_param(key=EXECUTE_ONLY_ONE_STEP_KEY, default_value=False) notice_ids = pull_dag_upstream(key=NOTICE_IDS_KEY) if notice_ids: - batch_size = self.batch_size or DEFAULT_BATCH_SIZE + if self.batch_size: + batch_size = 1 + len(notice_ids) // AIRFLOW_NUMBER_OF_WORKERS + batch_size = batch_size if batch_size < MAX_BATCH_SIZE else MAX_BATCH_SIZE + else: + batch_size = self.batch_size for notice_batch in chunks(notice_ids, chunk_size=batch_size): TriggerDagRunOperator( task_id=f'trigger_worker_dag_{uuid4().hex}',