From 093715104ccf0657c30c606d52267969c884d5ae Mon Sep 17 00:00:00 2001 From: CaptainOfHacks <39195263+CaptainOfHacks@users.noreply.github.com> Date: Sat, 23 Sep 2023 11:48:06 +0300 Subject: [PATCH] fix processed notice ids --- dags/operators/DagBatchPipelineOperator.py | 98 ++++++++++++------- dags/pipelines/notice_processor_pipelines.py | 6 +- ...ocess_unnormalised_notices_from_backlog.py | 3 +- 3 files changed, 67 insertions(+), 40 deletions(-) diff --git a/dags/operators/DagBatchPipelineOperator.py b/dags/operators/DagBatchPipelineOperator.py index 71b1aa32..3d212094 100644 --- a/dags/operators/DagBatchPipelineOperator.py +++ b/dags/operators/DagBatchPipelineOperator.py @@ -1,5 +1,5 @@ from concurrent.futures import ThreadPoolExecutor -from typing import Any, Protocol, List +from typing import Any, Protocol, List, Optional from uuid import uuid4 from airflow.models import BaseOperator from airflow.operators.trigger_dagrun import TriggerDagRunOperator @@ -51,6 +51,62 @@ def __init__(self, *args, self.notice_pipeline_callable = notice_pipeline_callable self.batch_pipeline_callable = batch_pipeline_callable + def single_notice_processor(self, notice_id: str, notice_repository: NoticeRepository, + pipeline_name: str) -> Optional[str]: + """ + This method can execute the notice_pipeline_callable for a single notice_id. + :param notice_id: The notice_id what will be processed. + :param notice_repository: The notice repository. + :param pipeline_name: The pipeline name for logs. + """ + logger = get_logger() + notice = None + processed_notice_id = None + try: + notice_event = NoticeEventMessage(notice_id=notice_id, + domain_action=pipeline_name) + notice_event.start_record() + notice = notice_repository.get(reference=notice_id) + result_notice_pipeline = self.notice_pipeline_callable(notice, notice_repository.mongodb_client) + if result_notice_pipeline.store_result: + notice_repository.update(notice=result_notice_pipeline.notice) + if result_notice_pipeline.processed: + processed_notice_id = notice_id + notice_event.end_record() + if notice.normalised_metadata: + notice_event.notice_form_number = notice.normalised_metadata.form_number + notice_event.notice_eforms_subtype = notice.normalised_metadata.eforms_subtype + notice_event.notice_status = str(notice.status) + logger.info(event_message=notice_event) + error_message = result_notice_pipeline.error_message + except Exception as exception_error_message: + error_message = str(exception_error_message) + if error_message: + notice_normalised_metadata = notice.normalised_metadata if notice else None + log_notice_error(message=error_message, notice_id=notice_id, domain_action=pipeline_name, + notice_form_number=notice_normalised_metadata.form_number if notice_normalised_metadata else None, + notice_status=notice.status if notice else None, + notice_eforms_subtype=notice_normalised_metadata.eforms_subtype if notice_normalised_metadata else None) + return processed_notice_id + + def multithread_notice_processor(self, notice_ids: list, mongodb_client: MongoClient, pipeline_name: str) -> list: + """ + This method can execute the notice_pipeline_callable for each notice_id in the notice_ids batch. + :param notice_ids: The notice_ids batch. + :param mongodb_client: The mongodb client. + :param pipeline_name: The pipeline name for logs. + """ + processed_notice_ids = [] + notice_repository = NoticeRepository(mongodb_client=mongodb_client) + with ThreadPoolExecutor() as executor: + futures = [executor.submit(self.single_notice_processor, notice_id, notice_repository, pipeline_name) + for notice_id in notice_ids] + for future in futures: + processed_notice_id = future.result() + if processed_notice_id: + processed_notice_ids.append(processed_notice_id) + return processed_notice_ids + def execute(self, context: Any): """ This method can execute the notice_pipeline_callable for each notice_id in the notice_ids batch or @@ -61,7 +117,6 @@ def execute(self, context: Any): if not notice_ids: raise Exception(f"XCOM key [{NOTICE_IDS_KEY}] is not present in context!") mongodb_client = MongoClient(config.MONGO_DB_AUTH_URL) - notice_repository = NoticeRepository(mongodb_client=mongodb_client) processed_notice_ids = [] pipeline_name = DEFAULT_PIPELINE_NAME_FOR_LOGS if self.notice_pipeline_callable: @@ -77,41 +132,12 @@ def execute(self, context: Any): handle_event_message_metadata_dag_context(batch_event_message, context) batch_event_message.start_record() if self.batch_pipeline_callable is not None: - processed_notice_ids.extend( - self.batch_pipeline_callable(notice_ids=notice_ids, mongodb_client=mongodb_client)) + processed_notice_ids = self.batch_pipeline_callable(notice_ids=notice_ids, mongodb_client=mongodb_client) elif self.notice_pipeline_callable is not None: - def multithread_notice_processor(notice_id: str): - notice = None - try: - notice_event = NoticeEventMessage(notice_id=notice_id, domain_action=pipeline_name) - notice_event.start_record() - notice = notice_repository.get(reference=notice_id) - result_notice_pipeline = self.notice_pipeline_callable(notice, mongodb_client) - if result_notice_pipeline.store_result: - notice_repository.update(notice=result_notice_pipeline.notice) - if result_notice_pipeline.processed: - processed_notice_ids.append(notice_id) - notice_event.end_record() - if notice.normalised_metadata: - notice_event.notice_form_number = notice.normalised_metadata.form_number - notice_event.notice_eforms_subtype = notice.normalised_metadata.eforms_subtype - notice_event.notice_status = str(notice.status) - logger.info(event_message=notice_event) - error_message = result_notice_pipeline.error_message - except Exception as exception_error_message: - error_message = str(exception_error_message) - if error_message: - notice_normalised_metadata = notice.normalised_metadata if notice else None - log_notice_error(message=error_message, notice_id=notice_id, domain_action=pipeline_name, - notice_form_number=notice_normalised_metadata.form_number if notice_normalised_metadata else None, - notice_status=notice.status if notice else None, - notice_eforms_subtype=notice_normalised_metadata.eforms_subtype if notice_normalised_metadata else None) - - with ThreadPoolExecutor() as executor: - futures = [executor.submit(multithread_notice_processor, notice_id) for notice_id in - notice_ids] - for future in futures: - future.result() + processed_notice_ids = self.multithread_notice_processor(notice_ids=notice_ids, + mongodb_client=mongodb_client, + pipeline_name=pipeline_name + ) batch_event_message.end_record() logger.info(event_message=batch_event_message) if not processed_notice_ids: diff --git a/dags/pipelines/notice_processor_pipelines.py b/dags/pipelines/notice_processor_pipelines.py index 3a43c42e..a8d57ece 100644 --- a/dags/pipelines/notice_processor_pipelines.py +++ b/dags/pipelines/notice_processor_pipelines.py @@ -9,7 +9,7 @@ def notice_normalisation_pipeline(notice: Notice, mongodb_client: MongoClient = None) -> NoticePipelineOutput: """ - + Notice normalisation pipeline. This pipeline is responsible for normalising the notice metadata. """ from ted_sws.data_sampler.services.notice_xml_indexer import index_notice from ted_sws.notice_metadata_processor.services.metadata_normalizer import normalise_notice @@ -19,8 +19,8 @@ def notice_normalisation_pipeline(notice: Notice, mongodb_client: MongoClient = normalised_notice = normalise_notice(notice=indexed_notice) return NoticePipelineOutput(notice=normalised_notice) except Exception as error_message: - return NoticePipelineOutput(notice=indexed_notice, processed=False, store_result=True, - error_message=str(error_message)) + return NoticePipelineOutput(notice=indexed_notice, processed=False, + store_result=True, error_message=str(error_message)) def notice_transformation_pipeline(notice: Notice, mongodb_client: MongoClient) -> NoticePipelineOutput: diff --git a/dags/reprocess_unnormalised_notices_from_backlog.py b/dags/reprocess_unnormalised_notices_from_backlog.py index 111cbf41..287ba11e 100644 --- a/dags/reprocess_unnormalised_notices_from_backlog.py +++ b/dags/reprocess_unnormalised_notices_from_backlog.py @@ -36,7 +36,8 @@ def reprocess_unnormalised_notices_from_backlog(): def select_all_raw_notices(): start_date = get_dag_param(key=START_DATE_DAG_PARAM) end_date = get_dag_param(key=END_DATE_DAG_PARAM) - notice_ids = notice_ids_selector_by_status(notice_statuses=[NoticeStatus.RAW], start_date=start_date, + notice_ids = notice_ids_selector_by_status(notice_statuses=[NoticeStatus.RAW, NoticeStatus.INDEXED], + start_date=start_date, end_date=end_date) push_dag_downstream(key=NOTICE_IDS_KEY, value=notice_ids)