Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support of Population Coding for ce_rate_loss #321

Merged
merged 2 commits into from
May 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions snntorch/functional/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,31 @@ class ce_rate_loss(LossFunctions):

"""

def __init__(self, reduction='mean', weight=None):
def __init__(self, population_code=False, num_classes=False, reduction='mean', weight=None):
super().__init__(reduction=reduction, weight=weight)
self.population_code = population_code
self.num_classes = num_classes
self.__name__ = "ce_rate_loss"

def _compute_loss(self, spk_out, targets):
device, num_steps, _ = self._prediction_check(spk_out)
device, num_steps, num_outputs = self._prediction_check(spk_out)
log_softmax_fn = nn.LogSoftmax(dim=-1)
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=self.weight)

if self.population_code:
for idx in range(self.num_classes):
spk_out[
:,
:,
int(num_outputs * idx / self.num_classes) : int(
num_outputs * (idx + 1) / self.num_classes
),
]
weights = torch.Tensor([self.weight[0] if i < int(num_outputs/self.num_classes) else self.weight[1] for i in range(num_outputs) ]).to(device)
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=weights)
else:
loss_fn = nn.NLLLoss(reduction=self._intermediate_reduction(), weight=self.weight)

log_p_y = log_softmax_fn(spk_out)

loss_shape = (spk_out.size(1)) if self._intermediate_reduction() == 'none' else (1)
loss = torch.zeros(loss_shape, dtype=dtype, device=device)

Expand Down
Loading