Skip to content

Commit

Permalink
Merge pull request #350 from marrlab/lb_fix_args
Browse files Browse the repository at this point in the history
fix chain_node_builder and diva test func
  • Loading branch information
smilesun authored Sep 11, 2023
2 parents 1420953 + 9a74bf8 commit 5d2bdc7
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 7 deletions.
5 changes: 5 additions & 0 deletions domainlab/compos/pcr/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ class RequestVAEBuilderCHW():
def __init__(self, i_c, i_h, i_w, args):
pass

class RequestVAEBuilderNN():
"""creates request when input does not come from command-line (args) but from test_exp file"""
@store_args
def __init__(self, net_class_d, net_x, net_class_y, i_c, i_h, i_w):
"""net_class_d, net_x and net_class_y are neural networks defined by the user"""

class RequestTask():
"""
Expand Down
9 changes: 9 additions & 0 deletions domainlab/compos/vae/compos/encoder_xyd_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ def infer_zy_loc(self, tensor):
return zy_loc


class XYDEncoderParallelUser(XYDEncoderParallel):
"""
This class only reimplemented constructor of parent class
"""
@store_args
def __init__(self, net_class_d, net_x, net_class_y):
super().__init__(net_class_d, net_x, net_class_y)


class XYDEncoderParallelConvBnReluPool(XYDEncoderParallel):
"""
This class only reimplemented constructor of parent class
Expand Down
6 changes: 3 additions & 3 deletions domainlab/compos/vae/utils_request_chain_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from domainlab.compos.vae.zoo_vae_builders_classif import (
NodeVAEBuilderArg, NodeVAEBuilderImgAlex, NodeVAEBuilderImgConvBnPool)
from domainlab.compos.vae.zoo_vae_builders_classif_topic import \
NodeVAEBuilderImgTopic
NodeVAEBuilderArg, NodeVAEBuilderUser, NodeVAEBuilderImgAlex, NodeVAEBuilderImgConvBnPool)
from domainlab.compos.vae.zoo_vae_builders_classif_topic import NodeVAEBuilderImgTopic


class VAEChainNodeGetter(object):
Expand Down Expand Up @@ -29,5 +28,6 @@ def __call__(self):
chain = NodeVAEBuilderImgConvBnPool(None)
chain = NodeVAEBuilderImgAlex(chain)
chain = NodeVAEBuilderArg(chain)
chain = NodeVAEBuilderUser(chain)
node = chain.handle(self.request)
return node
24 changes: 20 additions & 4 deletions domainlab/compos/vae/zoo_vae_builders_classif.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
DecoderConcatLatentFCReshapeConvGatedConv
from domainlab.compos.vae.compos.encoder_xyd_parallel import (
XYDEncoderParallelAlex, XYDEncoderParallelConvBnReluPool,
XYDEncoderParallelExtern)
XYDEncoderParallelExtern, XYDEncoderParallelUser)


class ChainNodeVAEBuilderClassifCondPriorBase(
ChainNodeVAEBuilderClassifCondPrior):
ChainNodeVAEBuilderClassifCondPrior):
"""
base class of AE builder
"""
Expand Down Expand Up @@ -40,7 +40,7 @@ def build_encoder(self):
def build_decoder(self):
"""build_decoder."""
decoder = DecoderConcatLatentFCReshapeConvGatedConv(
z_dim=self.zd_dim+self.zx_dim+self.zy_dim,
z_dim=self.zd_dim + self.zx_dim + self.zy_dim,
i_c=self.i_c, i_w=self.i_w,
i_h=self.i_h)
return decoder
Expand Down Expand Up @@ -70,14 +70,30 @@ def build_encoder(self):
return encoder


class NodeVAEBuilderUser(ChainNodeVAEBuilderClassifCondPriorBase):
"""Build encoders according to test_mk_exp file"""

def is_myjob(self, request):
flag = not hasattr(request, "args")
self.request = request
self.config_img(flag, request)
return flag

def build_encoder(self):
encoder = XYDEncoderParallelUser(self.request.net_class_d,
self.request.net_x,
self.request.net_class_y)
return encoder


class NodeVAEBuilderImgConvBnPool(ChainNodeVAEBuilderClassifCondPriorBase):
def is_myjob(self, request):
"""is_myjob.
:param request:
"""
flag = (request.args.nname == "conv_bn_pool_2" or
request.args.nname_dom == "conv_bn_pool_2") # @FIXME
request.args.nname_dom == "conv_bn_pool_2") # @FIXME
self.config_img(flag, request)
return flag

Expand Down
68 changes: 68 additions & 0 deletions tests/test_mk_exp_diva.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
make an experiment using "diva" model
"""

from domainlab.mk_exp import mk_exp
from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault
from domainlab.tasks.task_dset import mk_task_dset
from domainlab.models.model_diva import mk_diva
from domainlab.tasks.utils_task import ImSize
from domainlab.compos.vae.utils_request_chain_builder import VAEChainNodeGetter
from domainlab.compos.pcr.request import RequestVAEBuilderNN
from domainlab.compos.vae.compos.encoder import LSEncoderConvBnReluPool


def test_mk_exp_diva():
"""
test mk experiment API for "diva" model and trainers "mldg", "dial"
"""
mk_exp_diva(trainer="mldg")
mk_exp_diva(trainer="dial")


def mk_exp_diva(trainer="mldg"):
"""
execute experiment with "diva" model and custom trainer
"""

# specify domain generalization task
task = mk_task_dset(dim_y=10, isize=ImSize(3, 28, 28), taskna="custom_task")
task.add_domain(name="domain1",
dset_tr=DsetMNISTColorSoloDefault(0),
dset_val=DsetMNISTColorSoloDefault(1))
task.add_domain(name="domain2",
dset_tr=DsetMNISTColorSoloDefault(2),
dset_val=DsetMNISTColorSoloDefault(3))
task.add_domain(name="domain3",
dset_tr=DsetMNISTColorSoloDefault(4),
dset_val=DsetMNISTColorSoloDefault(5))

# specify parameters
list_str_y = [f"class{i}" for i in range(task.dim_y)]
list_d_tr = ["domain2", "domain3"]
zd_dim = 3
zy_dim = 10
zx_dim = 30
gamma_d = 1e5
gamma_y = 7e5
beta_d = 1e3
beta_x = 1e3
beta_y = 1e3
net_class_d = LSEncoderConvBnReluPool(
zd_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1)
net_x = LSEncoderConvBnReluPool(
zx_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1)
net_class_y = LSEncoderConvBnReluPool(
zy_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1)

request = RequestVAEBuilderNN(net_class_d, net_x, net_class_y,
task.isize.c, task.isize.h, task.isize.w)
chain_node_builder = VAEChainNodeGetter(request)()

# specify model to use
model = mk_diva()(chain_node_builder, zd_dim, zy_dim, zx_dim, list_str_y, list_d_tr, gamma_d,
gamma_y, beta_d, beta_x, beta_y)

# make trainer for model
exp = mk_exp(task, model, trainer=trainer, test_domain="domain1", batchsize=32)
exp.execute(num_epochs=3)

0 comments on commit 5d2bdc7

Please sign in to comment.