Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun committed Oct 15, 2024
1 parent 136a620 commit 8d3402c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 21 deletions.
21 changes: 0 additions & 21 deletions tests/test_miro.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
end-end test for mutual information regulation
"""
import pytest
from tests.utils_test import utils_test_algo


Expand All @@ -12,23 +11,3 @@ def test_miro():
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \
--trainer=miro --nname=alexnet"
utils_test_algo(args)

def test_miro2():
"""
train with MIRO
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \
--trainer=miro --nname=alexnet \
--layers2extract_feats _net_invar_feat.net_torchvision.features.1"
utils_test_algo(args)

def test_miro3():
"""
train with MIRO
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \
--trainer=miro --nname=alexnet \
--layers2extract_feats features"
with pytest.raises(RuntimeError):
utils_test_algo(args)
raise RuntimeError("This is a runtime error")
14 changes: 14 additions & 0 deletions tests/test_miro2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
end-end test for mutual information regulation
"""
from tests.utils_test import utils_test_algo


def test_miro2():
"""
train with MIRO
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \
--trainer=miro --nname=alexnet \
--layers2extract_feats _net_invar_feat.net_torchvision.features.1"
utils_test_algo(args)
17 changes: 17 additions & 0 deletions tests/test_miro3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
end-end test for mutual information regulation
"""
import pytest
from tests.utils_test import utils_test_algo


def test_miro3():
"""
train with MIRO
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \
--trainer=miro --nname=alexnet \
--layers2extract_feats features"
with pytest.raises(RuntimeError):
utils_test_algo(args)
raise RuntimeError("This is a runtime error")

0 comments on commit 8d3402c

Please sign in to comment.