diff --git a/tests/test_msel_oracle.py b/tests/test_msel_oracle.py index 8a28567b0..3eaef27d2 100644 --- a/tests/test_msel_oracle.py +++ b/tests/test_msel_oracle.py @@ -17,6 +17,15 @@ from domainlab.utils.utils_cuda import get_device +def mk_model(task): + backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) + num_final_in = backbone.fc.in_features + backbone.fc = nn.Linear(num_final_in, task.dim_y) + + # specify model to use + model = mk_erm(list_str_y=task.list_str_y)(backbone) + return model + def mk_exp( task, model, @@ -70,10 +79,7 @@ def mk_exp( return exp -def test_msel_oracle(): - """ - return trainer, model, observer - """ +def mk_task(): task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") task.add_domain( name="domain1", @@ -90,8 +96,15 @@ def test_msel_oracle(): dset_tr=DsetMNISTColorSoloDefault(4), dset_val=DsetMNISTColorSoloDefault(5), ) + return task + +def test_msel_oracle(): + """ + return trainer, model, observer + """ # specify backbone to use + task = mk_task() backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) num_final_in = backbone.fc.in_features backbone.fc = nn.Linear(num_final_in, task.dim_y) @@ -110,24 +123,8 @@ def test_msel_oracle1(): """ return trainer, model, observer """ - task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain( - name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1), - ) - task.add_domain( - name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3), - ) - task.add_domain( - name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5), - ) - - # specify backbone to use + task = mk_task() + # specify backbone to use backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) num_final_in = backbone.fc.in_features backbone.fc = nn.Linear(num_final_in, task.dim_y) @@ -150,23 +147,7 @@ def test_msel_oracle2(): """ return trainer, model, observer """ - task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain( - name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1), - ) - task.add_domain( - name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3), - ) - task.add_domain( - name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5), - ) - + task = mk_task() # specify backbone to use backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) num_final_in = backbone.fc.in_features @@ -184,30 +165,8 @@ def test_msel_oracle3(): """ return trainer, model, observer """ - task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain( - name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1), - ) - task.add_domain( - name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3), - ) - task.add_domain( - name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5), - ) - - # specify backbone to use - backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) - num_final_in = backbone.fc.in_features - backbone.fc = nn.Linear(num_final_in, task.dim_y) - - # specify model to use - model = mk_erm(list_str_y=task.list_str_y)(backbone) + task = mk_task() + model = mk_model(task) exp = mk_exp( task, @@ -226,30 +185,9 @@ def test_msel_oracle4(): """ return trainer, model, observer """ - task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") - task.add_domain( - name="domain1", - dset_tr=DsetMNISTColorSoloDefault(0), - dset_val=DsetMNISTColorSoloDefault(1), - ) - task.add_domain( - name="domain2", - dset_tr=DsetMNISTColorSoloDefault(2), - dset_val=DsetMNISTColorSoloDefault(3), - ) - task.add_domain( - name="domain3", - dset_tr=DsetMNISTColorSoloDefault(4), - dset_val=DsetMNISTColorSoloDefault(5), - ) - + task = mk_task() + model = mk_model(task) # specify backbone to use - backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) - num_final_in = backbone.fc.in_features - backbone.fc = nn.Linear(num_final_in, task.dim_y) - - # specify model to use - model = mk_erm(list_str_y=task.list_str_y)(backbone) exp = mk_exp( task, model, @@ -261,5 +199,5 @@ def test_msel_oracle4(): ) exp.execute(num_epochs=2) exp.trainer.observer.model_sel.msel.best_loss = 0 - exp.trainer.observer.model_sel.msel.update(epoch = 1, clear_counter=True) + exp.trainer.observer.model_sel.msel.update(epoch=1, clear_counter=True) del exp