Skip to content

Commit

Permalink
Merge pull request #74 from aertslab/dev
Browse files Browse the repository at this point in the history
update msecosine log loss function
  • Loading branch information
nkempynck authored Nov 28, 2024
2 parents af998dc + da4dfd9 commit 5a76acf
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/crested/tl/losses/_cosinemse_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class CosineMSELogLoss(keras.losses.Loss):
Name of the loss function.
reduction
Type of reduction to apply to loss.
multiplier
Scalar to multiply the predicted value with. When predicting mean coverage, multiply by 1000 to get actual count. Keep to 1 when predicting insertion counts.
Notes
-----
Expand All @@ -50,11 +52,13 @@ def __init__(
max_weight: float = 1.0,
name: str | None = "CosineMSELogLoss",
reduction: str = "sum_over_batch_size",
multiplier: float = 1000,
):
"""Initialize the loss function."""
super().__init__(name=name)
self.max_weight = max_weight
self.reduction = reduction
self.multiplier = multiplier

def call(self, y_true, y_pred):
"""Compute the loss value."""
Expand All @@ -64,13 +68,13 @@ def call(self, y_true, y_pred):
y_true1 = keras.utils.normalize(y_true, axis=-1)
y_pred1 = keras.utils.normalize(y_pred, axis=-1)

log_y_pred_pos = keras.ops.log(1 + 1000 * keras.ops.maximum(y_pred, 0))
log_y_pred_pos = keras.ops.log(1 + self.multiplier * keras.ops.maximum(y_pred, 0))
log_y_pred_neg = -keras.ops.log(
1 + keras.ops.abs(1000 * keras.ops.minimum(y_pred, 0))
1 + keras.ops.abs(self.multiplier * keras.ops.minimum(y_pred, 0))
)

log_y_pred = log_y_pred_pos + log_y_pred_neg
log_y_true = keras.ops.log(1 + 1000 * y_true)
log_y_true = keras.ops.log(1 + self.multiplier * y_true)

mse_loss = keras.ops.mean(keras.ops.square(log_y_pred - log_y_true))
weight = keras.ops.abs(mse_loss)
Expand Down

0 comments on commit 5a76acf

Please sign in to comment.