Skip to content

Commit

Permalink
update domain adaptation classification; add digits to classification
Browse files Browse the repository at this point in the history
  • Loading branch information
JunguangJiang committed Sep 21, 2021
1 parent cb9f8c9 commit a6d9a2b
Show file tree
Hide file tree
Showing 37 changed files with 680 additions and 1,159 deletions.
1 change: 1 addition & 0 deletions common/vision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .coco70 import COCO70
from .oxfordpet import OxfordIIITPet
from .pacs import PACS
from .digits import *

__all__ = ['ImageList', 'Office31', 'OfficeHome', "VisDA2017", "OfficeCaltech", "DomainNet", "ImageNetR", "ImageNetSketch",
"Aircraft", "cub200", "StanfordCars", "StanfordDogs", "COCO70", "OxfordIIITPet", "PACS"]
19 changes: 17 additions & 2 deletions common/vision/datasets/digits.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class SVHN(D.SVHN):
downloaded again.
"""
def __init__(self, root, mode="RGB", **kwargs):
def __init__(self, root, mode="L", **kwargs):
super(SVHN, self).__init__(root, **kwargs)
assert mode in ['L', 'RGB']
self.mode = mode
Expand All @@ -144,4 +144,19 @@ def __getitem__(self, index: int) -> Tuple[Any, Any]:
if self.target_transform is not None:
target = self.target_transform(target)

return img, target
return img, target


class MNISTRGB(MNIST):
def __init__(self, root, **kwargs):
super(MNISTRGB, self).__init__(root, mode='RGB', **kwargs)


class USPSRGB(USPS):
def __init__(self, root, **kwargs):
super(USPSRGB, self).__init__(root, mode='RGB', **kwargs)


class SVHNRGB(SVHN):
def __init__(self, root, **kwargs):
super(SVHNRGB, self).__init__(root, mode='RGB', **kwargs)
1 change: 1 addition & 0 deletions common/vision/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .resnet import *
from .ibn import *
from .digits import *

__all__ = ['resnet', 'digits', 'ibn']
102 changes: 34 additions & 68 deletions common/vision/models/digits.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,58 @@
import torch.nn as nn

class LeNet:
def __init__(self, num_classes=10):
self.num_classes = num_classes
self.bottleneck_dim = 50 * 4 * 4

def backbone(self):
return nn.Sequential(
class LeNet(nn.Sequential):
def __init__(self, num_classes=10):
super(LeNet, self).__init__(
nn.Conv2d(1, 20, kernel_size=5),
nn.MaxPool2d(2),
nn.ReLU(),
nn.Conv2d(20, 50, kernel_size=5),
nn.Dropout2d(p=0.5),
nn.MaxPool2d(2),
nn.ReLU(),
)

def bottleneck(self):
return nn.Flatten(start_dim=1)

def head(self):
return nn.Sequential(
nn.Linear(self.bottleneck_dim, 500),
nn.Flatten(start_dim=1),
nn.Linear(50 * 4 * 4, 500),
nn.ReLU(),
nn.Dropout(p=0.5),
nn.Linear(500, self.num_classes)
)
self.num_classes = num_classes
self.out_features = 500

def complete(self):
return nn.Sequential(
self.backbone(),
self.bottleneck(),
self.head()
)
def copy_head(self):
return nn.Linear(500, self.num_classes)


class DTN:
class DTN(nn.Sequential):
def __init__(self, num_classes=10):
self.num_classes = num_classes
self.bottleneck_dim = 256 * 4 * 4

def backbone(self):
return nn.Sequential(
nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(64),
nn.Dropout2d(0.1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(128),
nn.Dropout2d(0.3),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(256),
nn.Dropout2d(0.5),
nn.ReLU(),
super(DTN, self).__init__(
nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(64),
nn.Dropout2d(0.1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(128),
nn.Dropout2d(0.3),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
nn.BatchNorm2d(256),
nn.Dropout2d(0.5),
nn.ReLU(),
nn.Flatten(start_dim=1),
nn.Linear(256 * 4 * 4, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(),
)
self.num_classes = num_classes
self.out_features = 512

def bottleneck(self):
return nn.Flatten(start_dim=1)

def head(self):
return nn.Sequential(
nn.Linear(self.bottleneck_dim, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Dropout(),
nn.Linear(512, self.num_classes)
)
def copy_head(self):
return nn.Linear(512, self.num_classes)

def complete(self):
return nn.Sequential(
self.backbone(),
self.bottleneck(),
self.head()
)


def lenet(**kwargs):
def lenet(pretrained=False, **kwargs):
"""LeNet model from
`"Gradient-based learning applied to document recognition" <http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf>`_
Expand All @@ -86,16 +62,11 @@ def lenet(**kwargs):
.. note::
The input image size must be 28 x 28.
Examples::
>>> # Get the whole LeNet model
>>> model = lenet().complete()
>>> # Or combine it by yourself
>>> model = nn.Sequential(lenet().backbone(), lenet().bottleneck(), lenet().head())
"""
return LeNet(**kwargs)


def dtn(**kwargs):
def dtn(pretrained=False, **kwargs):
""" DTN model
Args:
Expand All @@ -104,10 +75,5 @@ def dtn(**kwargs):
.. note::
The input image size must be 32 x 32.
Examples::
>>> # Get the whole DTN model
>>> model = dtn().complete()
>>> # Or combine it by yourself
>>> model = nn.Sequential(dtn().backbone(), dtn().bottleneck(), dtn().head())
"""
return DTN(**kwargs)
7 changes: 4 additions & 3 deletions dalib/adaptation/mcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,13 @@ def __init__(self, in_features: int, num_classes: int, bottleneck_dim: Optional[
super(ImageClassifierHead, self).__init__()
self.num_classes = num_classes
if pool_layer is None:
pool_layer = nn.Sequential(
self.pool_layer = nn.Sequential(
nn.AdaptiveAvgPool2d(output_size=(1, 1)),
nn.Flatten()
)
else:
self.pool_layer = pool_layer
self.head = nn.Sequential(
pool_layer,
nn.Dropout(0.5),
nn.Linear(in_features, bottleneck_dim),
nn.BatchNorm1d(bottleneck_dim),
Expand All @@ -74,4 +75,4 @@ def __init__(self, in_features: int, num_classes: int, bottleneck_dim: Optional[
)

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return self.head(inputs)
return self.head(self.pool_layer(inputs))
56 changes: 2 additions & 54 deletions dalib/adaptation/mdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ def shift_log(x: torch.Tensor, offset: Optional[float] = 1e-6) -> torch.Tensor:

class GeneralModule(nn.Module):
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck: nn.Module,
head: nn.Module, adv_head: nn.Module,
grl: Optional[WarmStartGradientReverseLayer] = None, finetune: Optional[bool] = True):
head: nn.Module, adv_head: nn.Module, grl: Optional[WarmStartGradientReverseLayer] = None,
finetune: Optional[bool] = True):
super(GeneralModule, self).__init__()
self.backbone = backbone
self.num_classes = num_classes
Expand Down Expand Up @@ -442,55 +442,3 @@ def __init__(self, backbone: nn.Module, num_factors: int, bottleneck = None, hea
super(ImageRegressor, self).__init__(backbone, num_factors, bottleneck,
head, adv_head, grl_layer, finetune)
self.num_factors = num_factors


class SequenceClassifier(GeneralModule):

def __init__(self, backbone: nn.Module, num_classes: int,
bottleneck_dim: Optional[int] = 1024, width: Optional[int] = 1024,
grl: Optional[WarmStartGradientReverseLayer] = None, finetune=True):
grl_layer = WarmStartGradientReverseLayer(alpha=1.0, lo=0.0, hi=0.1, max_iters=1000,
auto_step=False) if grl is None else grl

# bottleneck = nn.Sequential(
# nn.Linear(backbone.out_features, bottleneck_dim),
# nn.BatchNorm1d(bottleneck_dim),
# nn.ReLU(),
# nn.Dropout(0.5)
# )
# bottleneck[0].weight.data.normal_(0, 0.005)
# bottleneck[0].bias.data.fill_(0.1)
bottleneck = nn.Identity()
bottleneck_dim = backbone.out_features

# The classifier head used for final predictions.
head = nn.Sequential(
nn.Linear(bottleneck_dim, width),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(width, num_classes)
)
# The adversarial classifier head
adv_head = nn.Sequential(
nn.Linear(bottleneck_dim, width),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(width, num_classes)
)
for dep in range(2):
head[dep * 3].weight.data.normal_(0, 0.01)
head[dep * 3].bias.data.fill_(0.0)
adv_head[dep * 3].weight.data.normal_(0, 0.01)
adv_head[dep * 3].bias.data.fill_(0.0)
super(SequenceClassifier, self).__init__(backbone, num_classes, bottleneck,
head, adv_head, grl_layer, finetune)

def forward(self, *args, **kwargs):
""""""
hidden_state = self.backbone(*args, **kwargs)[0] # (bs, seq_len, dim)
pooled_output = hidden_state[:, 0] # (bs, dim)
features = self.bottleneck(pooled_output)
outputs = self.head(features)
features_adv = self.grl_layer(features)
outputs_adv = self.adv_head(features_adv)
return outputs, outputs_adv
34 changes: 1 addition & 33 deletions dalib/adaptation/regda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,10 @@
import torch.nn as nn
import numpy as np

from ..modules.gl import WarmStartGradientLayer
from dalib.modules.gl import WarmStartGradientLayer
from common.utils.metric.keypoint_detection import get_max_preds


class LabelGenerator1d:

def __init__(self, width=64, sigma=2):
self.width = width
self.sigma = sigma
heatmaps = np.zeros((width, width), dtype=np.float32)
tmp_size = sigma * 3
for mu_x in range(width):
for mu_y in range(width):
# Check that any part of the gaussian is in-bounds
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]

# Generate gaussian
size = 2 * tmp_size + 1
x = np.arange(0, size, 1, np.float32)
x0 = size // 2
# The gaussian is not normalized, we want the center value to equal 1
g = np.exp(- ((x - x0) ** 2) / (2 * sigma ** 2))

# Usable gaussian range
g_x = max(0, -ul[0]), min(br[0], width) - ul[0]
# Image range
img_x = max(0, ul[0]), min(br[0], width)

heatmaps[mu_x][img_x[0]:img_x[1]] = g[g_x[0]:g_x[1]]
self.heatmaps = heatmaps

def __call__(self, y):
return torch.from_numpy(self.heatmaps[y[:], :].copy()).to(y.device)


class PseudoLabelGenerator2d(nn.Module):
"""
Generate ground truth heatmap and ground false heatmap from a prediction.
Expand Down
1 change: 0 additions & 1 deletion examples/domain_adaptation/digits/README.md

This file was deleted.

Loading

0 comments on commit a6d9a2b

Please sign in to comment.