From 9091d0b3a0ebf136426f03411a9c0e67089ff12a Mon Sep 17 00:00:00 2001 From: nkempynck Date: Mon, 18 Nov 2024 09:47:36 +0100 Subject: [PATCH 1/3] poission loss func --- src/crested/tl/losses/__init__.py | 1 + src/crested/tl/losses/_poisson.py | 71 +++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 src/crested/tl/losses/_poisson.py diff --git a/src/crested/tl/losses/__init__.py b/src/crested/tl/losses/__init__.py index 3324497..c1d24ad 100644 --- a/src/crested/tl/losses/__init__.py +++ b/src/crested/tl/losses/__init__.py @@ -2,3 +2,4 @@ from ._cosinemse import CosineMSELoss from ._cosinemse_log import CosineMSELogLoss +from ._poisson import PoissonLoss diff --git a/src/crested/tl/losses/_poisson.py b/src/crested/tl/losses/_poisson.py new file mode 100644 index 0000000..10f2bf9 --- /dev/null +++ b/src/crested/tl/losses/_poisson.py @@ -0,0 +1,71 @@ +import keras +import keras.ops as ops + +@keras.saving.register_keras_serializable(package="Losses") +class PoissonLoss(keras.losses.Loss): + """ + Custom Poisson loss for count data with optional log(x + 1) transformation. + + This loss function computes the Poisson loss, optionally applying + log(x + 1) transformations to predictions and/or targets to ensure + non-negativity. + + Parameters + ---------- + log_transform : bool + If True, applies log(x + 1) transformation to both predictions and targets. + eps : float + Small value to avoid log(0). + reduction : str + Type of reduction to apply to the loss. Default: "sum_over_batch_size". + """ + def __init__( + self, + log_transform: bool = True, + eps: float = 1e-7, + reduction: str = "sum_over_batch_size", + name: str = "PoissonLoss" + ): + super().__init__(name=name, reduction=reduction) + self.log_transform = log_transform + self.eps = eps + + def call(self, y_true, y_pred): + """ + Compute the Poisson loss. + + Parameters + ---------- + y_true : Tensor + True target values (counts or log(x + 1)-transformed counts). + y_pred : Tensor + Predicted values (counts or log(x + 1)-transformed counts). + + Returns + ------- + Tensor + The Poisson loss value for each sample. + """ + # Ensure predictions and targets are float32 + y_true = ops.cast(y_true, dtype="float32") + y_pred = ops.cast(y_pred, dtype="float32") + + # Apply log(x + 1) transformation if needed + if self.log_transform: + y_true = ops.log(y_true + 1.0) + y_pred = ops.log(y_pred + 1.0) + + # Compute Poisson loss for each class + loss = y_pred - y_true * ops.log(y_pred + self.eps) + + # Sum the loss across classes + return ops.sum(loss, axis=-1) + + def get_config(self): + """Return the configuration of the loss function.""" + config = super().get_config() + config.update({ + "log_transform": self.log_transform, + "eps": self.eps + }) + return config \ No newline at end of file From 9610701b632a40d87ea566058a136ddaf11ae7d6 Mon Sep 17 00:00:00 2001 From: nkempynck Date: Mon, 18 Nov 2024 09:50:04 +0100 Subject: [PATCH 2/3] poisson updated --- src/crested/tl/losses/_poisson.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/crested/tl/losses/_poisson.py b/src/crested/tl/losses/_poisson.py index 10f2bf9..0578a63 100644 --- a/src/crested/tl/losses/_poisson.py +++ b/src/crested/tl/losses/_poisson.py @@ -1,6 +1,7 @@ import keras import keras.ops as ops + @keras.saving.register_keras_serializable(package="Losses") class PoissonLoss(keras.losses.Loss): """ @@ -19,6 +20,7 @@ class PoissonLoss(keras.losses.Loss): reduction : str Type of reduction to apply to the loss. Default: "sum_over_batch_size". """ + def __init__( self, log_transform: bool = True, @@ -68,4 +70,4 @@ def get_config(self): "log_transform": self.log_transform, "eps": self.eps }) - return config \ No newline at end of file + return config From 96a20db2a257292ee19aebed02d799be8a6db5dc Mon Sep 17 00:00:00 2001 From: nkempynck Date: Mon, 18 Nov 2024 09:54:29 +0100 Subject: [PATCH 3/3] docstring to init poission --- src/crested/tl/losses/_poisson.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/crested/tl/losses/_poisson.py b/src/crested/tl/losses/_poisson.py index 0578a63..aeef9af 100644 --- a/src/crested/tl/losses/_poisson.py +++ b/src/crested/tl/losses/_poisson.py @@ -28,6 +28,21 @@ def __init__( reduction: str = "sum_over_batch_size", name: str = "PoissonLoss" ): + """ + Initialize the PoissonLoss class. + + Parameters + ---------- + log_transform : bool, optional + Whether to apply a log(x + 1) transformation to the inputs. Default is True. + eps : float, optional + A small epsilon value to avoid log(0). Default is 1e-7. + reduction : str, optional + The type of reduction to apply to the loss, e.g., "sum_over_batch_size". + Default is "sum_over_batch_size". + name : str, optional + Name of the loss function. Default is "PoissonLoss". + """ super().__init__(name=name, reduction=reduction) self.log_transform = log_transform self.eps = eps