Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation for "Parameter Exchange for Robust Dynamic Domain Generalization" #62

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ A drawback of Dassl is that it doesn't (yet? hmm) support distributed multi-GPU
We don't provide detailed documentations for Dassl, unlike another [project](https://kaiyangzhou.github.io/deep-person-reid/) of ours. This is because Dassl is developed for research purpose and as a researcher, we think it's important to be able to read source code and we highly encourage you to do so---definitely not because we are lazy. :-)

## What's new
- **[Oct 2023]** A new domain generalization method "[Parameter Exchange for Robust Dynamic Domain Generalization](https://dl.acm.org/doi/10.1145/3581783.3612318)" based on [DDG](https://arxiv.org/abs/2205.13913) in ACM MM'23 is added to this repo. See [here](https://github.com/MetaVisionLab/PE) for more details.
- **[Oct 2022]** New paper "[On-Device Domain Generalization](https://arxiv.org/abs/2209.07521)" is out! Code, models and datasets: https://github.com/KaiyangZhou/on-device-dg.

<details>
Expand Down Expand Up @@ -55,6 +56,7 @@ Dassl has implemented the following methods:
- [Moment Matching for Multi-Source Domain Adaptation (ICCV'19)](https://arxiv.org/abs/1812.01754) [[dassl/engine/da/m3sda.py](dassl/engine/da/m3sda.py)]

- Domain generalization
- [Parameter Exchange for Robust Dynamic Domain Generalization (ACM MM'23)](https://dl.acm.org/doi/10.1145/3581783.3612318) [[dassl/engine/dg/robust_ddg.py](dassl/engine/dg/robust_ddg.py)]
- [Dynamic Domain Generalization (IJCAI'22)](https://arxiv.org/abs/2205.13913) [[dassl/modeling/backbone/resnet_dynamic.py](dassl/modeling/backbone/resnet_dynamic.py)] [[dassl/engine/dg/domain_mix.py](dassl/engine/dg/domain_mix.py)]
- [Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization (CVPR'22)](https://arxiv.org/abs/2203.07740) [[dassl/modeling/ops/efdmix.py](dassl/modeling/ops/efdmix.py)]
- [Domain Generalization with MixStyle (ICLR'21)](https://openreview.net/forum?id=6xHJ37MVxxp) [[dassl/modeling/ops/mixstyle.py](dassl/modeling/ops/mixstyle.py)]
Expand Down
3 changes: 3 additions & 0 deletions dassl/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,9 @@
_C.TRAINER.DOMAINMIX.TYPE = "crossdomain"
_C.TRAINER.DOMAINMIX.ALPHA = 1.0
_C.TRAINER.DOMAINMIX.BETA = 1.0
# RobustDDG
_C.TRAINER.ROBUSTDDG = CN()
_C.TRAINER.ROBUSTDDG.TYPE = 'Cross-Kernel' # Cross-Instance, Cross-Kernel

######
# SSL
Expand Down
1 change: 1 addition & 0 deletions dassl/engine/dg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .vanilla import Vanilla
from .crossgrad import CrossGrad
from .domain_mix import DomainMix
from .robust_ddg import RobustDDG
107 changes: 107 additions & 0 deletions dassl/engine/dg/robust_ddg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import torch
from torch.nn import functional as F
from torch.optim.swa_utils import AveragedModel

from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
from dassl.modeling.ops import Conv2dDynamic

__all__ = ["RobustDDG"]

PE_CI = "Cross-Instance"
PE_CK = "Cross-Kernel"
PE_TYPES = [PE_CI, PE_CK]
TARGET_PE = None
PE_ON = False


def shuffle_column(data):
cdata = data.clone()
B = data.shape[0]
C = data.shape[1]
for i in range(B):
ridxs = torch.randperm(C)
cdata[i] = data[i][ridxs]
return cdata


def pe_forward(self, x, attention_x=None):
attention_x = x if attention_x is None else attention_x
y = self.attention(attention_x)

if PE_ON:
if TARGET_PE == PE_CI:
# CI-PE
rand_idxs = torch.randperm(y.size(0), device=y.device)
y = y[rand_idxs]
elif TARGET_PE == PE_CK:
# CK-PE
y = shuffle_column(y)
else:
raise ValueError(f"Available PEs are:{PE_TYPES}")

out = self.conv(x)

for i, template in enumerate(self.kernel_templates):
out += self.kernel_templates[template](x) * y[:, i].view(-1, 1, 1, 1)

return out


@TRAINER_REGISTRY.register()
class RobustDDG(TrainerX):
"""RobustDDG.

Parameter Exchange for Robust Dynamic Domain Generalization.

https://github.com/MetaVisionLab/PE
"""

def __init__(self, cfg):
super(RobustDDG, self).__init__(cfg)
self.available_backbones = \
["resnet18_dynamic", "resnet50_dynamic", "resnet101_dynamic"]
assert cfg.MODEL.BACKBONE.NAME in self.available_backbones, \
f"PE method supports these backbones: {self.available_backbones}"
self.swa_model = AveragedModel(self.model)
self.register_model("swa_model", self.swa_model, None, None)
# you can change the PE type by setting the TRAINER.ROBUSTDDG.TYPE
assert cfg.TRAINER.ROBUSTDDG.TYPE in PE_TYPES, \
f"Available PEs are:{PE_TYPES}"
global TARGET_PE, PE_ON
TARGET_PE = cfg.TRAINER.ROBUSTDDG.TYPE
PE_ON = False
# inject PE
Conv2dDynamic.forward = pe_forward

def model_inference(self, input):
global PE_ON
PE_ON = False # always False for inference
# use the SWA model for inference
return self.swa_model(input)

def forward_backward(self, batch):
images, labels, _ = self.parse_batch_train(batch)
raw_output = self.model(images)
global PE_ON
PE_ON = True
perturbed_output = self.model(images)
PE_ON = False
loss = F.cross_entropy(raw_output, labels) \
+ F.cross_entropy(perturbed_output, labels)
self.model_backward_and_update(loss)

# update BN statistics for the SWA model
with torch.no_grad():
self.swa_model(images)

loss_summary = {
"loss": loss.item(),
"acc": compute_accuracy(raw_output, labels)[0].item()
}

if (self.batch_idx + 1) == self.num_batches:
self.swa_model.update_parameters(self.model)
self.update_lr()

return loss_summary