From 02265eac70b58d8cac65b2f2b615fedf1994e03c Mon Sep 17 00:00:00 2001 From: abse4411 Date: Fri, 3 Nov 2023 18:34:05 +0800 Subject: [PATCH] Add the method "Parameter Exchange for Robust Dynamic Domain Generalization" --- README.md | 2 + dassl/config/defaults.py | 3 + dassl/engine/dg/__init__.py | 1 + dassl/engine/dg/robust_ddg.py | 107 ++++++++++++++++++++++++++++++++++ 4 files changed, 113 insertions(+) create mode 100644 dassl/engine/dg/robust_ddg.py diff --git a/README.md b/README.md index 6f3ec6f..63c7e8d 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ A drawback of Dassl is that it doesn't (yet? hmm) support distributed multi-GPU We don't provide detailed documentations for Dassl, unlike another [project](https://kaiyangzhou.github.io/deep-person-reid/) of ours. This is because Dassl is developed for research purpose and as a researcher, we think it's important to be able to read source code and we highly encourage you to do so---definitely not because we are lazy. :-) ## What's new +- **[Oct 2023]** A new domain generalization method "[Parameter Exchange for Robust Dynamic Domain Generalization](https://dl.acm.org/doi/10.1145/3581783.3612318)" based on [DDG](https://arxiv.org/abs/2205.13913) in ACM MM'23 is added to this repo. See [here](https://github.com/MetaVisionLab/PE) for more details. - **[Oct 2022]** New paper "[On-Device Domain Generalization](https://arxiv.org/abs/2209.07521)" is out! Code, models and datasets: https://github.com/KaiyangZhou/on-device-dg.
@@ -55,6 +56,7 @@ Dassl has implemented the following methods: - [Moment Matching for Multi-Source Domain Adaptation (ICCV'19)](https://arxiv.org/abs/1812.01754) [[dassl/engine/da/m3sda.py](dassl/engine/da/m3sda.py)] - Domain generalization + - [Parameter Exchange for Robust Dynamic Domain Generalization (ACM MM'23)](https://dl.acm.org/doi/10.1145/3581783.3612318) [[dassl/engine/dg/robust_ddg.py](dassl/engine/dg/robust_ddg.py)] - [Dynamic Domain Generalization (IJCAI'22)](https://arxiv.org/abs/2205.13913) [[dassl/modeling/backbone/resnet_dynamic.py](dassl/modeling/backbone/resnet_dynamic.py)] [[dassl/engine/dg/domain_mix.py](dassl/engine/dg/domain_mix.py)] - [Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization (CVPR'22)](https://arxiv.org/abs/2203.07740) [[dassl/modeling/ops/efdmix.py](dassl/modeling/ops/efdmix.py)] - [Domain Generalization with MixStyle (ICLR'21)](https://openreview.net/forum?id=6xHJ37MVxxp) [[dassl/modeling/ops/mixstyle.py](dassl/modeling/ops/mixstyle.py)] diff --git a/dassl/config/defaults.py b/dassl/config/defaults.py index cd873e9..3b12672 100644 --- a/dassl/config/defaults.py +++ b/dassl/config/defaults.py @@ -281,6 +281,9 @@ _C.TRAINER.DOMAINMIX.TYPE = "crossdomain" _C.TRAINER.DOMAINMIX.ALPHA = 1.0 _C.TRAINER.DOMAINMIX.BETA = 1.0 +# RobustDDG +_C.TRAINER.ROBUSTDDG = CN() +_C.TRAINER.ROBUSTDDG.TYPE = 'Cross-Kernel' # Cross-Instance, Cross-Kernel ###### # SSL diff --git a/dassl/engine/dg/__init__.py b/dassl/engine/dg/__init__.py index 23146a4..51b5be6 100644 --- a/dassl/engine/dg/__init__.py +++ b/dassl/engine/dg/__init__.py @@ -3,3 +3,4 @@ from .vanilla import Vanilla from .crossgrad import CrossGrad from .domain_mix import DomainMix +from .robust_ddg import RobustDDG diff --git a/dassl/engine/dg/robust_ddg.py b/dassl/engine/dg/robust_ddg.py new file mode 100644 index 0000000..12cc234 --- /dev/null +++ b/dassl/engine/dg/robust_ddg.py @@ -0,0 +1,107 @@ +import torch +from torch.nn import functional as F +from torch.optim.swa_utils import AveragedModel + +from dassl.engine import TRAINER_REGISTRY, TrainerX +from dassl.metrics import compute_accuracy +from dassl.modeling.ops import Conv2dDynamic + +__all__ = ["RobustDDG"] + +PE_CI = "Cross-Instance" +PE_CK = "Cross-Kernel" +PE_TYPES = [PE_CI, PE_CK] +TARGET_PE = None +PE_ON = False + + +def shuffle_column(data): + cdata = data.clone() + B = data.shape[0] + C = data.shape[1] + for i in range(B): + ridxs = torch.randperm(C) + cdata[i] = data[i][ridxs] + return cdata + + +def pe_forward(self, x, attention_x=None): + attention_x = x if attention_x is None else attention_x + y = self.attention(attention_x) + + if PE_ON: + if TARGET_PE == PE_CI: + # CI-PE + rand_idxs = torch.randperm(y.size(0), device=y.device) + y = y[rand_idxs] + elif TARGET_PE == PE_CK: + # CK-PE + y = shuffle_column(y) + else: + raise ValueError(f"Available PEs are:{PE_TYPES}") + + out = self.conv(x) + + for i, template in enumerate(self.kernel_templates): + out += self.kernel_templates[template](x) * y[:, i].view(-1, 1, 1, 1) + + return out + + +@TRAINER_REGISTRY.register() +class RobustDDG(TrainerX): + """RobustDDG. + + Parameter Exchange for Robust Dynamic Domain Generalization. + + https://github.com/MetaVisionLab/PE + """ + + def __init__(self, cfg): + super(RobustDDG, self).__init__(cfg) + self.available_backbones = \ + ["resnet18_dynamic", "resnet50_dynamic", "resnet101_dynamic"] + assert cfg.MODEL.BACKBONE.NAME in self.available_backbones, \ + f"PE method supports these backbones: {self.available_backbones}" + self.swa_model = AveragedModel(self.model) + self.register_model("swa_model", self.swa_model, None, None) + # you can change the PE type by setting the TRAINER.ROBUSTDDG.TYPE + assert cfg.TRAINER.ROBUSTDDG.TYPE in PE_TYPES, \ + f"Available PEs are:{PE_TYPES}" + global TARGET_PE, PE_ON + TARGET_PE = cfg.TRAINER.ROBUSTDDG.TYPE + PE_ON = False + # inject PE + Conv2dDynamic.forward = pe_forward + + def model_inference(self, input): + global PE_ON + PE_ON = False # always False for inference + # use the SWA model for inference + return self.swa_model(input) + + def forward_backward(self, batch): + images, labels, _ = self.parse_batch_train(batch) + raw_output = self.model(images) + global PE_ON + PE_ON = True + perturbed_output = self.model(images) + PE_ON = False + loss = F.cross_entropy(raw_output, labels) \ + + F.cross_entropy(perturbed_output, labels) + self.model_backward_and_update(loss) + + # update BN statistics for the SWA model + with torch.no_grad(): + self.swa_model(images) + + loss_summary = { + "loss": loss.item(), + "acc": compute_accuracy(raw_output, labels)[0].item() + } + + if (self.batch_idx + 1) == self.num_batches: + self.swa_model.update_parameters(self.model) + self.update_lr() + + return loss_summary