Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/cf_bugfix_benchmarkyaml' into cf…
Browse files Browse the repository at this point in the history
…_bugfix_benchmarkyaml
  • Loading branch information
Car-la-F committed Sep 14, 2023
2 parents c277335 + ec12e6f commit 6209bdc
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 10 deletions.
2 changes: 1 addition & 1 deletion domainlab/algos/msels/c_msel_val.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def update(self):
else:
self.es_c += 1
logger = Logger.get_logger()
logger.debug("early stop counter: ", self.es_c)
logger.debug(f"early stop counter: {self.es_c}")
logger.debug(f"val acc:{self.tr_obs.metric_te['acc']}, "
f"best validation acc: {self.best_val_acc}")
flag = False # do not update best model
Expand Down
2 changes: 1 addition & 1 deletion domainlab/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def mk_parser_main():
arg_group_task.add_argument('--san_num', type=int, default=8,
help='number of images to be dumped for the sanity check')

arg_group_task.add_argument('--loglevel', type=str, default='INFO',
arg_group_task.add_argument('--loglevel', type=str, default='DEBUG',
help='sets the loglevel of the logger')

# args for variational auto encoder
Expand Down
5 changes: 5 additions & 0 deletions domainlab/compos/pcr/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ class RequestVAEBuilderCHW():
def __init__(self, i_c, i_h, i_w, args):
pass

class RequestVAEBuilderNN():
"""creates request when input does not come from command-line (args) but from test_exp file"""
@store_args
def __init__(self, net_class_d, net_x, net_class_y, i_c, i_h, i_w):
"""net_class_d, net_x and net_class_y are neural networks defined by the user"""

class RequestTask():
"""
Expand Down
9 changes: 9 additions & 0 deletions domainlab/compos/vae/compos/encoder_xyd_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ def infer_zy_loc(self, tensor):
return zy_loc


class XYDEncoderParallelUser(XYDEncoderParallel):
"""
This class only reimplemented constructor of parent class
"""
@store_args
def __init__(self, net_class_d, net_x, net_class_y):
super().__init__(net_class_d, net_x, net_class_y)


class XYDEncoderParallelConvBnReluPool(XYDEncoderParallel):
"""
This class only reimplemented constructor of parent class
Expand Down
6 changes: 3 additions & 3 deletions domainlab/compos/vae/utils_request_chain_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from domainlab.compos.vae.zoo_vae_builders_classif import (
NodeVAEBuilderArg, NodeVAEBuilderImgAlex, NodeVAEBuilderImgConvBnPool)
from domainlab.compos.vae.zoo_vae_builders_classif_topic import \
NodeVAEBuilderImgTopic
NodeVAEBuilderArg, NodeVAEBuilderUser, NodeVAEBuilderImgAlex, NodeVAEBuilderImgConvBnPool)
from domainlab.compos.vae.zoo_vae_builders_classif_topic import NodeVAEBuilderImgTopic


class VAEChainNodeGetter(object):
Expand Down Expand Up @@ -29,5 +28,6 @@ def __call__(self):
chain = NodeVAEBuilderImgConvBnPool(None)
chain = NodeVAEBuilderImgAlex(chain)
chain = NodeVAEBuilderArg(chain)
chain = NodeVAEBuilderUser(chain)
node = chain.handle(self.request)
return node
24 changes: 20 additions & 4 deletions domainlab/compos/vae/zoo_vae_builders_classif.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
DecoderConcatLatentFCReshapeConvGatedConv
from domainlab.compos.vae.compos.encoder_xyd_parallel import (
XYDEncoderParallelAlex, XYDEncoderParallelConvBnReluPool,
XYDEncoderParallelExtern)
XYDEncoderParallelExtern, XYDEncoderParallelUser)


class ChainNodeVAEBuilderClassifCondPriorBase(
ChainNodeVAEBuilderClassifCondPrior):
ChainNodeVAEBuilderClassifCondPrior):
"""
base class of AE builder
"""
Expand Down Expand Up @@ -40,7 +40,7 @@ def build_encoder(self):
def build_decoder(self):
"""build_decoder."""
decoder = DecoderConcatLatentFCReshapeConvGatedConv(
z_dim=self.zd_dim+self.zx_dim+self.zy_dim,
z_dim=self.zd_dim + self.zx_dim + self.zy_dim,
i_c=self.i_c, i_w=self.i_w,
i_h=self.i_h)
return decoder
Expand Down Expand Up @@ -70,14 +70,30 @@ def build_encoder(self):
return encoder


class NodeVAEBuilderUser(ChainNodeVAEBuilderClassifCondPriorBase):
"""Build encoders according to test_mk_exp file"""

def is_myjob(self, request):
flag = not hasattr(request, "args")
self.request = request
self.config_img(flag, request)
return flag

def build_encoder(self):
encoder = XYDEncoderParallelUser(self.request.net_class_d,
self.request.net_x,
self.request.net_class_y)
return encoder


