Skip to content

Commit

Permalink
split test_fbopt into two files
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun committed Dec 11, 2024
1 parent 598a4cc commit 30fe2f5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 25 deletions.
25 changes: 0 additions & 25 deletions tests/test_fbopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,3 @@ def test_diva_fbopt():
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=diva --gamma_y=1.0 --trainer=fbopt --nname=alexnet --epos=3 --no_dump"
utils_test_algo(args)


def test_erm_fbopt():
"""
erm
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt --nname=alexnet --epos=3 --no_dump" # pylint: disable=line-too-long
with pytest.raises(RuntimeError):
utils_test_algo(args)


def test_irm_fbopt():
"""
irm
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt_irm --nname=alexnet --epos=3 --no_dump" # pylint: disable=line-too-long
utils_test_algo(args)


def test_forcesetpoint_fbopt():
"""
diva
"""
args = "--te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --model=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=10 --es=0 --mu_init=0.00001 --coeff_ma_setpoint=0.5 --coeff_ma_output_state=0.99 --force_setpoint_change_once --no_dump"
utils_test_algo(args)
37 changes: 37 additions & 0 deletions tests/test_fbopt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
unit and end-end test for deep all, mldg
"""
import pytest
from tests.utils_test import utils_test_algo


def test_erm_fbopt():
"""
erm
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \
--trainer=fbopt --nname=alexnet --epos=3 \
--no_dump" # pylint: disable=line-too-long
with pytest.raises(RuntimeError):
utils_test_algo(args)


def test_irm_fbopt():
"""
irm
"""
args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \
--trainer=fbopt_irm --nname=alexnet --epos=3 \
--no_dump" # pylint: disable=line-too-long
utils_test_algo(args)


def test_forcesetpoint_fbopt():
"""
diva
"""
args = "--te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --model=jigen \
--trainer=fbopt --nname=conv_bn_pool_2 --epos=10 --es=0 \
--mu_init=0.00001 --coeff_ma_setpoint=0.5 \
--coeff_ma_output_state=0.99 --force_setpoint_change_once --no_dump"
utils_test_algo(args)

0 comments on commit 30fe2f5

Please sign in to comment.