diff --git a/domainlab/compos/vae/compos/encoder_xydt_elevator.py b/domainlab/compos/vae/compos/encoder_xydt_elevator.py index 9d7b093e5..9345f9622 100644 --- a/domainlab/compos/vae/compos/encoder_xydt_elevator.py +++ b/domainlab/compos/vae/compos/encoder_xydt_elevator.py @@ -82,6 +82,15 @@ def __init__(self, device, topic_dim, zd_dim, super().__init__(net_infer_zd_topic, net_infer_zx, net_infer_zy) +class XYDTEncoderArgUser(XYDTEncoderElevator): + """ + This class only reimplemented constructor of parent class + """ + @store_args + def __init__(self, net_class_d, net_x, net_class_y): + super().__init__(net_class_d, net_x, net_class_y) + + # To remove class XYDTEncoderConvBnReluPool(XYDTEncoderElevator): """ diff --git a/domainlab/compos/vae/utils_request_chain_builder.py b/domainlab/compos/vae/utils_request_chain_builder.py index c740d5e94..fe5d5b5b6 100644 --- a/domainlab/compos/vae/utils_request_chain_builder.py +++ b/domainlab/compos/vae/utils_request_chain_builder.py @@ -1,6 +1,7 @@ from domainlab.compos.vae.zoo_vae_builders_classif import ( NodeVAEBuilderArg, NodeVAEBuilderUser, NodeVAEBuilderImgAlex, NodeVAEBuilderImgConvBnPool) -from domainlab.compos.vae.zoo_vae_builders_classif_topic import NodeVAEBuilderImgTopic +from domainlab.compos.vae.zoo_vae_builders_classif_topic import ( + NodeVAEBuilderImgTopic, NodeVAEBuilderImgTopicUser) class VAEChainNodeGetter(object): @@ -24,6 +25,7 @@ def __call__(self): """ if self.topic_dim is not None: chain = NodeVAEBuilderImgTopic(None) + chain = NodeVAEBuilderImgTopicUser(chain) else: chain = NodeVAEBuilderImgConvBnPool(None) chain = NodeVAEBuilderImgAlex(chain) diff --git a/domainlab/compos/vae/zoo_vae_builders_classif_topic.py b/domainlab/compos/vae/zoo_vae_builders_classif_topic.py index 5f630c4c0..ec1396a99 100644 --- a/domainlab/compos/vae/zoo_vae_builders_classif_topic.py +++ b/domainlab/compos/vae/zoo_vae_builders_classif_topic.py @@ -3,8 +3,8 @@ """ from domainlab.compos.vae.compos.decoder_concat_vec_reshape_conv_gated_conv import \ DecoderConcatLatentFCReshapeConvGatedConv -from domainlab.compos.vae.compos.encoder_xydt_elevator import XYDTEncoderArg -from domainlab.compos.vae.zoo_vae_builders_classif import NodeVAEBuilderArg +from domainlab.compos.vae.compos.encoder_xydt_elevator import (XYDTEncoderArg, XYDTEncoderArgUser) +from domainlab.compos.vae.zoo_vae_builders_classif import NodeVAEBuilderArg, NodeVAEBuilderUser class NodeVAEBuilderImgTopic(NodeVAEBuilderArg): @@ -14,8 +14,7 @@ def is_myjob(self, request): :param request: """ - self.args = request.args - flag = True # @FIXME + flag = hasattr(request, "args") self.config_img(flag, request) return flag @@ -44,3 +43,41 @@ def build_decoder(self, topic_dim): i_c=self.i_c, i_w=self.i_w, i_h=self.i_h) return decoder + + +class NodeVAEBuilderImgTopicUser(NodeVAEBuilderUser): + """NodeVAEBuilderImgTopic if user input does not come from command line""" + + def is_myjob(self, request): + """is_myjob. + + :param request: + """ + flag = not hasattr(request, "args") + self.request = request + self.config_img(flag, request) + return flag + + + def build_encoder(self, device, topic_dim): + """build_encoder. + + :param device: + :param topic_dim: + """ + encoder = XYDTEncoderArgUser(self.request.net_class_d, + self.request.net_x, + self.request.net_class_y) + return encoder + + + def build_decoder(self, topic_dim): + """build_decoder. + + :param topic_dim: + """ + decoder = DecoderConcatLatentFCReshapeConvGatedConv( + z_dim=self.zd_dim+self.zx_dim+self.zy_dim+topic_dim, + i_c=self.i_c, i_w=self.i_w, + i_h=self.i_h) + return decoder diff --git a/tests/test_mk_exp_hduva.py b/tests/test_mk_exp_hduva.py new file mode 100644 index 000000000..aa6fd8f29 --- /dev/null +++ b/tests/test_mk_exp_hduva.py @@ -0,0 +1,75 @@ +""" +make an experiment +""" +from domainlab.compos.vae.compos.encoder import LSEncoderConvBnReluPool +from domainlab.mk_exp import mk_exp +from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault +from domainlab.tasks.task_dset import mk_task_dset +from domainlab.models.model_hduva import mk_hduva +from domainlab.tasks.utils_task import ImSize +from domainlab.compos.vae.utils_request_chain_builder import VAEChainNodeGetter +from domainlab.compos.pcr.request import RequestVAEBuilderNN + + +def test_mk_exp_hduva(): + """ + test mk experiment API with "hduva" model and trainers "mldg", "diva" + """ + + mk_exp_hduva(trainer="mldg") + mk_exp_hduva(trainer="diva") + + +def mk_exp_hduva(trainer="mldg"): + """ + execute experiment with "hduva" model and custom trainer + """ + + # specify domain generalization task + task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task") + task.add_domain(name="domain1", + dset_tr=DsetMNISTColorSoloDefault(0), + dset_val=DsetMNISTColorSoloDefault(1)) + task.add_domain(name="domain2", + dset_tr=DsetMNISTColorSoloDefault(2), + dset_val=DsetMNISTColorSoloDefault(3)) + task.add_domain(name="domain3", + dset_tr=DsetMNISTColorSoloDefault(4), + dset_val=DsetMNISTColorSoloDefault(5)) + + # specify backbone to use + zy_dim = 10 + zd_dim = 3 + list_str_y = [f"class{i}" for i in range(task.dim_y)] + list_d_tr = ["domain2", "domain3"] + gamma_d = 1e5 + gamma_y = 7e5 + beta_d = 1e3 + beta_x = 1e3 + beta_y = 1e3 + beta_t = 1e3 + device = "cpu" + zx_dim = 0 + topic_dim = 3 + net_class_y = LSEncoderConvBnReluPool(zy_dim, + task.isize.c, task.isize.w, task.isize.h, + conv_stride=1) + net_x = LSEncoderConvBnReluPool(zx_dim, + task.isize.c, task.isize.w, task.isize.h, + conv_stride=1) + # FIXME + net_class_d = LSEncoderConvBnReluPool(zd_dim, + task.isize.c, task.isize.w, task.isize.h, + conv_stride=1) + + request = RequestVAEBuilderNN(net_class_d, net_x, net_class_y, + task.isize.c, task.isize.h, task.isize.w) + chain_node_builder = VAEChainNodeGetter(request, topic_dim)() + + # specify model to use + model = mk_hduva()(chain_node_builder, zy_dim, zd_dim, list_str_y, list_d_tr, gamma_d, gamma_y, + beta_d, beta_x, beta_y, beta_t, device, zx_dim, topic_dim) + + # make trainer for model + exp = mk_exp(task, model, trainer=trainer, test_domain="domain1", batchsize=32) + exp.execute(num_epochs=3)