From 30fe2f5593f645ecc8096f380c44c4271538f802 Mon Sep 17 00:00:00 2001 From: smilesun Date: Wed, 11 Dec 2024 11:48:23 +0100 Subject: [PATCH] split test_fbopt into two files --- tests/test_fbopt.py | 25 ------------------------- tests/test_fbopt2.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 25 deletions(-) create mode 100644 tests/test_fbopt2.py diff --git a/tests/test_fbopt.py b/tests/test_fbopt.py index 03dad9e16..15306ebf4 100644 --- a/tests/test_fbopt.py +++ b/tests/test_fbopt.py @@ -27,28 +27,3 @@ def test_diva_fbopt(): """ args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=diva --gamma_y=1.0 --trainer=fbopt --nname=alexnet --epos=3 --no_dump" utils_test_algo(args) - - -def test_erm_fbopt(): - """ - erm - """ - args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt --nname=alexnet --epos=3 --no_dump" # pylint: disable=line-too-long - with pytest.raises(RuntimeError): - utils_test_algo(args) - - -def test_irm_fbopt(): - """ - irm - """ - args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm --trainer=fbopt_irm --nname=alexnet --epos=3 --no_dump" # pylint: disable=line-too-long - utils_test_algo(args) - - -def test_forcesetpoint_fbopt(): - """ - diva - """ - args = "--te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --model=jigen --trainer=fbopt --nname=conv_bn_pool_2 --epos=10 --es=0 --mu_init=0.00001 --coeff_ma_setpoint=0.5 --coeff_ma_output_state=0.99 --force_setpoint_change_once --no_dump" - utils_test_algo(args) diff --git a/tests/test_fbopt2.py b/tests/test_fbopt2.py new file mode 100644 index 000000000..746924386 --- /dev/null +++ b/tests/test_fbopt2.py @@ -0,0 +1,37 @@ +""" +unit and end-end test for deep all, mldg +""" +import pytest +from tests.utils_test import utils_test_algo + + +def test_erm_fbopt(): + """ + erm + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \ + --trainer=fbopt --nname=alexnet --epos=3 \ + --no_dump" # pylint: disable=line-too-long + with pytest.raises(RuntimeError): + utils_test_algo(args) + + +def test_irm_fbopt(): + """ + irm + """ + args = "--te_d=caltech --task=mini_vlcs --debug --bs=2 --model=erm \ + --trainer=fbopt_irm --nname=alexnet --epos=3 \ + --no_dump" # pylint: disable=line-too-long + utils_test_algo(args) + + +def test_forcesetpoint_fbopt(): + """ + diva + """ + args = "--te_d=0 --tr_d 1 2 --task=mnistcolor10 --bs=16 --model=jigen \ + --trainer=fbopt --nname=conv_bn_pool_2 --epos=10 --es=0 \ + --mu_init=0.00001 --coeff_ma_setpoint=0.5 \ + --coeff_ma_output_state=0.99 --force_setpoint_change_once --no_dump" + utils_test_algo(args)