diff --git a/pyproject.toml b/pyproject.toml index 96d29a9..6946a19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ readme = "README.md" requires-python = ">=3.11" [tool.bumpversion] -current_version = "0.0.3" +current_version = "0.0.4" commit = false tag = false diff --git a/src/modeling/infer.py b/src/modeling/infer.py index e686be3..bcaa32f 100644 --- a/src/modeling/infer.py +++ b/src/modeling/infer.py @@ -41,6 +41,10 @@ def save_visualizations(self, documents: list[str]): ) fig.write_html(self.output_path / "vis_documents.html") + def log_preds(self, topics: list[int]): + with (self.output_path / "preds.txt").open("w") as f: + f.write("\n".join(str(x) for x in topics)) + def log_info(self, documents: list[str]): with (self.output_path / "topic_info.csv").open("w") as f: f.write(self.model.get_topic_info().to_csv()) @@ -49,7 +53,8 @@ def log_info(self, documents: list[str]): def infer(self): documents = self.get_dataset(self.processing).tolist() - self.model.transform(documents) + topics, probs = self.model.transform(documents) self.output_path.mkdir(exist_ok=True, parents=True) + self.log_preds(topics.tolist()) self.log_info(documents) self.save_visualizations(documents)