Skip to content

Commit

Permalink
add parameters to fit
Browse files Browse the repository at this point in the history
  • Loading branch information
erceksi committed Jul 9, 2024
1 parent c0e0de6 commit 129c0ee
Showing 1 changed file with 48 additions and 8 deletions.
56 changes: 48 additions & 8 deletions src/crested/tl/_crested.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,34 +133,42 @@ def _initialize_callbacks(
save_dir: os.PathLike,
model_checkpointing: bool,
model_checkpointing_best_only: bool,
model_checkpointing_metric: str,
model_checkpointing_mode: str,
early_stopping: bool,
early_stopping_patience: int | None,
early_stopping_metric: str,
early_stopping_mode: str,
learning_rate_reduce: bool,
learning_rate_reduce_patience: int | None,
learning_rate_reduce_metric: str,
learning_rate_reduce_mode: str,
custom_callbacks: list | None,
) -> list:
"""Initialize callbacks"""
callbacks = []
if early_stopping:
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
patience=early_stopping_patience,
mode="min",
monitor="val_loss",
mode=early_stopping_mode,
monitor=early_stopping_metric,
)
callbacks.append(early_stopping_callback)
if model_checkpointing:
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(save_dir, "checkpoints", "{epoch:02d}.keras"),
monitor="val_loss",
monitor=model_checkpointing_metric,
save_best_only=model_checkpointing_best_only,
mode=model_checkpointing_mode,
save_freq="epoch",
)
callbacks.append(model_checkpoint_callback)
if learning_rate_reduce:
learning_rate_reduce_callback = tf.keras.callbacks.ReduceLROnPlateau(
patience=learning_rate_reduce_patience,
monitor="val_loss",
monitor=learning_rate_reduce_metric,
factor=0.25,
mode=learning_rate_reduce_mode
)
callbacks.append(learning_rate_reduce_callback)
if custom_callbacks is not None:
Expand Down Expand Up @@ -213,10 +221,16 @@ def fit(
mixed_precision: bool = False,
model_checkpointing: bool = True,
model_checkpointing_best_only: bool = True,
model_checkpointing_metric: str = 'val_loss',
model_checkpointing_mode: str = 'min',
early_stopping: bool = True,
early_stopping_patience: int = 10,
early_stopping_metric: str = 'val_loss',
early_stopping_mode: str = 'min',
learning_rate_reduce: bool = True,
learning_rate_reduce_patience: int = 5,
learning_rate_reduce_metric: str = 'val_loss',
learning_rate_reduce_mode: str = 'min',
custom_callbacks: list | None = None,
) -> None:
"""
Expand All @@ -232,14 +246,26 @@ def fit(
Save model checkpoints.
model_checkpointing_best_only
Save only the best model checkpoint.
model_checkpointing_metric
Metric to monitor to choose best models.
model_checkpointing_mode
'max' if a high metric is better, 'min' if a low metric is better
early_stopping
Enable early stopping.
early_stopping_patience
Number of epochs with no improvement after which training will be stopped.
early_stopping_metric
Metric to monitor for early stopping.
early_stopping_mode
'max' if a high metric is better, 'min' if a low metric is better
learning_rate_reduce
Enable learning rate reduction.
learning_rate_reduce_patience
Number of epochs with no improvement after which learning rate will be reduced.
learning_rate_reduce_metric
Metric to monitor for reducing the learning rate.
learning_rate_reduce_mode
'max' if a high metric is better, 'min' if a low metric is better
custom_callbacks
List of custom callbacks to use during training.
"""
Expand All @@ -250,10 +276,16 @@ def fit(
self.save_dir,
model_checkpointing,
model_checkpointing_best_only,
model_checkpointing_metric,
model_checkpointing_mode,
early_stopping,
early_stopping_patience,
early_stopping_metric,
early_stopping_mode,
learning_rate_reduce,
learning_rate_reduce_patience,
learning_rate_reduce_metric,
learning_rate_reduce_mode,
custom_callbacks,
)

Expand Down Expand Up @@ -657,16 +689,18 @@ def calculate_contribution_scores_sequence(
--------
crested.pl.patterns.contribution_scores
"""
self._check_contrib_params(method)
if self.anndatamodule.predict_dataset is None:
self.anndatamodule.setup("predict")
self._check_contribution_scores_params(class_names)

if isinstance(sequences, str):
sequences = [sequences]

if isinstance(class_names, str):
class_names = [class_names]

self._check_contrib_params(method)
if self.anndatamodule.predict_dataset is None:
self.anndatamodule.setup("predict")
self._check_contribution_scores_params(class_names)


all_scores = []
all_one_hot_sequences = []
Expand Down Expand Up @@ -829,6 +863,9 @@ def enhancer_design_motif_implementation(
A list of designed sequences and if return_intermediate is True a list of dictionaries of intermediate
mutations and predictions
"""

self._check_contribution_scores_params([target_class])

all_class_names = list(self.anndatamodule.adata.obs_names)

target = all_class_names.index(target_class)
Expand Down Expand Up @@ -991,6 +1028,9 @@ def enhancer_design_in_silico_evolution(
A list of designed sequences and if return_intermediate is True a list of dictionaries of intermediate
mutations and predictions
"""

self._check_contribution_scores_params([target_class])

all_class_names = list(self.anndatamodule.adata.obs_names)

target = all_class_names.index(target_class)
Expand Down

0 comments on commit 129c0ee

Please sign in to comment.