Skip to content

Commit

Permalink
Increase tolerence for approx. equality of loss
Browse files Browse the repository at this point in the history
  • Loading branch information
saiftyfirst committed Oct 15, 2023
1 parent 47df1df commit 9203870
Showing 1 changed file with 27 additions and 25 deletions.
52 changes: 27 additions & 25 deletions tests/test_snntorch/functional/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@
import snntorch as snn
import snntorch.functional as sf
import torch
import math

import snntorch.spikegen as spikegen

torch.manual_seed(42)
tolerance = 1e-3
tolerance = 1e-2


@pytest.fixture(scope="module")
Expand All @@ -30,17 +27,22 @@ def membrane_predicted_():
# shape: time_steps x batch_size x num_out_neurons
return torch.rand((3, 3, 3))


@pytest.fixture(scope="module")
def class_weights_():
return torch.tensor([0.35, 0.50, 0.15], dtype=torch.float32)


def assert_approximate_equality(actual, expected):
assert actual == pytest.approx(expected, abs=tolerance)


class TestLoss:
def test_ce_rate_loss_base(self, spike_predicted_, targets_labels_):
loss_fn = sf.ce_rate_loss()
loss = loss_fn(spike_predicted_, targets_labels_)

assert math.isclose(loss.item(), 1.1099, rel_tol=tolerance)
assert loss.item() == pytest.approx(1.1099, abs=tolerance)

def test_ce_rate_loss_unreduced(self, spike_predicted_, targets_labels_):
unreduced_loss_fn = sf.ce_rate_loss(reduction='none')
Expand All @@ -49,7 +51,7 @@ def test_ce_rate_loss_unreduced(self, spike_predicted_, targets_labels_):
reduced_loss_fn = sf.ce_rate_loss()
reduced_loss = reduced_loss_fn(spike_predicted_, targets_labels_)

assert math.isclose(unreduced_loss.mean().item(), reduced_loss.item(), rel_tol=tolerance)
assert_approximate_equality(unreduced_loss.mean().item(), reduced_loss.item())

