Skip to content

Commit

Permalink
clean test msel oracle
Browse files Browse the repository at this point in the history
  • Loading branch information
smilesun committed Jul 26, 2024
1 parent 589c428 commit c599946
Showing 1 changed file with 25 additions and 87 deletions.
112 changes: 25 additions & 87 deletions tests/test_msel_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@
from domainlab.utils.utils_cuda import get_device


def mk_model(task):
backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
num_final_in = backbone.fc.in_features
backbone.fc = nn.Linear(num_final_in, task.dim_y)

# specify model to use
model = mk_erm(list_str_y=task.list_str_y)(backbone)
return model

def mk_exp(
task,
model,
Expand Down Expand Up @@ -70,10 +79,7 @@ def mk_exp(
return exp


def test_msel_oracle():
"""
return trainer, model, observer
"""
def mk_task():
task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task")
task.add_domain(
name="domain1",
Expand All @@ -90,8 +96,15 @@ def test_msel_oracle():
dset_tr=DsetMNISTColorSoloDefault(4),
dset_val=DsetMNISTColorSoloDefault(5),
)
return task


def test_msel_oracle():
"""
return trainer, model, observer
"""
# specify backbone to use
task = mk_task()
backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
num_final_in = backbone.fc.in_features
backbone.fc = nn.Linear(num_final_in, task.dim_y)
Expand All @@ -110,24 +123,8 @@ def test_msel_oracle1():
"""
return trainer, model, observer
"""
task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task")
task.add_domain(
name="domain1",
dset_tr=DsetMNISTColorSoloDefault(0),
dset_val=DsetMNISTColorSoloDefault(1),
)
task.add_domain(
name="domain2",
dset_tr=DsetMNISTColorSoloDefault(2),
dset_val=DsetMNISTColorSoloDefault(3),
)
task.add_domain(
name="domain3",
dset_tr=DsetMNISTColorSoloDefault(4),
dset_val=DsetMNISTColorSoloDefault(5),
)

# specify backbone to use
task = mk_task()
# specify backbone to use
backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
num_final_in = backbone.fc.in_features
backbone.fc = nn.Linear(num_final_in, task.dim_y)
Expand All @@ -150,23 +147,7 @@ def test_msel_oracle2():
"""
return trainer, model, observer
"""
task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task")
task.add_domain(
name="domain1",
dset_tr=DsetMNISTColorSoloDefault(0),
dset_val=DsetMNISTColorSoloDefault(1),
)
task.add_domain(
name="domain2",
dset_tr=DsetMNISTColorSoloDefault(2),
dset_val=DsetMNISTColorSoloDefault(3),
)
task.add_domain(
name="domain3",
dset_tr=DsetMNISTColorSoloDefault(4),
dset_val=DsetMNISTColorSoloDefault(5),
)

task = mk_task()
# specify backbone to use
backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
num_final_in = backbone.fc.in_features
Expand All @@ -184,30 +165,8 @@ def test_msel_oracle3():
"""
return trainer, model, observer
"""
task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task")
task.add_domain(
name="domain1",
dset_tr=DsetMNISTColorSoloDefault(0),
dset_val=DsetMNISTColorSoloDefault(1),
)
task.add_domain(
name="domain2",
dset_tr=DsetMNISTColorSoloDefault(2),
dset_val=DsetMNISTColorSoloDefault(3),
)
task.add_domain(
name="domain3",
dset_tr=DsetMNISTColorSoloDefault(4),
dset_val=DsetMNISTColorSoloDefault(5),
)

# specify backbone to use
backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
num_final_in = backbone.fc.in_features
backbone.fc = nn.Linear(num_final_in, task.dim_y)

# specify model to use
model = mk_erm(list_str_y=task.list_str_y)(backbone)
task = mk_task()
model = mk_model(task)

exp = mk_exp(
task,
Expand All @@ -226,30 +185,9 @@ def test_msel_oracle4():
"""
return trainer, model, observer
"""
task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task")
task.add_domain(
name="domain1",
dset_tr=DsetMNISTColorSoloDefault(0),
dset_val=DsetMNISTColorSoloDefault(1),
)
task.add_domain(
name="domain2",
dset_tr=DsetMNISTColorSoloDefault(2),
dset_val=DsetMNISTColorSoloDefault(3),
)
task.add_domain(
name="domain3",
dset_tr=DsetMNISTColorSoloDefault(4),
dset_val=DsetMNISTColorSoloDefault(5),
)

task = mk_task()
model = mk_model(task)
# specify backbone to use
backbone = torchvisionmodels.resnet.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
num_final_in = backbone.fc.in_features
backbone.fc = nn.Linear(num_final_in, task.dim_y)

# specify model to use
model = mk_erm(list_str_y=task.list_str_y)(backbone)
exp = mk_exp(
task,
model,
Expand All @@ -261,5 +199,5 @@ def test_msel_oracle4():
)
exp.execute(num_epochs=2)
exp.trainer.observer.model_sel.msel.best_loss = 0
exp.trainer.observer.model_sel.msel.update(epoch = 1, clear_counter=True)
exp.trainer.observer.model_sel.msel.update(epoch=1, clear_counter=True)
del exp

0 comments on commit c599946

Please sign in to comment.