From fa80296189aef6b7601f514c3d98c8c2d61926b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dylan=20Perdig=C3=A3o?= Date: Sat, 11 May 2024 16:11:12 +0100 Subject: [PATCH 1/2] population coding for ce_rate_loss --- snntorch/functional/loss.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/snntorch/functional/loss.py b/snntorch/functional/loss.py index cffa5d2c..693802af 100644 --- a/snntorch/functional/loss.py +++ b/snntorch/functional/loss.py @@ -94,17 +94,30 @@ class ce_rate_loss(LossFunctions): """ - def __init__(self, reduction='mean', weight=None): + def __init__(self, population_code=False, num_classes=False, reduction='mean', weight=None): super().__init__(reduction=reduction, weight=weight) + self.population_code = population_code + self.num_classes = num_classes self.__name__ = "ce_rate_loss" def _compute_loss(self, spk_out, targets): - device, num_steps, _ = self._prediction_check(spk_out) + device, num_steps, num_outputs = self._prediction_check(spk_out) + + if self.population_code: + for idx in range(self.num_classes): + spk_out[ + :, + :, + int(num_outputs * idx / self.num_classes) : int( + num_outputs * (idx + 1) / self.num_classes + ), + ] + weights = torch.Tensor([self.weight[0] if i < int(num_outputs/self.num_classes) else self.weight[1] for i in range(num_outputs) ]).to(device) + log_softmax_fn = nn.LogSoftmax(dim=-1) - loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=self.weight) + loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=weights) log_p_y = log_softmax_fn(spk_out) - loss_shape = (spk_out.size(1)) if self._intermediate_reduction() == 'none' else (1) loss = torch.zeros(loss_shape, dtype=dtype, device=device) From d540f8abada0b437b9b9f5f2fb776511bfe7e978 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dylan=20Perdig=C3=A3o?= Date: Sat, 11 May 2024 16:16:34 +0100 Subject: [PATCH 2/2] Update loss.py --- snntorch/functional/loss.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/snntorch/functional/loss.py b/snntorch/functional/loss.py index 693802af..f1e19da9 100644 --- a/snntorch/functional/loss.py +++ b/snntorch/functional/loss.py @@ -102,6 +102,7 @@ def __init__(self, population_code=False, num_classes=False, reduction='mean', w def _compute_loss(self, spk_out, targets): device, num_steps, num_outputs = self._prediction_check(spk_out) + log_softmax_fn = nn.LogSoftmax(dim=-1) if self.population_code: for idx in range(self.num_classes): @@ -113,10 +114,10 @@ def _compute_loss(self, spk_out, targets): ), ] weights = torch.Tensor([self.weight[0] if i < int(num_outputs/self.num_classes) else self.weight[1] for i in range(num_outputs) ]).to(device) - - log_softmax_fn = nn.LogSoftmax(dim=-1) - loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=weights) - + loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=weights) + else: + loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=self.weight) + log_p_y = log_softmax_fn(spk_out) loss_shape = (spk_out.size(1)) if self._intermediate_reduction() == 'none' else (1) loss = torch.zeros(loss_shape, dtype=dtype, device=device)