class NodeVAEBuilderImgConvBnPool(ChainNodeVAEBuilderClassifCondPriorBase):
def is_myjob(self, request):
"""is_myjob.
:param request:
"""
flag = (request.args.nname == "conv_bn_pool_2" or
request.args.nname_dom == "conv_bn_pool_2") # @FIXME
request.args.nname_dom == "conv_bn_pool_2") # @FIXME
self.config_img(flag, request)
return flag

Expand Down
40 changes: 39 additions & 1 deletion domainlab/models/model_diva.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,46 @@

def mk_diva(parent_class=VAEXYDClassif):
"""
DIVA with arbitrary task loss
Instantiate a domain invariant variational autoencoder (DIVA) with arbitrary task loss.
Details:
This method is creating a generative model based on a variational autoencoder, which can
reconstruct the input images. Here for, three different encoders with latent variables are
trained, each representing a latent subspace for the domain, class and residual features
information, respectively. The latent subspaces serve for disentangling the respective
sources of variation. To reconstruct the input image, the three latent variables are fed
into a decoder.
Additionally, two classifiers are trained, which predict the domain and the class label.
For more details, see:
Ilse, Maximilian, et al. "Diva: Domain invariant variational autoencoders."
Medical Imaging with Deep Learning. PMLR, 2020.
Args:
parent_class: Class object determining the task type. Defaults to VAEXYDClassif.
Returns:
ModelDIVA: model inheriting from parent class.
Input Parameters:
zd_dim: size of latent space for domain-specific information,
zy_dim: size of latent space for class-specific information,
zx_dim: size of latent space for residual variance,
chain_node_builder: creates the neural network specified by the user; object of the class
"VAEChainNodeGetter" (see domainlab/compos/vae/utils_request_chain_builder.py)
being initialized by entering a user request,
list_str_y: list of labels,
list_d_tr: list of training domains,
gamma_d: weighting term for d classifier,
gamma_y: weighting term for y classifier,
beta_d: weighting term for domain encoder,
beta_x: weighting term for residual variation encoder,
beta_y: weighting term for class encoder
Usage:
For a concrete example, see:
https://github.com/marrlab/DomainLab/blob/master/tests/test_mk_exp_diva.py
"""

class ModelDIVA(parent_class):
"""
DIVA
Expand Down
68 changes: 68 additions & 0 deletions tests/test_mk_exp_diva.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
make an experiment using "diva" model
"""

from domainlab.mk_exp import mk_exp
from domainlab.dsets.dset_mnist_color_solo_default import DsetMNISTColorSoloDefault
from domainlab.tasks.task_dset import mk_task_dset
from domainlab.models.model_diva import mk_diva
from domainlab.tasks.utils_task import ImSize
from domainlab.compos.vae.utils_request_chain_builder import VAEChainNodeGetter
from domainlab.compos.pcr.request import RequestVAEBuilderNN
from domainlab.compos.vae.compos.encoder import LSEncoderConvBnReluPool


def test_mk_exp_diva():
"""
test mk experiment API for "diva" model and trainers "mldg", "dial"
"""
mk_exp_diva(trainer="mldg")
mk_exp_diva(trainer="dial")


def mk_exp_diva(trainer="mldg"):
"""
execute experiment with "diva" model and custom trainer
"""

# specify domain generalization task
task = mk_task_dset(dim_y=10, isize=ImSize(3, 28, 28), taskna="custom_task")
task.add_domain(name="domain1",
dset_tr=DsetMNISTColorSoloDefault(0),
dset_val=DsetMNISTColorSoloDefault(1))
task.add_domain(name="domain2",
dset_tr=DsetMNISTColorSoloDefault(2),
dset_val=DsetMNISTColorSoloDefault(3))
task.add_domain(name="domain3",
dset_tr=DsetMNISTColorSoloDefault(4),
dset_val=DsetMNISTColorSoloDefault(5))

# specify parameters
list_str_y = [f"class{i}" for i in range(task.dim_y)]
list_d_tr = ["domain2", "domain3"]
zd_dim = 3
zy_dim = 10
zx_dim = 30
gamma_d = 1e5
gamma_y = 7e5
beta_d = 1e3
beta_x = 1e3
beta_y = 1e3
net_class_d = LSEncoderConvBnReluPool(
zd_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1)
net_x = LSEncoderConvBnReluPool(
zx_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1)
net_class_y = LSEncoderConvBnReluPool(
zy_dim, task.isize.c, task.isize.w, task.isize.h, conv_stride=1)

request = RequestVAEBuilderNN(net_class_d, net_x, net_class_y,
task.isize.c, task.isize.h, task.isize.w)
chain_node_builder = VAEChainNodeGetter(request)()

# specify model to use
model = mk_diva()(chain_node_builder, zd_dim, zy_dim, zx_dim, list_str_y, list_d_tr, gamma_d,
gamma_y, beta_d, beta_x, beta_y)

# make trainer for model
exp = mk_exp(task, model, trainer=trainer, test_domain="domain1", batchsize=32)
exp.execute(num_epochs=3)

0 comments on commit 6209bdc

Please sign in to comment.