Skip to content

Commit

Permalink
Replace exact loss checks with parameter checks
Browse files Browse the repository at this point in the history
  • Loading branch information
saiftyfirst committed Oct 15, 2023
1 parent 9203870 commit 69160ef
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions tests/test_snntorch/functional/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch

torch.manual_seed(42)
tolerance = 1e-2
tolerance = 1e-5


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -40,9 +40,9 @@ def assert_approximate_equality(actual, expected):
class TestLoss:
def test_ce_rate_loss_base(self, spike_predicted_, targets_labels_):
loss_fn = sf.ce_rate_loss()
loss = loss_fn(spike_predicted_, targets_labels_)

assert loss.item() == pytest.approx(1.1099, abs=tolerance)
assert loss_fn.weight is None
assert loss_fn.reduction == 'mean'

def test_ce_rate_loss_unreduced(self, spike_predicted_, targets_labels_):
unreduced_loss_fn = sf.ce_rate_loss(reduction='none')
Expand All @@ -69,9 +69,9 @@ def test_ce_rate_loss_weighted(self, spike_predicted_, targets_labels_, class_we

def test_ce_count_loss_base(self, spike_predicted_, targets_labels_):
loss_fn = sf.ce_count_loss()
loss = loss_fn(spike_predicted_, targets_labels_)

assert_approximate_equality(loss.item(), 1.1944)
assert loss_fn.weight is None
assert loss_fn.reduction == 'mean'

def test_ce_count_loss_unreduced(self, spike_predicted_, targets_labels_):
unreduced_loss_fn = sf.ce_count_loss(reduction='none')
Expand All @@ -98,9 +98,9 @@ def test_ce_count_loss_weighted(self, spike_predicted_, targets_labels_, class_w

def test_ce_max_membrane_loss_base(self, membrane_predicted_, targets_labels_):
loss_fn = sf.ce_max_membrane_loss()
loss = loss_fn(membrane_predicted_, targets_labels_)

assert_approximate_equality(loss.item(), 1.0639)
assert loss_fn.weight is None
assert loss_fn.reduction == 'mean'

def test_ce_max_membrane_loss_unreduced(self, membrane_predicted_, targets_labels_):
unreduced_loss_fn = sf.ce_max_membrane_loss(reduction='none')
Expand All @@ -127,9 +127,9 @@ def test_ce_max_membrane_loss_weighted(self, spike_predicted_, targets_labels_,

def test_mse_count_loss_base(self, spike_predicted_, targets_labels_):
loss_fn = sf.mse_count_loss()
loss = loss_fn(spike_predicted_, targets_labels_)

assert_approximate_equality(loss.item(), 0.8148)
assert loss_fn.weight is None
assert loss_fn.reduction == 'mean'

def test_mse_count_loss_unreduced(self, spike_predicted_, targets_labels_):
unreduced_loss_fn = sf.mse_count_loss(reduction='none')
Expand All @@ -156,9 +156,9 @@ def test_mse_count_loss_weighted(self, spike_predicted_, targets_labels_, class_

def test_mse_membrane_loss_base(self, membrane_predicted_, targets_labels_):
loss_fn = sf.mse_membrane_loss()
loss = loss_fn(membrane_predicted_, targets_labels_)

assert_approximate_equality(loss.item(), 0.3214)
assert loss_fn.weight is None
assert loss_fn.reduction == 'mean'

def test_mse_membrane_loss_unreduced(self, membrane_predicted_, targets_labels_):
unreduced_loss_fn = sf.mse_membrane_loss(reduction='none')
Expand All @@ -185,9 +185,9 @@ def test_mse_membrane_loss_weighted(self, spike_predicted_, targets_labels_, cla

def test_mse_temporal_loss_base(self, spike_predicted_, targets_labels_):
loss_fn = sf.mse_temporal_loss(on_target=1, off_target=0)
loss = loss_fn(spike_predicted_, targets_labels_)

assert_approximate_equality(loss.item(), 0.22222)
assert loss_fn.weight is None
assert loss_fn.reduction == 'mean'

def test_mse_temporal_loss_unreduced(self, spike_predicted_, targets_labels_):
unreduced_loss_fn = sf.mse_temporal_loss(reduction='none')
Expand All @@ -214,9 +214,9 @@ def test_mse_temporal_loss_weighted(self, spike_predicted_, targets_labels_, cla

def test_ce_temporal_loss_base(self, spike_predicted_, targets_labels_):
loss_fn = sf.ce_temporal_loss()
loss = loss_fn(spike_predicted_, targets_labels_)

assert_approximate_equality(loss.item(), 0.8364)
assert loss_fn.weight is None
assert loss_fn.reduction == 'mean'

def test_ce_temporal_loss_unreduced(self, spike_predicted_, targets_labels_):
unreduced_loss_fn = sf.ce_temporal_loss(reduction='none')
Expand Down

0 comments on commit 69160ef

Please sign in to comment.