diff --git a/src/crested/tl/losses/_cosinemse_log.py b/src/crested/tl/losses/_cosinemse_log.py index c4e067d..9ea261a 100644 --- a/src/crested/tl/losses/_cosinemse_log.py +++ b/src/crested/tl/losses/_cosinemse_log.py @@ -30,6 +30,8 @@ class CosineMSELogLoss(keras.losses.Loss): Name of the loss function. reduction Type of reduction to apply to loss. + multiplier + Scalar to multiply the predicted value with. When predicting mean coverage, multiply by 1000 to get actual count. Keep to 1 when predicting insertion counts. Notes ----- @@ -50,11 +52,13 @@ def __init__( max_weight: float = 1.0, name: str | None = "CosineMSELogLoss", reduction: str = "sum_over_batch_size", + multiplier: float = 1000, ): """Initialize the loss function.""" super().__init__(name=name) self.max_weight = max_weight self.reduction = reduction + self.multiplier = multiplier def call(self, y_true, y_pred): """Compute the loss value.""" @@ -64,13 +68,13 @@ def call(self, y_true, y_pred): y_true1 = keras.utils.normalize(y_true, axis=-1) y_pred1 = keras.utils.normalize(y_pred, axis=-1) - log_y_pred_pos = keras.ops.log(1 + 1000 * keras.ops.maximum(y_pred, 0)) + log_y_pred_pos = keras.ops.log(1 + self.multiplier * keras.ops.maximum(y_pred, 0)) log_y_pred_neg = -keras.ops.log( - 1 + keras.ops.abs(1000 * keras.ops.minimum(y_pred, 0)) + 1 + keras.ops.abs(self.multiplier * keras.ops.minimum(y_pred, 0)) ) log_y_pred = log_y_pred_pos + log_y_pred_neg - log_y_true = keras.ops.log(1 + 1000 * y_true) + log_y_true = keras.ops.log(1 + self.multiplier * y_true) mse_loss = keras.ops.mean(keras.ops.square(log_y_pred - log_y_true)) weight = keras.ops.abs(mse_loss)