diff --git a/domainlab/compos/pcr/request.py b/domainlab/compos/pcr/request.py index f883844f1..f7aba101b 100644 --- a/domainlab/compos/pcr/request.py +++ b/domainlab/compos/pcr/request.py @@ -6,6 +6,11 @@ class RequestVAEBuilderCHW(): def __init__(self, i_c, i_h, i_w, args): pass +class RequestVAEBuilderNN(): + """creates request when input does not come from command-line (args) but from test_exp file""" + @store_args + def __init__(self, net_class_d, net_x, net_class_y, i_c, i_h, i_w): + """net_class_d, net_x and net_class_y are neural networks defined by the user""" class RequestTask(): """ diff --git a/domainlab/compos/vae/compos/encoder_xyd_parallel.py b/domainlab/compos/vae/compos/encoder_xyd_parallel.py index 91afb4cc5..0544e9a2e 100644 --- a/domainlab/compos/vae/compos/encoder_xyd_parallel.py +++ b/domainlab/compos/vae/compos/encoder_xyd_parallel.py @@ -37,6 +37,15 @@ def infer_zy_loc(self, tensor): return zy_loc +class XYDEncoderParallelUser(XYDEncoderParallel): + """ + This class only reimplemented constructor of parent class + """ + @store_args + def __init__(self, net_class_d, net_x, net_class_y): + super().__init__(net_class_d, net_x, net_class_y) + + class XYDEncoderParallelConvBnReluPool(XYDEncoderParallel): """ This class only reimplemented constructor of parent class diff --git a/domainlab/compos/vae/utils_request_chain_builder.py b/domainlab/compos/vae/utils_request_chain_builder.py index b02406cb4..c740d5e94 100644 --- a/domainlab/compos/vae/utils_request_chain_builder.py +++ b/domainlab/compos/vae/utils_request_chain_builder.py @@ -1,7 +1,6 @@ from domainlab.compos.vae.zoo_vae_builders_classif import ( - NodeVAEBuilderArg, NodeVAEBuilderImgAlex, NodeVAEBuilderImgConvBnPool) -from domainlab.compos.vae.zoo_vae_builders_classif_topic import \ - NodeVAEBuilderImgTopic + NodeVAEBuilderArg, NodeVAEBuilderUser, NodeVAEBuilderImgAlex, NodeVAEBuilderImgConvBnPool) +from domainlab.compos.vae.zoo_vae_builders_classif_topic import NodeVAEBuilderImgTopic class VAEChainNodeGetter(object): @@ -29,5 +28,6 @@ def __call__(self): chain = NodeVAEBuilderImgConvBnPool(None) chain = NodeVAEBuilderImgAlex(chain) chain = NodeVAEBuilderArg(chain) + chain = NodeVAEBuilderUser(chain) node = chain.handle(self.request) return node diff --git a/domainlab/compos/vae/zoo_vae_builders_classif.py b/domainlab/compos/vae/zoo_vae_builders_classif.py index db3298dc2..1adf43c7e 100644 --- a/domainlab/compos/vae/zoo_vae_builders_classif.py +++ b/domainlab/compos/vae/zoo_vae_builders_classif.py @@ -7,11 +7,11 @@ DecoderConcatLatentFCReshapeConvGatedConv from domainlab.compos.vae.compos.encoder_xyd_parallel import ( XYDEncoderParallelAlex, XYDEncoderParallelConvBnReluPool, - XYDEncoderParallelExtern) + XYDEncoderParallelExtern, XYDEncoderParallelUser) class ChainNodeVAEBuilderClassifCondPriorBase( - ChainNodeVAEBuilderClassifCondPrior): + ChainNodeVAEBuilderClassifCondPrior): """ base class of AE builder """ @@ -40,7 +40,7 @@ def build_encoder(self): def build_decoder(self): """build_decoder.""" decoder = DecoderConcatLatentFCReshapeConvGatedConv( - z_dim=self.zd_dim+self.zx_dim+self.zy_dim, + z_dim=self.zd_dim + self.zx_dim + self.zy_dim, i_c=self.i_c, i_w=self.i_w, i_h=self.i_h) return decoder @@ -70,6 +70,22 @@ def build_encoder(self): return encoder +class NodeVAEBuilderUser(ChainNodeVAEBuilderClassifCondPriorBase): + """Build encoders according to test_mk_exp file""" + + def is_myjob(self, request): + flag = not hasattr(request, "args") + self.request = request + self.config_img(flag, request) + return flag + + def build_encoder(self): + encoder = XYDEncoderParallelUser(self.request.net_class_d, + self.request.net_x, + self.request.net_class_y) + return encoder + + class NodeVAEBuilderImgConvBnPool(ChainNodeVAEBuilderClassifCondPriorBase): def is_myjob(self, request): """is_myjob. @@ -77,7 +93,7 @@ def is_myjob(self, request): :param request: """ flag = (request.args.nname == "conv_bn_pool_2" or - request.args.nname_dom == "conv_bn_pool_2") # @FIXME + request.args.nname_dom == "conv_bn_pool_2") # @FIXME self.config_img(flag, request) return flag diff --git a/tests/test_mk_exp_diva.py b/tests/test_mk_exp_diva.py new file mode 100644 index 000000000..83db8ae6a --- /dev/null +++ b/tests/test_mk_exp_diva.py @@ -0,0 +1,68 @@ +""" +make an experiment using "diva" model +""" + +from domainlab.mk_exp import mk_exp +from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault +from domainlab.tasks.task_dset import mk_task_dset +from domainlab.models.model_diva import mk_diva +from domainlab.tasks.utils_task import ImSize +from domainlab.compos.vae.utils_request_chain_builder import VAEChainNodeGetter +from domainlab.compos.pcr.request import RequestVAEBuilderNN +from domainlab.compos.vae.compos.encoder import LSEncoderConvBnReluPool + + +def test_mk_exp_diva(): + """ + test mk experiment API for "diva" model and trainers "mldg", "dial" + """ + mk_exp_diva(trainer="mldg") + mk_exp_diva(trainer="dial") + + +def mk_exp_diva(trainer="mldg"): + """ + execute experiment with "diva" model and custom trainer + """ + + # specify domain generalization task + task = mk_task_dset(dim_y=10, isize=ImSize(3, 28, 28), taskna="custom_task") + task.add_domain(name="domain1", + dset_tr=DsetMNISTColorSoloDefault(0), + dset_val=DsetMNISTColorSoloDefault(1)) + task.add_domain(name="domain2", + dset_tr=DsetMNISTColorSoloDefault(2), + dset_val=DsetMNISTColorSoloDefault(3)) + task.add_domain(name="domain3", + dset_tr=DsetMNISTColorSoloDefault(4), + dset_val=DsetMNISTColorSoloDefault(5)) + + # specify parameters + list_str_y = [f"class{i}" for i in range(task.dim_y)] + list_d_tr = ["domain2", "domain3"] + zd_dim = 3 + zy_dim = 10 + zx_dim = 30 + gamma_d = 1e5 + gamma_y = 7e5 + beta_d = 1e3 + beta_x = 1e3 + beta_y = 1e3 + net_class_d = LSEncoderConvBnReluPool( + zd_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1) + net_x = LSEncoderConvBnReluPool( + zx_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1) + net_class_y = LSEncoderConvBnReluPool( + zy_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1) + + request = RequestVAEBuilderNN(net_class_d, net_x, net_class_y, + task.isize.c, task.isize.h, task.isize.w) + chain_node_builder = VAEChainNodeGetter(request)() + + # specify model to use + model = mk_diva()(chain_node_builder, zd_dim, zy_dim, zx_dim, list_str_y, list_d_tr, gamma_d, + gamma_y, beta_d, beta_x, beta_y) + + # make trainer for model + exp = mk_exp(task, model, trainer=trainer, test_domain="domain1", batchsize=32) + exp.execute(num_epochs=3)