diff --git a/src/build.py b/src/build.py index 461342a..5de2c6c 100644 --- a/src/build.py +++ b/src/build.py @@ -19,6 +19,8 @@ def __init__(self): self._prawler = TodayILearnedPrawler() self._fetcher = WaPoFetcher() self._agents = ["agent_1", "agent_2"] + self.checkpoint_file = '../reading_sets/pre-build/checkpoint.json' + def _get_article_section(self, article_url): if article_url is not None: @@ -52,11 +54,40 @@ def _populate_factual_section(self, reading_sets, conversation_id, agent): reading_sets[conversation_id][agent][key]["summarized_wiki_lead_section"] =\ self._wikidata.get_wiki_text("summarized_wiki_lead_section", wiki_id) + def _save_checkpoint(self, data, conversation_id): + checkpoint_data = { + "reading_sets": data, + "last_processed_id": conversation_id + } + with open(self.checkpoint_file, "w") as checkpoint_file: + json.dump(checkpoint_data, checkpoint_file) + + def _load_checkpoint(self): + if os.path.exists(self.checkpoint_file): + with open(self.checkpoint_file, "r") as checkpoint_file: + checkpoint_data = json.load(checkpoint_file) + return checkpoint_data["reading_sets"], checkpoint_data["last_processed_id"] + return None, None + def _process(self, pre_build_file_path, post_build_file_path): - with open(pre_build_file_path, "r") as pre_build_file: - reading_sets = json.load(pre_build_file) + reading_sets, last_processed_id = self._load_checkpoint() + + if reading_sets is None: + with open(pre_build_file_path, "r") as pre_build_file: + reading_sets = json.load(pre_build_file) + start_processing = False + print(f"Starting from {last_processed_id}") + else: + start_processing = True + print("pass") + for conversation_id in tqdm(reading_sets.keys()): + if start_processing and conversation_id == last_processed_id: + start_processing = False + continue + elif start_processing: + continue article_url = reading_sets[conversation_id]["article_url"] del reading_sets[conversation_id]["article_url"] @@ -70,6 +101,11 @@ def _process(self, pre_build_file_path, post_build_file_path): # populate factual section for each agent corresponding to the conversation_id self._populate_factual_section(reading_sets, conversation_id, agent) + self._save_checkpoint(reading_sets, conversation_id) + + if os.path.exists(self.checkpoint_file): + os.remove(self.checkpoint_file) + with open(post_build_file_path, "w") as post_build_file: post_build_file.write(json.dumps(reading_sets, indent=2))