Skip to content

Commit

Permalink
write tensorboard to output directory
Browse files Browse the repository at this point in the history
  • Loading branch information
Lilferrit committed Sep 5, 2024
1 parent 4cb18e1 commit e36a4c9
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
2 changes: 1 addition & 1 deletion casanovo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class Config:
max_length=int,
residues=dict,
n_log=int,
tb_summarywriter=str,
tb_summarywriter=bool,
train_label_smoothing=float,
warmup_iters=int,
cosine_schedule_period_iters=int,
Expand Down
4 changes: 2 additions & 2 deletions casanovo/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ random_seed: 454
# OUTPUT OPTIONS
# Logging frequency in training steps.
n_log: 1
# Tensorboard directory to use for keeping track of training metrics.
tb_summarywriter:
# Whether to create tensorboard directory
tb_summarywriter: false
# Model validation and checkpointing frequency in training steps.
val_check_interval: 50_000

Expand Down
15 changes: 13 additions & 2 deletions casanovo/denovo/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(
"""Initialize a ModelRunner"""
self.config = config
self.model_filename = model_filename
self.output_dir = output_dir

# Initialized later:
self.tmp_dir = None
Expand Down Expand Up @@ -268,6 +269,16 @@ def initialize_model(self, train: bool) -> None:
Determines whether to set the model up for model training or
evaluation / inference.
"""
tb_summarywriter = None
if self.config.tb_summarywriter:
if self.output_dir is None:
logger.warning(

Check warning on line 275 in casanovo/denovo/model_runner.py

View check run for this annotation

Codecov / codecov/patch

casanovo/denovo/model_runner.py#L275

Added line #L275 was not covered by tests
"Can not create tensorboard because the output directory "
"is not set in the model runner."
)
else:
tb_summarywriter = self.output_dir / "tensorboard"

model_params = dict(
dim_model=self.config.dim_model,
n_head=self.config.n_head,
Expand All @@ -284,7 +295,7 @@ def initialize_model(self, train: bool) -> None:
n_beams=self.config.n_beams,
top_match=self.config.top_match,
n_log=self.config.n_log,
tb_summarywriter=self.config.tb_summarywriter,
tb_summarywriter=tb_summarywriter,
train_label_smoothing=self.config.train_label_smoothing,
warmup_iters=self.config.warmup_iters,
cosine_schedule_period_iters=self.config.cosine_schedule_period_iters,
Expand All @@ -303,7 +314,7 @@ def initialize_model(self, train: bool) -> None:
min_peptide_len=self.config.min_peptide_len,
top_match=self.config.top_match,
n_log=self.config.n_log,
tb_summarywriter=self.config.tb_summarywriter,
tb_summarywriter=tb_summarywriter,
train_label_smoothing=self.config.train_label_smoothing,
warmup_iters=self.config.warmup_iters,
cosine_schedule_period_iters=self.config.cosine_schedule_period_iters,
Expand Down

0 comments on commit e36a4c9

Please sign in to comment.