Skip to content

Commit

Permalink
Merge pull request #4 from sfp932705/preds
Browse files Browse the repository at this point in the history
Preds
  • Loading branch information
sfp932705 authored Jul 21, 2024
2 parents 94c384d + ce9591f commit 2c12f65
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ readme = "README.md"
requires-python = ">=3.11"

[tool.bumpversion]
current_version = "0.0.3"
current_version = "0.0.4"
commit = false
tag = false

Expand Down
7 changes: 6 additions & 1 deletion src/modeling/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def save_visualizations(self, documents: list[str]):
)
fig.write_html(self.output_path / "vis_documents.html")

def log_preds(self, topics: list[int]):
with (self.output_path / "preds.txt").open("w") as f:
f.write("\n".join(str(x) for x in topics))

def log_info(self, documents: list[str]):
with (self.output_path / "topic_info.csv").open("w") as f:
f.write(self.model.get_topic_info().to_csv())
Expand All @@ -49,7 +53,8 @@ def log_info(self, documents: list[str]):

def infer(self):
documents = self.get_dataset(self.processing).tolist()
self.model.transform(documents)
topics, probs = self.model.transform(documents)
self.output_path.mkdir(exist_ok=True, parents=True)
self.log_preds(topics.tolist())
self.log_info(documents)
self.save_visualizations(documents)

0 comments on commit 2c12f65

Please sign in to comment.