Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test exp hduva #363

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions domainlab/compos/vae/compos/encoder_xydt_elevator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ def __init__(self, device, topic_dim, zd_dim,
super().__init__(net_infer_zd_topic, net_infer_zx, net_infer_zy)


class XYDTEncoderArgUser(XYDTEncoderElevator):
"""
This class only reimplemented constructor of parent class
"""
@store_args
def __init__(self, net_class_d, net_x, net_class_y):
super().__init__(net_class_d, net_x, net_class_y)


# To remove
class XYDTEncoderConvBnReluPool(XYDTEncoderElevator):
"""
Expand Down
4 changes: 3 additions & 1 deletion domainlab/compos/vae/utils_request_chain_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from domainlab.compos.vae.zoo_vae_builders_classif import (
NodeVAEBuilderArg, NodeVAEBuilderUser, NodeVAEBuilderImgAlex, NodeVAEBuilderImgConvBnPool)
from domainlab.compos.vae.zoo_vae_builders_classif_topic import NodeVAEBuilderImgTopic
from domainlab.compos.vae.zoo_vae_builders_classif_topic import (
NodeVAEBuilderImgTopic, NodeVAEBuilderImgTopicUser)


class VAEChainNodeGetter(object):
Expand All @@ -24,6 +25,7 @@ def __call__(self):
"""
if self.topic_dim is not None:
chain = NodeVAEBuilderImgTopic(None)
chain = NodeVAEBuilderImgTopicUser(chain)
else:
chain = NodeVAEBuilderImgConvBnPool(None)
chain = NodeVAEBuilderImgAlex(chain)
Expand Down
45 changes: 41 additions & 4 deletions domainlab/compos/vae/zoo_vae_builders_classif_topic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
"""
from domainlab.compos.vae.compos.decoder_concat_vec_reshape_conv_gated_conv import \
DecoderConcatLatentFCReshapeConvGatedConv
from domainlab.compos.vae.compos.encoder_xydt_elevator import XYDTEncoderArg
from domainlab.compos.vae.zoo_vae_builders_classif import NodeVAEBuilderArg
from domainlab.compos.vae.compos.encoder_xydt_elevator import (XYDTEncoderArg, XYDTEncoderArgUser)
from domainlab.compos.vae.zoo_vae_builders_classif import NodeVAEBuilderArg, NodeVAEBuilderUser


class NodeVAEBuilderImgTopic(NodeVAEBuilderArg):
Expand All @@ -14,8 +14,7 @@ def is_myjob(self, request):

:param request:
"""
self.args = request.args
flag = True # @FIXME
flag = hasattr(request, "args")
self.config_img(flag, request)
return flag

Expand Down Expand Up @@ -44,3 +43,41 @@ def build_decoder(self, topic_dim):
i_c=self.i_c, i_w=self.i_w,
i_h=self.i_h)
return decoder


class NodeVAEBuilderImgTopicUser(NodeVAEBuilderUser):
"""NodeVAEBuilderImgTopic if user input does not come from command line"""

def is_myjob(self, request):
"""is_myjob.

:param request:
"""
flag = not hasattr(request, "args")
self.request = request
self.config_img(flag, request)
return flag


def build_encoder(self, device, topic_dim):
"""build_encoder.

:param device:
:param topic_dim:
"""
encoder = XYDTEncoderArgUser(self.request.net_class_d,
self.request.net_x,
self.request.net_class_y)
return encoder


def build_decoder(self, topic_dim):
"""build_decoder.

:param topic_dim:
"""
decoder = DecoderConcatLatentFCReshapeConvGatedConv(
z_dim=self.zd_dim+self.zx_dim+self.zy_dim+topic_dim,
i_c=self.i_c, i_w=self.i_w,
i_h=self.i_h)
return decoder
75 changes: 75 additions & 0 deletions tests/test_mk_exp_hduva.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
make an experiment
"""
from domainlab.compos.vae.compos.encoder import LSEncoderConvBnReluPool
from domainlab.mk_exp import mk_exp
from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault
from domainlab.tasks.task_dset import mk_task_dset
from domainlab.models.model_hduva import mk_hduva
from domainlab.tasks.utils_task import ImSize
from domainlab.compos.vae.utils_request_chain_builder import VAEChainNodeGetter
from domainlab.compos.pcr.request import RequestVAEBuilderNN


def test_mk_exp_hduva():
"""
test mk experiment API with "hduva" model and trainers "mldg", "diva"
"""

mk_exp_hduva(trainer="mldg")
mk_exp_hduva(trainer="diva")


def mk_exp_hduva(trainer="mldg"):
"""
execute experiment with "hduva" model and custom trainer
"""

# specify domain generalization task
task = mk_task_dset(isize=ImSize(3, 28, 28), dim_y=10, taskna="custom_task")
task.add_domain(name="domain1",
dset_tr=DsetMNISTColorSoloDefault(0),
dset_val=DsetMNISTColorSoloDefault(1))
task.add_domain(name="domain2",
dset_tr=DsetMNISTColorSoloDefault(2),
dset_val=DsetMNISTColorSoloDefault(3))
task.add_domain(name="domain3",
dset_tr=DsetMNISTColorSoloDefault(4),
dset_val=DsetMNISTColorSoloDefault(5))

# specify backbone to use
zy_dim = 10
zd_dim = 3
list_str_y = [f"class{i}" for i in range(task.dim_y)]
list_d_tr = ["domain2", "domain3"]
gamma_d = 1e5
gamma_y = 7e5
beta_d = 1e3
beta_x = 1e3
beta_y = 1e3
beta_t = 1e3
device = "cpu"
zx_dim = 0
topic_dim = 3
net_class_y = LSEncoderConvBnReluPool(zy_dim,
task.isize.c, task.isize.w, task.isize.h,
conv_stride=1)
net_x = LSEncoderConvBnReluPool(zx_dim,
task.isize.c, task.isize.w, task.isize.h,
conv_stride=1)
# FIXME
net_class_d = LSEncoderConvBnReluPool(zd_dim,
task.isize.c, task.isize.w, task.isize.h,
conv_stride=1)

request = RequestVAEBuilderNN(net_class_d, net_x, net_class_y,
task.isize.c, task.isize.h, task.isize.w)
chain_node_builder = VAEChainNodeGetter(request, topic_dim)()

# specify model to use
model = mk_hduva()(chain_node_builder, zy_dim, zd_dim, list_str_y, list_d_tr, gamma_d, gamma_y,
beta_d, beta_x, beta_y, beta_t, device, zx_dim, topic_dim)

# make trainer for model
exp = mk_exp(task, model, trainer=trainer, test_domain="domain1", batchsize=32)
exp.execute(num_epochs=3)
Loading