def test_ce_rate_loss_weighted(self, spike_predicted_, targets_labels_, class_weights_):
weighted_loss_fn = sf.ce_rate_loss(weight=class_weights_)
Expand All @@ -63,13 +65,13 @@ def test_ce_rate_loss_weighted(self, spike_predicted_, targets_labels_, class_we
# expectation
expected_weighted_loss = ((vanilla_loss * weight_multiplier).mean())

assert math.isclose(weighted_loss.item(), expected_weighted_loss.item(), rel_tol=tolerance)
assert_approximate_equality(weighted_loss.item(), expected_weighted_loss.item())

def test_ce_count_loss_base(self, spike_predicted_, targets_labels_):
loss_fn = sf.ce_count_loss()
loss = loss_fn(spike_predicted_, targets_labels_)

assert math.isclose(loss.item(), 1.1944, rel_tol=tolerance)
assert_approximate_equality(loss.item(), 1.1944)

def test_ce_count_loss_unreduced(self, spike_predicted_, targets_labels_):
unreduced_loss_fn = sf.ce_count_loss(reduction='none')
Expand All @@ -78,7 +80,7 @@ def test_ce_count_loss_unreduced(self, spike_predicted_, targets_labels_):
reduced_loss_fn = sf.ce_count_loss()
reduced_loss = reduced_loss_fn(spike_predicted_, targets_labels_)

assert math.isclose(unreduced_loss.mean().item(), reduced_loss.item(), rel_tol=tolerance)
assert_approximate_equality(unreduced_loss.mean().item(), reduced_loss.item())

def test_ce_count_loss_weighted(self, spike_predicted_, targets_labels_, class_weights_):
weighted_loss_fn = sf.ce_count_loss(weight=class_weights_)
Expand All @@ -92,13 +94,13 @@ def test_ce_count_loss_weighted(self, spike_predicted_, targets_labels_, class_w
# expectation
expected_weighted_loss = ((vanilla_loss * weight_multiplier).mean())

assert math.isclose(weighted_loss.item(), expected_weighted_loss.item(), rel_tol=tolerance)
assert_approximate_equality(weighted_loss.item(), expected_weighted_loss.item())

def test_ce_max_membrane_loss_base(self, membrane_predicted_, targets_labels_):
loss_fn = sf.ce_max_membrane_loss()
loss = loss_fn(membrane_predicted_, targets_labels_)

assert math.isclose(loss.item(), 1.0639, rel_tol=1e-4)
assert_approximate_equality(loss.item(), 1.0639)

def test_ce_max_membrane_loss_unreduced(self, membrane_predicted_, targets_labels_):
unreduced_loss_fn = sf.ce_max_membrane_loss(reduction='none')
Expand All @@ -107,7 +109,7 @@ def test_ce_max_membrane_loss_unreduced(self, membrane_predicted_, targets_label
reduced_loss_fn = sf.ce_max_membrane_loss()
reduced_loss = reduced_loss_fn(membrane_predicted_, targets_labels_)

assert math.isclose(unreduced_loss.mean().item(), reduced_loss.item(), rel_tol=tolerance)
assert_approximate_equality(unreduced_loss.mean().item(), reduced_loss.item())

def test_ce_max_membrane_loss_weighted(self, spike_predicted_, targets_labels_, class_weights_):
weighted_loss_fn = sf.ce_max_membrane_loss(weight=class_weights_)
Expand All @@ -121,13 +123,13 @@ def test_ce_max_membrane_loss_weighted(self, spike_predicted_, targets_labels_,
# expectation
expected_weighted_loss = ((vanilla_loss * weight_multiplier).mean())

assert math.isclose(weighted_loss.item(), expected_weighted_loss.item(), rel_tol=tolerance)
assert_approximate_equality(weighted_loss.item(), expected_weighted_loss.item())

def test_mse_count_loss_base(self, spike_predicted_, targets_labels_):
loss_fn = sf.mse_count_loss()
loss = loss_fn(spike_predicted_, targets_labels_)

assert math.isclose(loss.item(), 0.8148, rel_tol=tolerance)
assert_approximate_equality(loss.item(), 0.8148)

def test_mse_count_loss_unreduced(self, spike_predicted_, targets_labels_):
unreduced_loss_fn = sf.mse_count_loss(reduction='none')
Expand All @@ -136,7 +138,7 @@ def test_mse_count_loss_unreduced(self, spike_predicted_, targets_labels_):
reduced_loss_fn = sf.mse_count_loss()
reduced_loss = reduced_loss_fn(spike_predicted_, targets_labels_)

assert math.isclose(unreduced_loss.mean().item(), reduced_loss.item(), rel_tol=tolerance)
assert_approximate_equality(unreduced_loss.mean().item(), reduced_loss.item())

def test_mse_count_loss_weighted(self, spike_predicted_, targets_labels_, class_weights_):
weighted_loss_fn = sf.mse_count_loss(weight=class_weights_)
Expand All @@ -150,13 +152,13 @@ def test_mse_count_loss_weighted(self, spike_predicted_, targets_labels_, class_
# expectation
expected_weighted_loss = ((vanilla_loss * weight_multiplier).mean())

assert math.isclose(weighted_loss.item(), expected_weighted_loss.item(), rel_tol=tolerance)
assert_approximate_equality(weighted_loss.item(), expected_weighted_loss.item())

def test_mse_membrane_loss_base(self, membrane_predicted_, targets_labels_):
loss_fn = sf.mse_membrane_loss()
loss = loss_fn(membrane_predicted_, targets_labels_)

assert math.isclose(loss.item(), 0.3214, rel_tol=tolerance)
assert_approximate_equality(loss.item(), 0.3214)

def test_mse_membrane_loss_unreduced(self, membrane_predicted_, targets_labels_):
unreduced_loss_fn = sf.mse_membrane_loss(reduction='none')
Expand All @@ -165,7 +167,7 @@ def test_mse_membrane_loss_unreduced(self, membrane_predicted_, targets_labels_)
reduced_loss_fn = sf.mse_membrane_loss()
reduced_loss = reduced_loss_fn(membrane_predicted_, targets_labels_)

assert math.isclose(unreduced_loss.mean().item(), reduced_loss.item(), rel_tol=tolerance)
assert_approximate_equality(unreduced_loss.mean().item(), reduced_loss.item())

def test_mse_membrane_loss_weighted(self, spike_predicted_, targets_labels_, class_weights_):
weighted_loss_fn = sf.mse_membrane_loss(weight=class_weights_)
Expand All @@ -179,13 +181,13 @@ def test_mse_membrane_loss_weighted(self, spike_predicted_, targets_labels_, cla
# expectation
expected_weighted_loss = ((vanilla_loss * weight_multiplier).mean())

assert math.isclose(weighted_loss.item(), expected_weighted_loss.item(), rel_tol=tolerance)
assert_approximate_equality(weighted_loss.item(), expected_weighted_loss.item())

def test_mse_temporal_loss_base(self, spike_predicted_, targets_labels_):
loss_fn = sf.mse_temporal_loss(on_target=1, off_target=0)
loss = loss_fn(spike_predicted_, targets_labels_)

assert math.isclose(loss.item(), 0.22222, rel_tol=1e-4)
assert_approximate_equality(loss.item(), 0.22222)

def test_mse_temporal_loss_unreduced(self, spike_predicted_, targets_labels_):
unreduced_loss_fn = sf.mse_temporal_loss(reduction='none')
Expand All @@ -194,7 +196,7 @@ def test_mse_temporal_loss_unreduced(self, spike_predicted_, targets_labels_):
reduced_loss_fn = sf.mse_temporal_loss()
reduced_loss = reduced_loss_fn(spike_predicted_, targets_labels_)

assert math.isclose(unreduced_loss.mean().item(), reduced_loss.item(), rel_tol=tolerance)
assert_approximate_equality(unreduced_loss.mean().item(), reduced_loss.item())

def test_mse_temporal_loss_weighted(self, spike_predicted_, targets_labels_, class_weights_):
weighted_loss_fn = sf.mse_temporal_loss(weight=class_weights_)
Expand All @@ -208,13 +210,13 @@ def test_mse_temporal_loss_weighted(self, spike_predicted_, targets_labels_, cla
# expectation
expected_weighted_loss = ((vanilla_loss * weight_multiplier).mean())

assert math.isclose(weighted_loss.item(), expected_weighted_loss.item(), rel_tol=tolerance)
assert_approximate_equality(weighted_loss.item(), expected_weighted_loss.item())

def test_ce_temporal_loss_base(self, spike_predicted_, targets_labels_):
loss_fn = sf.ce_temporal_loss()
loss = loss_fn(spike_predicted_, targets_labels_)

assert math.isclose(loss.item(), 0.8364, rel_tol=1e-4)
assert_approximate_equality(loss.item(), 0.8364)

def test_ce_temporal_loss_unreduced(self, spike_predicted_, targets_labels_):
unreduced_loss_fn = sf.ce_temporal_loss(reduction='none')
Expand All @@ -223,7 +225,7 @@ def test_ce_temporal_loss_unreduced(self, spike_predicted_, targets_labels_):
reduced_loss_fn = sf.ce_temporal_loss()
reduced_loss = reduced_loss_fn(spike_predicted_, targets_labels_)

assert math.isclose(unreduced_loss.mean().item(), reduced_loss.item(), rel_tol=tolerance)
assert_approximate_equality(unreduced_loss.mean().item(), reduced_loss.item())

def test_ce_temporal_loss_weighted(self, spike_predicted_, targets_labels_, class_weights_):
weighted_loss_fn = sf.ce_temporal_loss(weight=class_weights_)
Expand All @@ -237,4 +239,4 @@ def test_ce_temporal_loss_weighted(self, spike_predicted_, targets_labels_, clas
# expectation
expected_weighted_loss = ((vanilla_loss * weight_multiplier).sum() / weight_multiplier.sum())

assert math.isclose(weighted_loss.item(), expected_weighted_loss.item(), rel_tol=tolerance)
assert_approximate_equality(weighted_loss.item(), expected_weighted_loss.item())

0 comments on commit 9203870

Please sign in to comment.