diff --git a/README.md b/README.md index 306cf2d..e460a72 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # Incremental Learning +**This repository is not up to date with my local version, when I'll have finished working on my paper, I'll update this repo.** + *Also called lifelong learning, or continual learning.* This repository will store all my implementations of Incremental Learning's papers. diff --git a/inclearn/__init__.py b/inclearn/__init__.py index 7b762e5..19d72dd 100644 --- a/inclearn/__init__.py +++ b/inclearn/__init__.py @@ -1 +1 @@ -from inclearn import data, factory, models, resnet, utils, results_utils +from inclearn import parser, train diff --git a/inclearn/__main__.py b/inclearn/__main__.py index 73ea3d4..148ba9f 100644 --- a/inclearn/__main__.py +++ b/inclearn/__main__.py @@ -6,5 +6,6 @@ if args["seed_range"] is not None: args["seed"] = list(range(args["seed_range"][0], args["seed_range"][1] + 1)) + print("Seed range", args["seed"]) train(args) diff --git a/inclearn/convnet/__init__.py b/inclearn/convnet/__init__.py new file mode 100644 index 0000000..f731e81 --- /dev/null +++ b/inclearn/convnet/__init__.py @@ -0,0 +1 @@ +from . import cifar_resnet, densenet, my_resnet, resnet diff --git a/inclearn/convnet/cifar_resnet.py b/inclearn/convnet/cifar_resnet.py new file mode 100644 index 0000000..86d57c4 --- /dev/null +++ b/inclearn/convnet/cifar_resnet.py @@ -0,0 +1,197 @@ +''' Incremental-Classifier Learning + Authors : Khurram Javed, Muhammad Talha Paracha + Maintainer : Khurram Javed + Lab : TUKL-SEECS R&D Lab + Email : 14besekjaved@seecs.edu.pk ''' + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init + + +class DownsampleA(nn.Module): + def __init__(self, nIn, nOut, stride): + super(DownsampleA, self).__init__() + assert stride == 2 + self.avg = nn.AvgPool2d(kernel_size=1, stride=stride) + + def forward(self, x): + x = self.avg(x) + return torch.cat((x, x.mul(0)), 1) + + +class ResNetBasicblock(nn.Module): + expansion = 1 + """ + RexNet basicblock (https://github.com/facebook/fb.resnet.torch/blob/master/models/resnet.lua) + """ + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(ResNetBasicblock, self).__init__() + + self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn_a = nn.BatchNorm2d(planes) + + self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn_b = nn.BatchNorm2d(planes) + + self.downsample = downsample + self.featureSize = 64 + + def forward(self, x): + residual = x + + basicblock = self.conv_a(x) + basicblock = self.bn_a(basicblock) + basicblock = F.relu(basicblock, inplace=True) + + basicblock = self.conv_b(basicblock) + basicblock = self.bn_b(basicblock) + + if self.downsample is not None: + residual = self.downsample(x) + + return F.relu(residual + basicblock, inplace=True) + + +class CifarResNet(nn.Module): + """ + ResNet optimized for the Cifar Dataset, as specified in + https://arxiv.org/abs/1512.03385.pdf + """ + + def __init__(self, block, depth, num_classes, channels=3): + """ Constructor + Args: + depth: number of layers. + num_classes: number of classes + base_width: base width + """ + super(CifarResNet, self).__init__() + + self.featureSize = 64 + # Model type specifies number of layers for CIFAR-10 and CIFAR-100 model + assert (depth - 2) % 6 == 0, 'depth should be one of 20, 32, 44, 56, 110' + layer_blocks = (depth - 2) // 6 + + self.num_classes = num_classes + + self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False) + self.bn_1 = nn.BatchNorm2d(16) + + self.inplanes = 16 + self.stage_1 = self._make_layer(block, 16, layer_blocks, 1) + self.stage_2 = self._make_layer(block, 32, layer_blocks, 2) + self.stage_3 = self._make_layer(block, 64, layer_blocks, 2) + self.avgpool = nn.AvgPool2d(8) + self.out_dim = 64 * block.expansion + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + # m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + init.kaiming_normal(m.weight) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = DownsampleA(self.inplanes, planes * block.expansion, stride) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x, feature=False, T=1, labels=False, scale=None, keep=None): + + x = self.conv_1_3x3(x) + x = F.relu(self.bn_1(x), inplace=True) + x = self.stage_1(x) + x = self.stage_2(x) + x = self.stage_3(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + return x + + def forwardFeature(self, x): + pass + + +def resnet20(num_classes=10): + """Constructs a ResNet-20 model for CIFAR-10 (by default) + Args: + num_classes (uint): number of classes + """ + model = CifarResNet(ResNetBasicblock, 20, num_classes) + return model + + +def resnet10mnist(num_classes=10): + """Constructs a ResNet-20 model for CIFAR-10 (by default) + Args: + num_classes (uint): number of classes + """ + model = CifarResNet(ResNetBasicblock, 10, num_classes, 1) + return model + + +def resnet20mnist(num_classes=10): + """Constructs a ResNet-20 model for CIFAR-10 (by default) + Args: + num_classes (uint): number of classes + """ + model = CifarResNet(ResNetBasicblock, 20, num_classes, 1) + return model + + +def resnet32mnist(num_classes=10, channels=1): + model = CifarResNet(ResNetBasicblock, 32, num_classes, channels) + return model + + +def resnet32(num_classes=10): + """Constructs a ResNet-32 model for CIFAR-10 (by default) + Args: + num_classes (uint): number of classes + """ + model = CifarResNet(ResNetBasicblock, 32, num_classes) + return model + + +def resnet44(num_classes=10): + """Constructs a ResNet-44 model for CIFAR-10 (by default) + Args: + num_classes (uint): number of classes + """ + model = CifarResNet(ResNetBasicblock, 44, num_classes) + return model + + +def resnet56(num_classes=10): + """Constructs a ResNet-56 model for CIFAR-10 (by default) + Args: + num_classes (uint): number of classes + """ + model = CifarResNet(ResNetBasicblock, 56, num_classes) + return model + + +def resnet110(num_classes=10): + """Constructs a ResNet-110 model for CIFAR-10 (by default) + Args: + num_classes (uint): number of classes + """ + model = CifarResNet(ResNetBasicblock, 110, num_classes) + return model diff --git a/inclearn/convnet/densenet.py b/inclearn/convnet/densenet.py new file mode 100644 index 0000000..a5b7454 --- /dev/null +++ b/inclearn/convnet/densenet.py @@ -0,0 +1,181 @@ +import re +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] + +model_urls = { + 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', + 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', + 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', + 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', +} + + +class _DenseLayer(nn.Sequential): + def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): + super(_DenseLayer, self).__init__() + self.add_module('norm1', nn.BatchNorm2d(num_input_features)), + self.add_module('relu1', nn.ReLU(inplace=True)), + self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * + growth_rate, kernel_size=1, stride=1, + bias=False)), + self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), + self.add_module('relu2', nn.ReLU(inplace=True)), + self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, + kernel_size=3, stride=1, padding=1, + bias=False)), + self.drop_rate = drop_rate + + def forward(self, x): + new_features = super(_DenseLayer, self).forward(x) + if self.drop_rate > 0: + new_features = F.dropout(new_features, p=self.drop_rate, + training=self.training) + return torch.cat([x, new_features], 1) + + +class _DenseBlock(nn.Sequential): + def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): + super(_DenseBlock, self).__init__() + for i in range(num_layers): + layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, + bn_size, drop_rate) + self.add_module('denselayer%d' % (i + 1), layer) + + +class _Transition(nn.Sequential): + def __init__(self, num_input_features, num_output_features): + super(_Transition, self).__init__() + self.add_module('norm', nn.BatchNorm2d(num_input_features)) + self.add_module('relu', nn.ReLU(inplace=True)) + self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, + kernel_size=1, stride=1, bias=False)) + self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) + + +class DenseNet(nn.Module): + r"""Densenet-BC model class, based on + `"Densely Connected Convolutional Networks" `_ + + Args: + growth_rate (int) - how many filters to add each layer (`k` in paper) + block_config (list of 4 ints) - how many layers in each pooling block + num_init_features (int) - the number of filters to learn in the first convolution layer + bn_size (int) - multiplicative factor for number of bottle neck layers + (i.e. bn_size * k features in the bottleneck layer) + drop_rate (float) - dropout rate after each dense layer + num_classes (int) - number of classification classes + """ + + def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), + num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, + **kwargs): + + super(DenseNet, self).__init__() + + # First convolution + self.features = nn.Sequential(OrderedDict([ + ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, + padding=3, bias=False)), + ('norm0', nn.BatchNorm2d(num_init_features)), + ('relu0', nn.ReLU(inplace=True)), + ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + ])) + + # Each denseblock + num_features = num_init_features + for i, num_layers in enumerate(block_config): + block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, + bn_size=bn_size, growth_rate=growth_rate, + drop_rate=drop_rate) + self.features.add_module('denseblock%d' % (i + 1), block) + num_features = num_features + num_layers * growth_rate + if i != len(block_config) - 1: + trans = _Transition(num_input_features=num_features, + num_output_features=num_features // 2) + self.features.add_module('transition%d' % (i + 1), trans) + num_features = num_features // 2 + + # Final batch norm + self.features.add_module('norm5', nn.BatchNorm2d(num_features)) + + self.out_dim = num_features + + # Official init from torch repo. + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.constant_(m.bias, 0) + + def forward(self, x): + features = self.features(x) + out = F.relu(features, inplace=True) + out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) + return out + + +def _load_state_dict(model, model_url, progress): + pass + +def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress, + **kwargs): + model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) + if pretrained: + _load_state_dict(model, model_urls[arch], progress) + return model + + +def densenet121(pretrained=False, progress=True, **kwargs): + r"""Densenet-121 model from + `"Densely Connected Convolutional Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _densenet('densenet121', 32, (6, 12, 24, 16), 64, pretrained, progress, + **kwargs) + + +def densenet161(pretrained=False, progress=True, **kwargs): + r"""Densenet-161 model from + `"Densely Connected Convolutional Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _densenet('densenet161', 48, (6, 12, 36, 24), 96, pretrained, progress, + **kwargs) + + +def densenet169(pretrained=False, progress=True, **kwargs): + r"""Densenet-169 model from + `"Densely Connected Convolutional Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _densenet('densenet169', 32, (6, 12, 32, 32), 64, pretrained, progress, + **kwargs) + + +def densenet201(pretrained=False, progress=True, **kwargs): + r"""Densenet-201 model from + `"Densely Connected Convolutional Networks" `_ + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _densenet('densenet201', 32, (6, 12, 48, 32), 64, pretrained, progress, + **kwargs) diff --git a/inclearn/convnet/my_resnet.py b/inclearn/convnet/my_resnet.py new file mode 100644 index 0000000..3a3c22b --- /dev/null +++ b/inclearn/convnet/my_resnet.py @@ -0,0 +1,134 @@ +''' Incremental-Classifier Learning + Authors : Khurram Javed, Muhammad Talha Paracha + Maintainer : Khurram Javed + Lab : TUKL-SEECS R&D Lab + Email : 14besekjaved@seecs.edu.pk ''' + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import init + + +class DownsampleStride(nn.Module): + def __init__(self, n=2): + super(DownsampleStride, self).__init__() + self._n = n + + def forward(self, x): + return x[..., ::2, ::2] + + +class ResidualBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, increase_dim=False, last=False): + super(ResidualBlock, self).__init__() + + self.increase_dim = increase_dim + + if increase_dim: + first_stride = 2 + planes = inplanes * 2 + else: + first_stride = 1 + planes = inplanes + + self.conv_a = nn.Conv2d(inplanes, planes, kernel_size=3, stride=first_stride, padding=1, bias=False) + self.bn_a = nn.BatchNorm2d(planes) + + self.conv_b = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn_b = nn.BatchNorm2d(planes) + + if increase_dim: + self.downsample = DownsampleStride() + self.pad = lambda x: torch.cat((x, x.mul(0)), 1) + self.last = last + + def forward(self, x): + y = self.conv_a(x) + y = self.bn_a(y) + y = F.relu(y, inplace=True) + + y = self.conv_b(y) + y = self.bn_b(y) + + if self.increase_dim: + x = self.downsample(x) + x = self.pad(x) + + if x.shape != y.shape: + import pdb; pdb.set_trace() + + y = x + y + + if self.last: + y = F.relu(y, inplace=True) + + return y + + +class CifarResNet(nn.Module): + """ + ResNet optimized for the Cifar Dataset, as specified in + https://arxiv.org/abs/1512.03385.pdf + """ + + def __init__(self, n=5, channels=3): + """ Constructor + Args: + depth: number of layers. + num_classes: number of classes + base_width: base width + """ + super(CifarResNet, self).__init__() + + self.conv_1_3x3 = nn.Conv2d(channels, 16, kernel_size=3, stride=1, padding=1, bias=False) + self.bn_1 = nn.BatchNorm2d(16) + + self.inplanes = 16 + self.stage_1 = self._make_layer(16, increase_dim=False, n=n) + self.stage_2 = self._make_layer(16, increase_dim=True, n=n-1) + self.stage_3 = self._make_layer(32, increase_dim=True, n=n-2) + self.stage_4 = ResidualBlock(64, increase_dim=False, last=True) + + self.avgpool = nn.AvgPool2d(8) + self.out_dim = 64 + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, planes, increase_dim=False, last=False, n=None): + layers = [] + + if increase_dim: + layers.append( + ResidualBlock(planes, increase_dim=True) + ) + planes = 2 * planes + + for i in range(n): + layers.append(ResidualBlock(planes)) + + return nn.Sequential(*layers) + + def forward(self, x, feature=False, T=1, labels=False, scale=None, keep=None): + x = self.conv_1_3x3(x) + x = F.relu(self.bn_1(x), inplace=True) + x = self.stage_1(x) + x = self.stage_2(x) + x = self.stage_3(x) + x = self.stage_4(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + return x + + +def resnet_rebuffi(n=5): + return CifarResNet(n=n) diff --git a/inclearn/resnet.py b/inclearn/convnet/resnet.py similarity index 100% rename from inclearn/resnet.py rename to inclearn/convnet/resnet.py diff --git a/inclearn/data.py b/inclearn/data.py deleted file mode 100644 index f172c9b..0000000 --- a/inclearn/data.py +++ /dev/null @@ -1,183 +0,0 @@ -import numpy as np -import torch -from PIL import Image -from torch.utils.data import DataLoader -from torch.utils.data.sampler import SubsetRandomSampler -from torchvision import datasets, transforms - -# -------- -# Datasets -# -------- - -class IncrementalDataset(torch.utils.data.Dataset): - _base_dataset = None - _train_transforms = [] - _common_transforms = [transforms.ToTensor()] - - def __init__(self, data_path="data", train=True, randomize_class=False, - increment=10, shuffle=True, workers=10, batch_size=128, - classes_order=None): - self._train = train - self._increment = increment - - dataset = self._base_dataset( - data_path, - train=train, - download=True, - ) - self._data = self._preprocess_initial_data(dataset.data) - self._targets = np.array(dataset.targets) - - if classes_order is None: - self.classes_order = np.sort(np.unique(self._targets)) - - if randomize_class: - np.random.shuffle(self.classes_order) - else: - self.classes_order = classes_order - - trsf = self._train_transforms if train else [] - trsf = trsf + self._common_transforms - self._transforms = transforms.Compose(trsf) - - self._shuffle = shuffle - self._workers = workers - self._batch_size = batch_size - - self._memory_idxes = [] - - print("Classes order: ", self.classes_order) - self.set_classes_range(0, self._increment) - - def get_loader(self, validation_split=0.): - if validation_split: - indices = np.arange(len(self)) - np.random.shuffle(indices) - split_idx = int(len(self) * validation_split) - val_indices = indices[:split_idx] - train_indices = indices[split_idx:] - print("Val {}; Train {}.".format(val_indices.shape[0], train_indices.shape[0])) - - train_loader = self._get_loader(SubsetRandomSampler(train_indices)) - val_loader = self._get_loader(SubsetRandomSampler(val_indices)) - return train_loader, val_loader - - return self._get_loader(), None - - def _get_loader(self, sampler=None): - return DataLoader( - dataset=self, - batch_size=self._batch_size, - shuffle=False if sampler else self._shuffle, - num_workers=self._workers, - sampler=sampler - ) - - @property - def total_n_classes(self): - return len(np.unique(self._targets)) - - def _preprocess_initial_data(self, data): - return data - - def set_classes_range(self, low=0, high=None): - self._low_range = low - self._high_range = high - - if low == high: - high = high + 1 - - classes = self.classes_order[low:high] - idxes = np.where(np.isin(self._targets, classes))[0] - - self._mapping = {fake_idx: real_idx for fake_idx, real_idx in enumerate(idxes)} - if low != high - 1: - self._update_memory_mapping() - - def set_idxes(self, idxes): - self._mapping = {fake_idx: real_idx for fake_idx, real_idx in enumerate(idxes)} - - def _update_memory_mapping(self): - if len(self._memory_idxes): - examplars_mapping = { - fake_idx: real_idx - for fake_idx, real_idx in zip( - range(len(self._mapping), len(self._memory_idxes)+len(self._mapping)), - self._memory_idxes - ) - } - for k, v in examplars_mapping.items(): - assert k not in self._mapping - self._mapping[k] = v - - def set_memory(self, idxes): - print("Setting {} memory examplars.".format(len(idxes))) - self._memory_idxes = idxes - - def get_true_index(self, fake_idx): - return self._mapping[fake_idx] - - def __len__(self): - return len(self._mapping) - - def __getitem__(self, idx): - real_idx = self._mapping[idx] - x, real_y = self._data[real_idx], self._targets[real_idx] - - x = Image.fromarray(x) - x = self._transforms(x) - - y = np.where(self.classes_order == real_y)[0][0] - - return (real_idx, idx), x, y - - -class iCIFAR10(IncrementalDataset): - _base_dataset = datasets.cifar.CIFAR10 - _train_transforms = [ - transforms.RandomCrop(32, padding=4), - transforms.RandomHorizontalFlip(), - transforms.ColorJitter(brightness=63 / 255) - ] - _common_transforms = [ - transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), - (0.2023, 0.1994, 0.2010)) - ] - - -class iCIFAR100(iCIFAR10): - _base_dataset = datasets.cifar.CIFAR100 - _common_transforms = [ - transforms.ToTensor(), - transforms.Normalize((0.5071, 0.4867, 0.4408), - (0.2675, 0.2565, 0.2761)), - ] - - -class iMNIST(IncrementalDataset): - _base_dataset = datasets.MNIST - _train_transforms = [ - transforms.RandomCrop(28, padding=4), - transforms.RandomHorizontalFlip() - ] - _common_transforms = [ - transforms.ToTensor() - ] - - -class iPermutedMNIST(iMNIST): - def _preprocess_initial_data(self, data): - b, w, h, c = data.shape - data = data.reshape(b, -1, c) - - permutation = np.random.permutation(w * h) - - data = data[:, permutation, :] - - return data.reshape(b, w, h, c) - - -# -------------- -# Data utilities -# -------------- diff --git a/inclearn/factory.py b/inclearn/factory.py deleted file mode 100644 index c2c4d82..0000000 --- a/inclearn/factory.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -from torch import optim - -from inclearn import data, models, resnet - - -def get_optimizer(params, optimizer, lr, weight_decay=0.0): - if optimizer == "adam": - return optim.Adam(params, lr=lr, weight_decay=weight_decay) - elif optimizer == "sgd": - return optim.SGD(params, lr=lr, weight_decay=weight_decay, momentum=0.9) - - raise NotImplementedError - - -def get_resnet(resnet_type, **kwargs): - if resnet_type == "resnet18": - return resnet.resnet18(**kwargs) - elif resnet_type == "resnet34": - return resnet.resnet101(**kwargs) - - raise NotImplementedError(resnet_type) - - -def get_model(args): - if args["model"] == "icarl": - return models.ICarl(args) - elif args["model"] == "lwf": - return models.LwF(args) - elif args["model"] == "e2e": - return models.End2End(args) - - raise NotImplementedError(args["model"]) - - -def get_data(args, train=True, classes_order=None): - dataset_name = args["dataset"].lower() - - if dataset_name in ("icifar100", "cifar100"): - dataset = data.iCIFAR100 - elif dataset_name in ("icifar10", "cifar10"): - dataset = data.iCIFAR10 - else: - raise NotImplementedError(dataset_name) - - return dataset(increment=args["increment"], - train=train, - randomize_class=args["random_classes"], - classes_order=classes_order) - - -def set_device(args): - device_type = args["device"] - - if device_type == -1: - device = torch.device("cpu") - else: - device = torch.device("cuda:{}".format(device_type)) - - args["device"] = device diff --git a/inclearn/lib/__init__.py b/inclearn/lib/__init__.py index 139880d..b7d2d1d 100644 --- a/inclearn/lib/__init__.py +++ b/inclearn/lib/__init__.py @@ -1 +1,2 @@ -from . import callbacks, metrics +from . import (callbacks, factory, herding, metrics, network, results_utils, + utils) diff --git a/inclearn/lib/callbacks.py b/inclearn/lib/callbacks.py index 9986ba5..9dc9198 100644 --- a/inclearn/lib/callbacks.py +++ b/inclearn/lib/callbacks.py @@ -1,7 +1,28 @@ +import copy + import torch -class GaussianNoiseAnnealing: +class Callback: + def __init__(self): + self._iteration = 0 + self._in_training = True + + @property + def in_training(self): + return self._in_training + + def on_epoch_begin(self): + pass + + def on_epoch_end(self, metric=None): + self._iteration += 1 + + def before_step(self): + pass + + +class GaussianNoiseAnnealing(Callback): """Add gaussian noise to the gradients. Add gaussian noise to the gradients with the given mean & std. The std will @@ -19,20 +40,48 @@ def __init__(self, parameters, eta=0.3, gamma=0.55): self._eta = eta self._gamma = gamma - self._iteration = 0 + super(GaussianNoiseAnnealing, self).__init__() - def add_noise(self): + def before_step(self): variance = self._eta / ((1 + self._iteration) ** self._gamma) for param in self._parameters: - # L2 regularization on gradients - param.grad.add_(0.0001, torch.norm(param.grad, p=2)) - # Noise on gradients: noise = torch.randn(param.grad.shape, device=param.grad.device) * variance param.grad.add_(noise) - param.grad.clamp_(min=-5, max=5) - def step(self): - self._iteration += 1 +class EarlyStopping(Callback): + def __init__(self, network, minimize_metric=True, patience=5, epsilon=1e-3): + self._patience = patience + self._wait = 0 + + if minimize_metric: + self._cmp_fun = lambda old, new: (old - epsilon) > new + self._best = float('inf') + else: + self._cmp_fun = lambda old, new: (old + epsilon) < new + self._best = float("-inf") + + self.network = network + + self._record = [] + + super(EarlyStopping, self).__init__() + + def on_epoch_end(self, metric): + self._record.append(metric) + + if self._cmp_fun(self._best, metric): + self._best = metric + self._wait = 0 + + self.network = copy.deepcopy(self.network) + else: + self._wait += 1 + if self._wait == self._patience: + print("Early stopping, metric is: {}.".format(metric)) + print(self._record[-self._patience:]) + self._in_training = False + + super(EarlyStopping, self).on_epoch_end(metric=metric) diff --git a/inclearn/lib/data.py b/inclearn/lib/data.py new file mode 100644 index 0000000..b8d682e --- /dev/null +++ b/inclearn/lib/data.py @@ -0,0 +1,324 @@ +import random + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import DataLoader +from torch.utils.data.sampler import SubsetRandomSampler +from torchvision import datasets, transforms + +# -------- +# Datasets +# -------- + + +class IncrementalDataset: + + def __init__( + self, + dataset_name, + random_order=False, + shuffle=True, + workers=10, + batch_size=128, + seed=1, + increment=10, + validation_split=0. + ): + datasets = _get_datasets(dataset_name) + self._setup_data( + datasets, + random_order=random_order, + seed=seed, + increment=increment, + validation_split=validation_split + ) + self.train_transforms = datasets[0].train_transforms # FIXME handle multiple datasets + self.common_transforms = datasets[0].common_transforms + + self._current_task = 0 + + self._batch_size = batch_size + self._workers = workers + self._shuffle = shuffle + + @property + def n_tasks(self): + return len(self.increments) + + def new_task(self, memory=None): + if self._current_task >= len(self.increments): + raise Exception("No more tasks.") + + min_class = sum(self.increments[:self._current_task]) + max_class = sum(self.increments[:self._current_task + 1]) + x_train, y_train = self._select( + self.data_train, self.targets_train, low_range=min_class, high_range=max_class + ) + x_val, y_val = self._select( + self.data_val, self.targets_val, low_range=min_class, high_range=max_class + ) + x_test, y_test = self._select(self.data_test, self.targets_test, high_range=max_class) + + if memory is not None: + data_memory, targets_memory = memory + print("Set memory of size: {}.".format(data_memory.shape[0])) + x_train = np.concatenate((x_train, data_memory)) + y_train = np.concatenate((y_train, targets_memory)) + + train_loader = self._get_loader(x_train, y_train, mode="train") + val_loader = self._get_loader(x_val, y_val, mode="train") if len(x_val) > 0 else None + test_loader = self._get_loader(x_test, y_test, mode="test") + + task_info = { + "min_class": min_class, + "max_class": max_class, + "increment": self.increments[self._current_task], + "task": self._current_task, + "max_task": len(self.increments), + "n_train_data": x_train.shape[0], + "n_test_data": x_test.shape[0] + } + + self._current_task += 1 + + return task_info, train_loader, val_loader, test_loader + + def get_custom_loader(self, class_indexes, mode="test", data_source="train"): + """Returns a custom loader. + + :param class_indexes: A list of class indexes that we want. + :param mode: Various mode for the transformations applied on it. + :param data_source: Whether to fetch from the train, val, or test set. + :return: The raw data and a loader. + """ + if not isinstance(class_indexes, list): # TODO: deprecated, should always give a list + class_indexes = [class_indexes] + + if data_source == "train": + x, y = self.data_train, self.targets_train + elif data_source == "val": + x, y = self.data_val, self.targets_val + elif data_source == "test": + x, y = self.data_test, self.targets_test + else: + raise ValueError("Unknown data source <{}>.".format(data_source)) + + data, targets = [], [] + for class_index in class_indexes: + class_data, class_targets = self._select( + x, y, low_range=class_index, high_range=class_index + 1 + ) + data.append(class_data) + targets.append(class_targets) + + data = np.concatenate(data) + targets = np.concatenate(targets) + + return data, self._get_loader(data, targets, shuffle=False, mode=mode) + + def _select(self, x, y, low_range=0, high_range=0): + idxes = np.where(np.logical_and(y >= low_range, y < high_range))[0] + return x[idxes], y[idxes] + + def _get_loader(self, x, y, shuffle=True, mode="train"): + if mode == "train": + trsf = transforms.Compose([*self.train_transforms, *self.common_transforms]) + elif mode == "test": + trsf = transforms.Compose(self.common_transforms) + elif mode == "flip": + trsf = transforms.Compose( + [transforms.RandomHorizontalFlip(p=1.), *self.common_transforms] + ) + else: + raise NotImplementedError("Unknown mode {}.".format(mode)) + + return DataLoader( + DummyDataset(x, y, trsf), + batch_size=self._batch_size, + shuffle=shuffle, + num_workers=self._workers + ) + + def _setup_data(self, datasets, random_order=False, seed=1, increment=10, validation_split=0.): + # FIXME: handles online loading of images + self.data_train, self.targets_train = [], [] + self.data_test, self.targets_test = [], [] + self.data_val, self.targets_val = [], [] + self.increments = [] + self.class_order = [] + + current_class_idx = 0 # When using multiple datasets + for dataset in datasets: + train_dataset = dataset.base_dataset("data", train=True, download=True) + test_dataset = dataset.base_dataset("data", train=False, download=True) + + x_train, y_train = train_dataset.data, np.array(train_dataset.targets) + x_val, y_val, x_train, y_train = self._split_per_class( + x_train, y_train, validation_split + ) + x_test, y_test = test_dataset.data, np.array(test_dataset.targets) + + order = [i for i in range(len(np.unique(y_train)))] + if random_order: + random.seed(seed) # Ensure that following order is determined by seed: + random.shuffle(order) + elif dataset.class_order is not None: + order = dataset.class_order + + self.class_order.append(order) + + y_train = self._map_new_class_index(y_train, order) + y_val = self._map_new_class_index(y_val, order) + y_test = self._map_new_class_index(y_test, order) + + y_train += current_class_idx + y_val += current_class_idx + y_test += current_class_idx + + current_class_idx += len(order) + if len(datasets) > 1: + self.increments.append(len(order)) + else: + self.increments = [increment for _ in range(len(order) // increment)] + + self.data_train.append(x_train) + self.targets_train.append(y_train) + self.data_val.append(x_val) + self.targets_val.append(y_val) + self.data_test.append(x_test) + self.targets_test.append(y_test) + + self.data_train = np.concatenate(self.data_train) + self.targets_train = np.concatenate(self.targets_train) + self.data_val = np.concatenate(self.data_val) + self.targets_val = np.concatenate(self.targets_val) + self.data_test = np.concatenate(self.data_test) + self.targets_test = np.concatenate(self.targets_test) + + @staticmethod + def _map_new_class_index(y, order): + """Transforms targets for new class order.""" + return np.array(list(map(lambda x: order.index(x), y))) + + @staticmethod + def _split_per_class(x, y, validation_split=0.): + """Splits train data for a subset of validation data. + + Split is done so that each class has a much data. + """ + shuffled_indexes = np.random.permutation(x.shape[0]) + x = x[shuffled_indexes] + y = y[shuffled_indexes] + + x_val, y_val = [], [] + x_train, y_train = [], [] + + for class_id in np.unique(y): + class_indexes = np.where(y == class_id)[0] + nb_val_elts = int(class_indexes.shape[0] * validation_split) + + val_indexes = class_indexes[:nb_val_elts] + train_indexes = class_indexes[nb_val_elts:] + + x_val.append(x[val_indexes]) + y_val.append(y[val_indexes]) + x_train.append(x[train_indexes]) + y_train.append(y[train_indexes]) + + x_val, y_val = np.concatenate(x_val), np.concatenate(y_val) + x_train, y_train = np.concatenate(x_train), np.concatenate(y_train) + + return x_val, y_val, x_train, y_train + + +class DummyDataset(torch.utils.data.Dataset): + + def __init__(self, x, y, trsf): + self.x, self.y = x, y + self.trsf = trsf + + def __len__(self): + return self.x.shape[0] + + def __getitem__(self, idx): + x, y = self.x[idx], self.y[idx] + + x = Image.fromarray(x) + x = self.trsf(x) + + return x, y + + +def _get_datasets(dataset_names): + return [_get_dataset(dataset_name) for dataset_name in dataset_names.split("-")] + + +def _get_dataset(dataset_name): + dataset_name = dataset_name.lower().strip() + + if dataset_name == "cifar10": + return iCIFAR10 + elif dataset_name == "cifar100": + return iCIFAR100 + else: + raise NotImplementedError("Unknown dataset {}.".format(dataset_name)) + + +class DataHandler: + base_dataset = None + train_transforms = [] + common_transforms = [transforms.ToTensor()] + class_order = None + + +class iCIFAR10(DataHandler): + base_dataset = datasets.cifar.CIFAR10 + train_transforms = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(brightness=63 / 255) + ] + common_transforms = [ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ] + + +class iCIFAR100(iCIFAR10): + base_dataset = datasets.cifar.CIFAR100 + common_transforms = [ + transforms.ToTensor(), + transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)), + ] + class_order = [ + 87, 0, 52, 58, 44, 91, 68, 97, 51, 15, 94, 92, 10, 72, 49, 78, 61, 14, 8, 86, 84, 96, 18, + 24, 32, 45, 88, 11, 4, 67, 69, 66, 77, 47, 79, 93, 29, 50, 57, 83, 17, 81, 41, 12, 37, 59, + 25, 20, 80, 73, 1, 28, 6, 46, 62, 82, 53, 9, 31, 75, 38, 63, 33, 74, 27, 22, 36, 3, 16, 21, + 60, 19, 70, 90, 89, 43, 5, 42, 65, 76, 40, 30, 23, 85, 2, 95, 56, 48, 71, 64, 98, 13, 99, 7, + 34, 55, 54, 26, 35, 39 + ] + + +class iMNIST(DataHandler): + base_dataset = datasets.MNIST + train_transforms = [transforms.RandomCrop(28, padding=4), transforms.RandomHorizontalFlip()] + common_transforms = [transforms.ToTensor()] + + +class iPermutedMNIST(iMNIST): + + def _preprocess_initial_data(self, data): + b, w, h, c = data.shape + data = data.reshape(b, -1, c) + + permutation = np.random.permutation(w * h) + + data = data[:, permutation, :] + + return data.reshape(b, w, h, c) + + +# -------------- +# Data utilities +# -------------- diff --git a/inclearn/lib/factory.py b/inclearn/lib/factory.py new file mode 100644 index 0000000..25516c5 --- /dev/null +++ b/inclearn/lib/factory.py @@ -0,0 +1,67 @@ +import torch +from torch import optim + +from inclearn import models +from inclearn.convnet import densenet, my_resnet, resnet +from inclearn.lib import data + + +def get_optimizer(params, optimizer, lr, weight_decay=0.0): + if optimizer == "adam": + return optim.Adam(params, lr=lr, weight_decay=weight_decay) + elif optimizer == "sgd": + return optim.SGD(params, lr=lr, weight_decay=weight_decay, momentum=0.9) + + raise NotImplementedError + + +def get_convnet(convnet_type, **kwargs): + if convnet_type == "resnet18": + return resnet.resnet18(**kwargs) + elif convnet_type == "resnet34": + return resnet.resnet34(**kwargs) + elif convnet_type == "rebuffi": + return my_resnet.resnet_rebuffi() + elif convnet_type == "densenet121": + return densenet.densenet121(**kwargs) + + raise NotImplementedError("Unknwon convnet type {}.".format(convnet_type)) + + +def get_model(args): + if args["model"] == "icarl": + return models.ICarl(args) + elif args["model"] == "lwf": + return models.LwF(args) + elif args["model"] == "e2e": + return models.End2End(args) + elif args["model"] == "medic": + return models.Medic(args) + elif args["model"] == "focusforget": + return models.FocusForget(args) + elif args["model"] == "fixed": + return models.FixedRepresentation(args) + + raise NotImplementedError(args["model"]) + + +def get_data(args): + return data.IncrementalDataset( + dataset_name=args["dataset"], + random_order=args["random_classes"], + shuffle=True, + batch_size=args["batch_size"], + workers=args["workers"], + validation_split=args["validation"] + ) + + +def set_device(args): + device_type = args["device"] + + if device_type == -1: + device = torch.device("cpu") + else: + device = torch.device("cuda:{}".format(device_type)) + + args["device"] = device diff --git a/inclearn/lib/herding.py b/inclearn/lib/herding.py new file mode 100644 index 0000000..9251fc0 --- /dev/null +++ b/inclearn/lib/herding.py @@ -0,0 +1,14 @@ +import torch +from torch.nn import functional as F + + +def closest_to_mean(features): + F.normalize(features) + class_mean = torch.mean(features, dim=0, keepdim=False) + + return + + + +def l2_distance(x, y): + return (x - y).norm() diff --git a/inclearn/lib/network.py b/inclearn/lib/network.py new file mode 100644 index 0000000..008f740 --- /dev/null +++ b/inclearn/lib/network.py @@ -0,0 +1,101 @@ +import copy + +import torch +from torch import nn + +from inclearn.lib import factory + + +class BasicNet(nn.Module): + + def __init__( + self, convnet_type, use_bias=False, init="kaiming", use_multi_fc=False, device=None + ): + super(BasicNet, self).__init__() + + self.use_bias = use_bias + self.init = init + self.use_multi_fc = use_multi_fc + + self.convnet = factory.get_convnet(convnet_type, nf=64, zero_init_residual=True) + self.classifier = None + + self.n_classes = 0 + self.device = device + + self.to(self.device) + + def forward(self, x): + if self.classifier is None: + raise Exception("Add some classes before training.") + + features = self.convnet(x) + + if self.use_multi_fc: + logits = [] + for clf_name in self.classifier: + logits.append(self.__getattr__(clf_name)(features)) + logits = torch.cat(logits, 1) + else: + logits = self.classifier(features) + + return logits + + @property + def features_dim(self): + return self.convnet.out_dim + + def extract(self, x): + return self.convnet(x) + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + self.eval() + + return self + + def copy(self): + return copy.deepcopy(self) + + def add_classes(self, n_classes): + if self.use_multi_fc: + self._add_classes_multi_fc(n_classes) + else: + self._add_classes_single_fc(n_classes) + + self.n_classes += n_classes + + def _add_classes_multi_fc(self, n_classes): + if self.classifier is None: + self.classifier = [] + + new_classifier = self._gen_classifier(n_classes) + name = "_clf_{}".format(len(self.classifier)) + self.__setattr__(name, new_classifier) + self.classifier.append(name) + + def _add_classes_single_fc(self, n_classes): + if self.classifier is not None: + weight = copy.deepcopy(self.classifier.weight.data) + if self.use_bias: + bias = copy.deepcopy(self.classifier.bias.data) + + classifier = self._gen_classifier(self.n_classes + n_classes) + + if self.classifier is not None: + classifier.weight.data[:self.n_classes] = weight + if self.use_bias: + classifier.bias.data[:self.n_classes] = bias + + del self.classifier + self.classifier = classifier + + def _gen_classifier(self, n_classes): + classifier = nn.Linear(self.convnet.out_dim, n_classes, bias=self.use_bias).to(self.device) + if self.init == "kaiming": + nn.init.kaiming_normal_(classifier.weight, nonlinearity="linear") + if self.use_bias: + nn.init.constant_(classifier.bias, 0.) + + return classifier diff --git a/inclearn/results_utils.py b/inclearn/lib/results_utils.py similarity index 68% rename from inclearn/results_utils.py rename to inclearn/lib/results_utils.py index 2efb3ff..739de27 100644 --- a/inclearn/results_utils.py +++ b/inclearn/lib/results_utils.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt -from inclearn import utils +from inclearn.lib import utils def get_template_results(args): @@ -15,7 +15,7 @@ def get_template_results(args): def save_results(results, label): del results["config"]["device"] - folder_path = os.path.join("results", label) + folder_path = os.path.join("results", "{}_{}".format(utils.get_date(), label)) if not os.path.exists(folder_path): os.makedirs(folder_path) @@ -40,32 +40,32 @@ def extract(paths, avg_inc=False): with open(path) as f: data = json.load(f) - accs = [100 * task["total"] for task in data["results"]] + if isinstance(data["results"][0], dict): + accs = [100 * task["total"] for task in data["results"]] + elif isinstance(data["results"][0], float): + accs = [100 * task_acc for task_acc in data["results"]] + else: + raise NotImplementedError(type(data["results"][0])) if avg_inc: - accs = compute_avg_inc_acc(accs) + raise NotImplementedError("Deprecated") runs_accs.append(accs) return runs_accs -def compute_avg_inc_acc(accs): +def compute_avg_inc_acc(results): """Computes the average incremental accuracy as defined in iCaRL. The average incremental accuracies at task X are the average of accuracies at task 0, 1, ..., and X. - :param accs: A list of accuracies. - :return: A list of average incremental accuracies. + :param accs: A list of dict for per-class accuracy at each step. + :return: A float. """ - avg_inc_accs = [] - - for i in range(len(accs)): - sub_accs = [accs[j] for j in range(0, i + 1)] - avg_inc_accs.append(sum(sub_accs) / len(sub_accs)) - - return avg_inc_accs + tasks_accuracy = [r["total"] for r in results] + return sum(tasks_accuracy) / len(tasks_accuracy) def aggregate(runs_accs): @@ -118,6 +118,10 @@ def compute_unique_score(runs_accs, skip_first=False): return str(round(mean_of_mean, 2)), std +def get_max_label_length(results): + return max(len(r.get("label", r["path"])) for r in results) + + def plot(results, increment, total, title="", path_to_save=None): """Plotting utilities to visualize several experiments. @@ -132,11 +136,18 @@ def plot(results, increment, total, title="", path_to_save=None): x = list(range(increment, total + 1, increment)) + max_label_length = get_max_label_length(results) + 4 + for result in results: path = result["path"] - label = result["label"] + label = result.get("label", path) + from_paper = "[paper] " if result.get("from_paper", False) else "[me] " avg_inc = result.get("average_incremental", False) skip_first = result.get("skip_first", False) + kwargs = result.get("kwargs", {}) + + if result.get("hidden", False): + continue if "*" in path: path = glob.glob(path) @@ -148,19 +159,35 @@ def plot(results, increment, total, title="", path_to_save=None): unique_score, unique_std = compute_unique_score(runs_accs, skip_first=skip_first) - plt.errorbar(x, means, stds, label=label + " ({})".format(unique_score + unique_std), - marker="o", markersize=3) + label = "{mode}{label}(avg: {avg}, last: {last})".format( + mode=from_paper, + label=label.ljust(max_label_length, " "), + avg=unique_score + unique_std, + last=round(means[-1], 2) + ) + + try: + mean_stds_sub = [i - j for i, j in zip(means, stds)] + mean_stds_add = [i + j for i, j in zip(means, stds)] + plt.plot(x, means, label=label, **kwargs) + plt.fill_between(x, mean_stds_sub, mean_stds_add, alpha=0.3) + except Exception: + print(x) + print(means) + print(stds) + print(label) + raise plt.legend(loc="upper right") plt.xlabel("Number of classes") - plt.ylabel("Average Incremental Accuracy") + plt.ylabel("Accuracy over seen classes") plt.title(title) - for i in range(10, total + 1, 10): + for i in range(10, 100 + 1, 10): plt.axhline(y=i, color='black', linestyle='dashed', linewidth=1, alpha=0.2) plt.yticks([i for i in range(10, total + 1, 10)]) - plt.xticks([i for i in range(10, len(x) * increment + 1, 10)]) + plt.xticks([i for i in range(increment, len(x) * increment + 1, increment)]) if path_to_save: - plt.savefig(path_to_save) + plt.savefig(path_to_save, dpi=1200) plt.show() diff --git a/inclearn/utils.py b/inclearn/lib/utils.py similarity index 75% rename from inclearn/utils.py rename to inclearn/lib/utils.py index caaf1e0..f1b192b 100644 --- a/inclearn/utils.py +++ b/inclearn/lib/utils.py @@ -5,11 +5,13 @@ def to_onehot(targets, n_classes): - return torch.eye(n_classes)[targets] + onehot = torch.zeros(targets.shape[0], n_classes).to(targets.device) + onehot.scatter_(dim=1, index=targets.long().view(-1, 1), value=1.) + return onehot def _check_loss(loss): - return not torch.isnan(loss) and loss >= 0. + return not bool(torch.isnan(loss).item()) and bool((loss >= 0.).item()) def compute_accuracy(ypred, ytrue, task_size=10): diff --git a/inclearn/models/__init__.py b/inclearn/models/__init__.py index 92de89d..2fdd81c 100644 --- a/inclearn/models/__init__.py +++ b/inclearn/models/__init__.py @@ -1,5 +1,7 @@ -from .e2e import End2End +#from .e2e import End2End +#from .fixedrepresentation import FixedRepresentation +#from .focusforget import FocusForget from .icarl import ICarl -from .lwf import LwF -__all__ = ["ICarl", "LwF", "End2End"] +#from .lwf import LwF +#from .medic import Medic diff --git a/inclearn/models/base.py b/inclearn/models/base.py index 6ef888d..0b94e26 100644 --- a/inclearn/models/base.py +++ b/inclearn/models/base.py @@ -1,12 +1,10 @@ import abc import logging -from torch import nn - LOGGER = logging.Logger("IncLearn", level="INFO") -class IncrementalLearner(abc.ABC, nn.Module): +class IncrementalLearner(abc.ABC): """Base incremental learner. Methods are called in this order (& repeated for each new task): @@ -19,16 +17,16 @@ class IncrementalLearner(abc.ABC, nn.Module): """ def __init__(self, *args, **kwargs): - nn.Module.__init__(self, *args, **kwargs) + pass - def set_task_info( - self, task, total_n_classes, increment, n_train_data, n_test_data - ): + def set_task_info(self, task, total_n_classes, increment, n_train_data, n_test_data, + n_tasks): self._task = task self._task_size = increment self._total_n_classes = total_n_classes self._n_train_data = n_train_data self._n_test_data = n_test_data + self._n_tasks = n_tasks def before_task(self, train_loader, val_loader): LOGGER.info("Before task") @@ -40,18 +38,24 @@ def train_task(self, train_loader, val_loader): self.train() self._train_task(train_loader, val_loader) - def after_task(self, data_loader): + def after_task(self, inc_dataset): LOGGER.info("after task") self.eval() - self._after_task(data_loader) + self._after_task(inc_dataset) def eval_task(self, data_loader): LOGGER.info("eval task") self.eval() return self._eval_task(data_loader) - def get_memory_indexes(self): - return [] + def get_memory(self): + return None + + def eval(self): + raise NotImplementedError + + def train(self): + raise NotImplementedError def _before_task(self, data_loader): pass diff --git a/inclearn/models/e2e.py b/inclearn/models/e2e.py index a703c9f..b256c34 100644 --- a/inclearn/models/e2e.py +++ b/inclearn/models/e2e.py @@ -1,13 +1,15 @@ import numpy as np import torch +import tqdm from torch import nn from torch.nn import functional as F -from tqdm import trange from inclearn import factory, utils -from inclearn.lib import callbacks +from inclearn.lib import callbacks, network from inclearn.models.base import IncrementalLearner +tqdm.monitor_interval = 0 + class End2End(IncrementalLearner): """Implementation of End-to-End Increment Learning. @@ -28,28 +30,28 @@ def __init__(self, args): self._lr_decay = args["lr_decay"] self._k = args["memory_size"] - self._n_classes = args["increment"] + self._n_classes = 0 - self._temperature = 2.#args["temperature"] + self._temperature = args["temperature"] - self._features_extractor = factory.get_resnet( - args["convnet"], nf=64, zero_init_residual=True - ) - self._classifier = nn.Linear(self._features_extractor.out_dim, self._n_classes, bias=True) + self._network = network.BasicNet(args["convnet"], use_bias=True, use_multi_fc=True, + device=self._device) self._examplars = {} + self._old_model = [] - self.to(self._device) - - def forward(self, x): - x = self._features_extractor(x) - x = self._classifier(x) - return x + self._task_idxes = [] # ---------- # Public API # ---------- + def eval(self): + self._network.eval() + + def train(self): + self._network.train() + def _before_task(self, train_loader, val_loader): """Set up before the task training can begin. @@ -59,13 +61,12 @@ def _before_task(self, train_loader, val_loader): :param train_loader: The training dataloader. :param val_loader: The validation dataloader. """ - if self._task == 0: - self._previous_preds = None - else: - print("Computing previous predictions...") - self._previous_preds = self._compute_predictions(train_loader) + self._network.add_classes(self._task_size) + + self._task_idxes.append([self._n_classes + i for i in range(self._task_size)]) - self._add_n_classes(self._task_size) + self._n_classes += self._task_size + print("Now {} examplars per class.".format(self._m)) def _train_task(self, train_loader, val_loader): """Train & fine-tune model. @@ -80,37 +81,40 @@ def _train_task(self, train_loader, val_loader): :param train_loader: A DataLoader. :param val_loader: A DataLoader, can be None. """ - # Training on all new + examplars - self._best_acc = float("-inf") + if self._task == 0: + epochs = 90 + optimizer = factory.get_optimizer(self._network.parameters(), self._opt_name, 0.1, 0.001) + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [50, 60], gamma=0.1) + self._train(train_loader, val_loader, epochs, optimizer, scheduler) + return + # Training on all new + examplars print("Training") self._finetuning = False - optimizer = factory.get_optimizer(self.parameters(), self._opt_name, 0.1, 0.001) - scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [50, 60], gamma=0.2) - self._train(train_loader, val_loader, 70, optimizer, scheduler) - - if self._task == 0: - print("best", self._best_acc) - return + epochs = 60 + optimizer = factory.get_optimizer(self._network.parameters(), self._opt_name, 0.1, 0.001) + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [40, 50], gamma=0.1) + self._train(train_loader, val_loader, epochs, optimizer, scheduler) # Fine-tuning on sub-set new + examplars print("Fine-tuning") + self._old_model = self._network.copy().freeze() + self._finetuning = True self._build_examplars(train_loader, n_examplars=self._k // (self._n_classes - self._task_size)) train_loader.dataset.set_idxes(self.examplars) # Fine-tuning only on balanced dataset - self._previous_preds = self._compute_predictions(train_loader) - optimizer = factory.get_optimizer(self.parameters(), self._opt_name, 0.01, 0.001) - scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [20, 40], gamma=0.2) - self._train(train_loader, val_loader, 50, optimizer, scheduler) - - print("best", self._best_acc) + optimizer = factory.get_optimizer(self._network.parameters(), self._opt_name, 0.01, 0.001) + scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 20], gamma=0.1) + self._train(train_loader, val_loader, 40, optimizer, scheduler) def _after_task(self, data_loader): self._reduce_examplars() self._build_examplars(data_loader) + self._old_model = self._network.copy().freeze() + def _eval_task(self, data_loader): ypred, ytrue = self._classify(data_loader) assert ypred.shape == ytrue.shape @@ -125,29 +129,35 @@ def get_memory_indexes(self): # ----------- def _train(self, train_loader, val_loader, n_epochs, optimizer, scheduler): + self._callbacks = [ + callbacks.GaussianNoiseAnnealing(self._network.parameters()), + #callbacks.EarlyStopping(self._network, minimize_metric=False) + ] + self._best_acc = float("-inf") + print("nb ", len(train_loader.dataset)) - prog_bar = trange(n_epochs, desc="Losses.") + prog_bar = tqdm.trange(n_epochs, desc="Losses.") val_acc = 0. + train_acc = 0. for epoch in prog_bar: - if epoch % 10 == 0 and val_loader: - ypred, ytrue = self._classify(val_loader) - val_acc = (ypred == ytrue).sum() / len(ytrue) - self._best_acc = max(self._best_acc, val_acc) + for cb in self._callbacks: + cb.on_epoch_begin() + + scheduler.step() _clf_loss, _distil_loss = 0., 0. c = 0 - scheduler.step() - for i, ((_, idxes), inputs, targets) in enumerate(train_loader, start=1): optimizer.zero_grad() - c += len(idxes) + c += 1 inputs, targets = inputs.to(self._device), targets.to(self._device) - logits = self.forward(inputs) + logits = self._network(inputs) clf_loss, distil_loss = self._compute_loss( + inputs, logits, targets, idxes, @@ -160,6 +170,12 @@ def _train(self, train_loader, val_loader, n_epochs, optimizer, scheduler): loss = clf_loss + distil_loss loss.backward() + + #if self._task != 0: + # for param in self._network.parameters(): + # param.grad = param.grad * (self._temperature ** 2) + for cb in self._callbacks: + cb.before_step() optimizer.step() _clf_loss += clf_loss.item() @@ -167,21 +183,39 @@ def _train(self, train_loader, val_loader, n_epochs, optimizer, scheduler): if i % 10 == 0 or i >= len(train_loader): prog_bar.set_description( - "Clf loss: {}; Distill loss: {}; Val acc: {}".format( + "Clf: {}; Distill: {}; Train: {}; Val: {}".format( round(clf_loss.item(), 3), round(distil_loss.item(), 3), + round(train_acc, 3), round(val_acc, 3) ) ) + if val_loader: + ypred, ytrue = self._classify(val_loader) + val_acc = (ypred == ytrue).sum() / len(ytrue) + self._best_acc = max(self._best_acc, val_acc) + ypred, ytrue = self._classify(train_loader) + train_acc = (ypred == ytrue).sum() / len(ytrue) + + for cb in self._callbacks: + cb.on_epoch_end(metric=val_acc) prog_bar.set_description( - "Clf loss: {}; Distill loss: {}; Val acc: {}".format( + "Clf: {}; Distill: {}; Train: {}; Val: {}".format( round(_clf_loss / c, 3), round(_distil_loss / c, 3), - round(val_acc, 3) + round(train_acc, 3), + round(val_acc, 3), ) ) - def _compute_loss(self, logits, targets, idxes): + for cb in self._callbacks: + if not cb.in_training: + self._network = cb.network + return + + print("best", self._best_acc) + + def _compute_loss(self, inputs, logits, targets, idxes): """Computes the classification loss & the distillation loss. Distillation loss is null at the first task. @@ -197,13 +231,24 @@ def _compute_loss(self, logits, targets, idxes): if self._task == 0: distil_loss = torch.zeros(1, device=self._device) else: - if not self._finetuning: - logits = logits[..., :self._new_task_index] - - distil_loss = F.binary_cross_entropy( - F.softmax(logits / self._temperature, dim=1), - F.softmax(self._previous_preds[idxes] / self._temperature, dim=1) - ) + if self._finetuning: + # We only do distillation on current task during the distillation + # phase: + last_index = len(self._task_idxes) + else: + last_index = len(self._task_idxes) - 1 + + distil_loss = 0. + #with torch.no_grad(): + previous_logits = self._old_model(inputs) + + for i in range(last_index): + task_idxes = self._task_idxes[i] + + distil_loss += F.binary_cross_entropy( + F.softmax(logits[..., task_idxes] / self._temperature, dim=1), + F.softmax(previous_logits[..., task_idxes] / self._temperature, dim=1) + ) return clf_loss, distil_loss @@ -219,7 +264,7 @@ def _compute_predictions(self, loader): inputs = inputs.to(self._device) idxes = idxes[1].to(self._device) - logits[idxes] = self.forward(inputs).detach() + logits[idxes] = self._network(inputs).detach() return logits @@ -235,7 +280,7 @@ def _classify(self, loader): for _, inputs, targets in loader: inputs = inputs.to(self._device) - logits = self.forward(inputs) + logits = F.softmax(self._network(inputs), dim=1) preds = logits.argmax(dim=1).cpu().numpy() ypred.extend(preds) @@ -248,23 +293,13 @@ def _m(self): """Returns the number of examplars per class.""" return self._k // self._n_classes - def _add_n_classes(self, n): - self._n_classes += n - - weights = self._classifier.weight.data.clone() - self._classifier = nn.Linear(self._features_extractor.out_dim, self._n_classes, - bias=True).to(self._device) - self._classifier.weight.data[:self._n_classes - n] = weights - - print("Now {} examplars per class.".format(self._m)) - def _extract_features(self, loader): features = [] idxes = [] for (real_idxes, _), inputs, _ in loader: inputs = inputs.to(self._device) - features.append(self._features_extractor(inputs).detach()) + features.append(self._network.extract(inputs).detach()) idxes.extend(real_idxes.numpy().tolist()) features = torch.cat(features) diff --git a/inclearn/models/fixedrepresentation.py b/inclearn/models/fixedrepresentation.py new file mode 100644 index 0000000..76b86ab --- /dev/null +++ b/inclearn/models/fixedrepresentation.py @@ -0,0 +1,115 @@ +import logging + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from tqdm import trange + +from inclearn import factory, utils +from inclearn.models.base import IncrementalLearner + +LOGGER = logging.Logger("IncLearn", level="INFO") + + +class FixedRepresentation(IncrementalLearner): + """Base incremental learner. + + Methods are called in this order (& repeated for each new task): + + 1. set_task_info + 2. before_task + 3. train_task + 4. after_task + 5. eval_task + """ + + def __init__(self, args): + super().__init__() + + self._epochs = 70 + + self._n_classes = args["increment"] + self._device = args["device"] + + self._features_extractor = factory.get_resnet(args["convnet"], nf=64, + zero_init_residual=True) + + self._classifiers = [nn.Linear(self._features_extractor.out_dim, self._n_classes, bias=False).to(self._device)] + torch.nn.init.kaiming_normal_(self._classifiers[0].weight) + self.add_module("clf_" + str(self._n_classes), self._classifiers[0]) + + self.to(self._device) + + def forward(self, x): + feats = self._features_extractor(x) + + logits = [] + for clf in self._classifiers: + logits.append(clf(feats)) + + return torch.cat(logits, dim=1) + + def _before_task(self, data_loader, val_loader): + if self._task != 0: + self._add_n_classes(self._task_size) + + self._optimizer = factory.get_optimizer( + filter(lambda x: x.requires_grad, self.parameters()), + "sgd", 0.1) + self._scheduler = torch.optim.lr_scheduler.MultiStepLR(self._optimizer, [50, 60], gamma=0.2) + + def _get_params(self): + return [self._features_extractor.parameters()] + + def _train_task(self, train_loader, val_loader): + for _ in trange(self._epochs): + self._scheduler.step() + for _, inputs, targets in train_loader: + self._optimizer.zero_grad() + + inputs, targets = inputs.to(self._device), targets.to(self._device) + + logits = self.forward(inputs) + loss = F.cross_entropy(logits, targets) + loss.backward() + self._optimizer.step() + + def _after_task(self, data_loader): + pass + + def _eval_task(self, loader): + ypred = [] + ytrue = [] + + for _, inputs, targets in loader: + inputs = inputs.to(self._device) + logits = self.forward(inputs) + preds = logits.argmax(dim=1).cpu().numpy() + + ypred.extend(preds) + ytrue.extend(targets) + + ypred, ytrue = np.array(ypred), np.array(ytrue) + print(np.bincount(ypred)) + return ypred, ytrue + + def _add_n_classes(self, n): + self._n_classes += n + + self._classifiers.append(nn.Linear( + self._features_extractor.out_dim, self._task_size, + bias=False + ).to(self._device)) + nn.init.kaiming_normal_(self._classifiers[-1].weight) + self.add_module("clf_" + str(self._n_classes), self._classifiers[-1]) + + for param in self._features_extractor.parameters(): + param.requires_grad = False + + for clf in self._classifiers[:-1]: + for param in clf.parameters(): + param.requires_grad = False + for param in self._classifiers[-1].parameters(): + for param in clf.parameters(): + param.requires_grad = True diff --git a/inclearn/models/icarl.py b/inclearn/models/icarl.py index 8480e9a..9356a45 100644 --- a/inclearn/models/icarl.py +++ b/inclearn/models/icarl.py @@ -1,12 +1,14 @@ import numpy as np import torch -from torch import nn +from scipy.spatial.distance import cdist from torch.nn import functional as F -from tqdm import trange +from tqdm import tqdm -from inclearn import factory, utils +from inclearn.lib import factory, network, utils from inclearn.models.base import IncrementalLearner +EPSILON = 1e-8 + class ICarl(IncrementalLearner): """Implementation of iCarl. @@ -18,6 +20,7 @@ class ICarl(IncrementalLearner): :param args: An argparse parsed arguments object. """ + def __init__(self, args): super().__init__() @@ -30,174 +33,113 @@ def __init__(self, args): self._scheduling = args["scheduling"] self._lr_decay = args["lr_decay"] - self._k = args["memory_size"] - self._n_classes = args["increment"] + self._memory_size = args["memory_size"] + self._n_classes = 0 - self._features_extractor = factory.get_resnet(args["convnet"], nf=64, - zero_init_residual=True) - self._classifier = nn.Linear(self._features_extractor.out_dim, self._n_classes, bias=False) - torch.nn.init.kaiming_normal_(self._classifier.weight) + self._network = network.BasicNet(args["convnet"], device=self._device, use_bias=True) self._examplars = {} self._means = None + self._old_model = None + self._clf_loss = F.binary_cross_entropy_with_logits self._distil_loss = F.binary_cross_entropy_with_logits - self.to(self._device) + self._herding_matrix = [] - def forward(self, x): - x = self._features_extractor(x) - x = self._classifier(x) - return x + def eval(self): + self._network.eval() + + def train(self): + self._network.train() # ---------- # Public API # ---------- def _before_task(self, train_loader, val_loader): - if self._task == 0: - self._previous_preds = None - else: - print("Computing previous predictions...") - self._previous_preds = self._compute_predictions(train_loader) - if val_loader: - self._previous_preds_val = self._compute_predictions(val_loader) - - self._add_n_classes(self._task_size) + self._n_classes += self._task_size + self._network.add_classes(self._task_size) + print("Now {} examplars per class.".format(self._memory_per_class)) self._optimizer = factory.get_optimizer( - self.parameters(), - self._opt_name, - self._lr, - self._weight_decay + self._network.parameters(), self._opt_name, self._lr, self._weight_decay ) self._scheduler = torch.optim.lr_scheduler.MultiStepLR( - self._optimizer, - self._scheduling, - gamma=self._lr_decay + self._optimizer, self._scheduling, gamma=self._lr_decay ) def _train_task(self, train_loader, val_loader): - for p in self.parameters(): - p.register_hook(lambda grad: torch.clamp(grad, -5, 5)) - print("nb ", len(train_loader.dataset)) - prog_bar = trange(self._n_epochs, desc="Losses.") - - val_loss = 0. - for epoch in prog_bar: - _clf_loss, _distil_loss = 0., 0. - c = 0 + for epoch in range(self._n_epochs): + _loss, val_loss = 0., 0. self._scheduler.step() - for i, ((_, idxes), inputs, targets) in enumerate(train_loader, start=1): + prog_bar = tqdm(train_loader) + for i, (inputs, targets) in enumerate(prog_bar, start=1): self._optimizer.zero_grad() - c += len(idxes) - inputs, targets = inputs.to(self._device), targets.to(self._device) - targets = utils.to_onehot(targets, self._n_classes).to(self._device) - logits = self.forward(inputs) - - clf_loss, distil_loss = self._compute_loss( - logits, - targets, - idxes, - ) - - if not utils._check_loss(clf_loss) or not utils._check_loss(distil_loss): - import pdb; pdb.set_trace() + loss = self._forward_loss(inputs, targets) - loss = clf_loss + distil_loss + if not utils._check_loss(loss): + import pdb + pdb.set_trace() loss.backward() self._optimizer.step() - _clf_loss += clf_loss.item() - _distil_loss += distil_loss.item() + _loss += loss.item() - if i % 10 == 0 or i >= len(train_loader): - prog_bar.set_description( - "Clf loss: {}; Distill loss: {}; Val loss: {}".format( - round(clf_loss.item(), 3), - round(distil_loss.item(), 3), + if val_loader is not None and i == len(train_loader): + for inputs, targets in val_loader: + val_loss += self._forward_loss(inputs, targets).item() + + prog_bar.set_description( + "Task {}/{}, Epoch {}/{} => Clf loss: {}, Val loss: {}".format( + self._task + 1, self._n_tasks, + epoch + 1, self._n_epochs, + round(_loss / i, 3), round(val_loss, 3) - )) + ) + ) + + def _forward_loss(self, inputs, targets): + inputs, targets = inputs.to(self._device), targets.to(self._device) + targets = utils.to_onehot(targets, self._n_classes).to(self._device) + logits = self._network(inputs) - if val_loader is not None: - val_loss = self._compute_val_loss(val_loader) - prog_bar.set_description( - "Clf loss: {}; Distill loss: {}; Val loss: {}".format( - round(_clf_loss / c, 3), - round(_distil_loss / c, 3), - round(val_loss, 2) - )) + return self._compute_loss(inputs, logits, targets) - def _after_task(self, data_loader): - self._reduce_examplars() - self._build_examplars(data_loader) + def _after_task(self, inc_dataset): + self.build_examplars(inc_dataset) + + self._old_model = self._network.copy().freeze() def _eval_task(self, data_loader): - ypred, ytrue = self._classify(data_loader) - assert ypred.shape == ytrue.shape + ypred, ytrue = compute_accuracy(self._network, data_loader, self._class_means) return ypred, ytrue - def get_memory_indexes(self): - return self.examplars - # ----------- # Private API # ----------- - def _compute_val_loss(self, val_loader): - total_loss = 0. - c = 0 - - for idx, (idxes, inputs, targets) in enumerate(val_loader, start=1): - self._optimizer.zero_grad() - - c += len(idxes) - - inputs, targets = inputs.to(self._device), targets.to(self._device) - targets = utils.to_onehot(targets, self._n_classes).to(self._device) - logits = self.forward(inputs) - - clf_loss, distil_loss = self._compute_loss( - logits, - targets, - idxes[1], - train=False - ) - - if not utils._check_loss(clf_loss) or not utils._check_loss(distil_loss): - import pdb; pdb.set_trace() - - total_loss += (clf_loss + distil_loss).item() - - return total_loss - - def _compute_loss(self, logits, targets, idxes, train=True): - if self._task == 0: - # First task, only doing classification loss - clf_loss = self._clf_loss(logits, targets) - distil_loss = torch.zeros(1, device=self._device) + def _compute_loss(self, inputs, logits, targets): + if self._old_model is None: + loss = F.binary_cross_entropy_with_logits(logits, targets) else: - clf_loss = self._clf_loss( - logits[..., self._new_task_index:], - targets[..., self._new_task_index:] - ) + old_targets = torch.sigmoid(self._old_model(inputs).detach()) - previous_preds = self._previous_preds if train else self._previous_preds_val - distil_loss = self._distil_loss( - logits[..., :self._new_task_index], - previous_preds[idxes, :self._new_task_index] - ) + new_targets = targets.clone() + new_targets[..., :-self._task_size] = old_targets + + loss = F.binary_cross_entropy_with_logits(logits, new_targets) - return clf_loss, distil_loss + return loss def _compute_predictions(self, data_loader): preds = torch.zeros(self._n_train_data, self._n_classes, device=self._device) @@ -206,18 +148,21 @@ def _compute_predictions(self, data_loader): inputs = inputs.to(self._device) idxes = idxes[1].to(self._device) - preds[idxes] = self.forward(inputs).detach() + preds[idxes] = self._network(inputs).detach() return torch.sigmoid(preds) def _classify(self, data_loader): if self._means is None: - raise ValueError("Cannot classify without built examplar means," - " Have you forgotten to call `before_task`?") + raise ValueError( + "Cannot classify without built examplar means," + " Have you forgotten to call `before_task`?" + ) if self._means.shape[0] != self._n_classes: raise ValueError( "The number of examplar means ({}) is inconsistent".format(self._means.shape[0]) + - " with the number of classes ({}).".format(self._n_classes)) + " with the number of classes ({}).".format(self._n_classes) + ) ypred = [] ytrue = [] @@ -225,7 +170,7 @@ def _classify(self, data_loader): for _, inputs, targets in data_loader: inputs = inputs.to(self._device) - features = self._features_extractor(inputs).detach() + features = self._network.extract(inputs).detach() preds = self._get_closest(self._means, F.normalize(features)) ypred.extend(preds) @@ -234,119 +179,113 @@ def _classify(self, data_loader): return np.array(ypred), np.array(ytrue) @property - def _m(self): + def _memory_per_class(self): """Returns the number of examplars per class.""" - return self._k // self._n_classes + return self._memory_size // self._n_classes - def _add_n_classes(self, n): - self._n_classes += n + # ----------------- + # Memory management + # ----------------- - weights = self._classifier.weight.data - self._classifier = nn.Linear( - self._features_extractor.out_dim, self._n_classes, - bias=False - ).to(self._device) - torch.nn.init.kaiming_normal_(self._classifier.weight) + def build_examplars(self, inc_dataset): + print("Building & updating memory.") - self._classifier.weight.data[: self._n_classes - n] = weights + self._data_memory, self._targets_memory = [], [] + self._class_means = np.zeros((100, self._network.features_dim)) - print("Now {} examplars per class.".format(self._m)) + for class_idx in range(self._n_classes): + inputs, loader = inc_dataset.get_custom_loader(class_idx, mode="test") + features, targets = extract_features( + self._network, loader + ) + features_flipped, _ = extract_features( + self._network, inc_dataset.get_custom_loader(class_idx, mode="flip")[1] + ) - def _extract_features(self, loader): - features = [] - idxes = [] + if class_idx >= self._n_classes - self._task_size: + self._herding_matrix.append(select_examplars( + features, self._memory_per_class + )) - for (real_idxes, _), inputs, _ in loader: - inputs = inputs.to(self._device) - features.append(self._features_extractor(inputs).detach()) - idxes.extend(real_idxes.numpy().tolist()) + examplar_mean, alph = compute_examplar_mean( + features, features_flipped, self._herding_matrix[class_idx], self._memory_per_class + ) + self._data_memory.append(inputs[np.where(alph == 1)[0]]) + self._targets_memory.append(targets[np.where(alph == 1)[0]]) - features = F.normalize(torch.cat(features), dim=1) - mean = torch.mean(features, dim=0, keepdim=False) + self._class_means[class_idx, :] = examplar_mean - return features, mean, idxes + self._data_memory = np.concatenate(self._data_memory) + self._targets_memory = np.concatenate(self._targets_memory) - @staticmethod - def _remove_row(matrix, idxes, row_idx): - new_matrix = torch.cat((matrix[:row_idx, ...], matrix[row_idx + 1:, ...])) - del matrix - return new_matrix, idxes[:row_idx] + idxes[row_idx + 1:] + def get_memory(self): + return self._data_memory, self._targets_memory - @staticmethod - def _get_closest(centers, features): - pred_labels = [] - features = features - for feature in features: - distances = ICarl._dist(centers, feature) - pred_labels.append(distances.argmin().item()) +def extract_features(model, loader): + targets, features = [], [] - return np.array(pred_labels) + for _inputs, _targets in loader: + _targets = _targets.numpy() + _features = model.extract(_inputs.to(model.device)).detach().cpu().numpy() - @staticmethod - def _get_closest_features(center, features): - distances = ICarl._dist(center, features) - return distances.argmin().item() + features.append(_features) + targets.append(_targets) - @staticmethod - def _dist(a, b): - return torch.pow(a - b, 2).sum(-1) + return np.concatenate(features), np.concatenate(targets) - def _build_examplars(self, loader): - means = [] - lo, hi = 0, self._task * self._task_size - print("Updating examplars for classes {} -> {}.".format(lo, hi)) - for class_idx in range(lo, hi): - loader.dataset.set_idxes(self._examplars[class_idx]) - _, examplar_mean, _ = self._extract_features(loader) - means.append(F.normalize(examplar_mean, dim=0)) +def select_examplars(features, nb_max): + D = features.T + D = D / (np.linalg.norm(D, axis=0) + EPSILON) + mu = np.mean(D, axis=1) + herding_matrix = np.zeros((features.shape[0],)) - lo, hi = self._task * self._task_size, self._n_classes - print("Building examplars for classes {} -> {}.".format(lo, hi)) - for class_idx in range(lo, hi): - examplars_idxes = [] + w_t = mu + iter_herding, iter_herding_eff = 0, 0 - loader.dataset.set_classes_range(class_idx, class_idx) + while not ( + np.sum(herding_matrix != 0) == min(nb_max, features.shape[0]) + ) and iter_herding_eff < 1000: + tmp_t = np.dot(w_t, D) + ind_max = np.argmax(tmp_t) + iter_herding_eff += 1 + if herding_matrix[ind_max] == 0: + herding_matrix[ind_max] = 1 + iter_herding + iter_herding += 1 - features, class_mean, idxes = self._extract_features(loader) - examplars_mean = torch.zeros(self._features_extractor.out_dim, device=self._device) + w_t = w_t + mu - D[:, ind_max] - class_mean = F.normalize(class_mean, dim=0) + return herding_matrix - for i in range(min(self._m, features.shape[0])): - tmp = F.normalize( - (features + examplars_mean) / (i + 1), - dim=1 - ) - distances = self._dist(class_mean, tmp) - idxes_winner = distances.argsort().cpu().numpy() - for idx in idxes_winner: - real_idx = idxes[idx] - if real_idx in examplars_idxes: - continue +def compute_examplar_mean(feat_norm, feat_flip, herding_mat, nb_max): + D = feat_norm.T + D = D / (np.linalg.norm(D, axis=0) + EPSILON) - examplars_idxes.append(real_idx) - examplars_mean += features[idx] - break + D2 = feat_flip.T + D2 = D2 / (np.linalg.norm(D2, axis=0) + EPSILON) - means.append(F.normalize(examplars_mean / len(examplars_idxes), dim=0)) - self._examplars[class_idx] = examplars_idxes + alph = herding_mat + alph = (alph > 0) * (alph < nb_max + 1) * 1. - self._means = torch.stack(means) + alph_mean = alph / np.sum(alph) - @property - def examplars(self): - return np.array( - [ - examplar_idx - for class_examplars in self._examplars.values() - for examplar_idx in class_examplars - ] - ) + mean = (np.dot(D, alph_mean) + np.dot(D2, alph_mean)) / 2 + mean /= np.linalg.norm(mean) + + return mean, alph + + +def compute_accuracy(model, loader, class_means): + features, targets_ = extract_features(model, loader) + + targets = np.zeros((targets_.shape[0], 100), np.float32) + targets[range(len(targets_)), targets_.astype('int32')] = 1. + features = (features.T / (np.linalg.norm(features.T, axis=0) + EPSILON)).T + + # Compute score for iCaRL + sqd = cdist(class_means, features, 'sqeuclidean') + score_icarl = (-sqd).T - def _reduce_examplars(self): - print("Reducing examplars.") - for class_idx in range(len(self._examplars)): - self._examplars[class_idx] = self._examplars[class_idx][: self._m] + return np.argsort(score_icarl, axis=1)[:, -1], targets_ diff --git a/inclearn/models/lwf.py b/inclearn/models/lwf.py index aff45a2..fdd9b06 100644 --- a/inclearn/models/lwf.py +++ b/inclearn/models/lwf.py @@ -128,15 +128,15 @@ def _add_n_classes(self, n, convnet=None): self._n_classes += n def _compute_predictions(self, data_loader): - preds = torch.zeros(self._n_train_data, self._n_classes, device=self._device) + logits = torch.zeros(self._n_train_data, self._n_classes, device=self._device) for idxes, inputs, _ in data_loader: inputs = inputs.to(self._device) idxes = idxes[1].to(self._device) - preds[idxes] = self.forward(inputs).detach() + logits[idxes] = self.forward(inputs).detach() ** (1 / self._temperature) - return F.softmax(preds, dim=1) + return logits def _compute_loss(self, logits, targets, idxes): if self._task == 0: @@ -149,8 +149,7 @@ def _compute_loss(self, logits, targets, idxes): clf_loss = F.cross_entropy(logits[..., self._new_task_index:], targets, ignore_index=-1) distil_loss = F.binary_cross_entropy( - torch.softmax(logits[..., :self._new_task_index]**(1 / self._temperature), dim=1), - self._previous_preds[idxes]**(1 / self._temperature)) + F.softmax(logits[..., :self._new_task_index], dim=1), self._previous_preds[idxes]) return clf_loss, distil_loss diff --git a/inclearn/models/medic.py b/inclearn/models/medic.py new file mode 100644 index 0000000..f550adc --- /dev/null +++ b/inclearn/models/medic.py @@ -0,0 +1,300 @@ +import numpy as np +import torch +import tqdm +from torch import nn +from torch.nn import functional as F + +from inclearn import factory, utils +from inclearn.lib import callbacks, network +from inclearn.models.base import IncrementalLearner + +tqdm.monitor_interval = 0 + + +class Medic(IncrementalLearner): + """Implementation of: + + - Incremental Learning with Maximum Entropy Regularization: Rethinking + Forgetting and Intransigence. + + :param args: An argparse parsed arguments object. + """ + + def __init__(self, args): + super().__init__() + + self._device = args["device"] + self._opt_name = args["optimizer"] + self._lr = args["lr"] + self._weight_decay = args["weight_decay"] + self._n_epochs = args["epochs"] + + self._scheduling = args["scheduling"] + self._lr_decay = args["lr_decay"] + + self._k = args["memory_size"] + self._n_classes = 0 + self._epochs = args["epochs"] + + self._network = network.BasicNet( + args["convnet"], use_bias=True, use_multi_fc=False, device=self._device + ) + + self._examplars = {} + self._old_model = [] + + self._task_idxes = [] + + # ---------- + # Public API + # ---------- + + def eval(self): + self._network.eval() + + def train(self): + self._network.train() + + def _before_task(self, train_loader, val_loader): + """Set up before the task training can begin. + + 1. Precomputes previous model probabilities. + 2. Extend the classifier to support new classes. + + :param train_loader: The training dataloader. + :param val_loader: The validation dataloader. + """ + self._network.add_classes(self._task_size) + + self._task_idxes.append([self._n_classes + i for i in range(self._task_size)]) + + self._n_classes += self._task_size + print("Now {} examplars per class.".format(self._m)) + + def _train_task(self, train_loader, val_loader): + """Train & fine-tune model. + + :param train_loader: A DataLoader. + :param val_loader: A DataLoader, can be None. + """ + optimizer = factory.get_optimizer( + self._network.parameters(), self._opt_name, self._lr, self._weight_decay + ) + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, self._scheduling, gamma=self._lr_decay + ) + self._train(train_loader, val_loader, self._epochs, optimizer, scheduler) + + def _after_task(self, data_loader): + self._reduce_examplars() + self._build_examplars(data_loader) + self._old_model = self._network.copy().freeze() + + def _eval_task(self, data_loader): + ypred, ytrue = self._classify(data_loader) + assert ypred.shape == ytrue.shape + + return ypred, ytrue + + def get_memory_indexes(self): + return self.examplars + + # ----------- + # Private API + # ----------- + + def _train(self, train_loader, val_loader, n_epochs, optimizer, scheduler): + self._best_acc = float("-inf") + + print("nb ", len(train_loader.dataset)) + + val_acc = 0. + train_acc = 0. + for epoch in range(n_epochs): + scheduler.step() + + _clf_loss, _distil_loss = 0., 0. + c = 0 + + for i, ((_, idxes), inputs, targets) in enumerate(train_loader, start=1): + optimizer.zero_grad() + + c += 1 + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs) + + clf_loss, distil_loss = self._compute_loss( + inputs, + logits, + targets, + idxes, + ) + + if not utils._check_loss(clf_loss) or not utils._check_loss(distil_loss): + import pdb + pdb.set_trace() + + loss = clf_loss + distil_loss + + loss.backward() + + optimizer.step() + + _clf_loss += clf_loss.item() + _distil_loss += distil_loss.item() + + if val_loader: + self._network.eval() + ypred, ytrue = self._classify(val_loader) + val_acc = (ypred == ytrue).sum() / len(ytrue) + self._best_acc = max(self._best_acc, val_acc) + ypred, ytrue = self._classify(train_loader) + train_acc = (ypred == ytrue).sum() / len(ytrue) + self._network.train() + + print("Epoch {}/{}; Clf: {}; Distill: {}; Train: {}; Val: {}".format( + epoch, n_epochs, + round(_clf_loss / c, 3), + round(_distil_loss / c, 3), + round(train_acc, 3), + round(val_acc, 3), + ) + ) + + print("best", self._best_acc) + + def _compute_loss(self, inputs, logits, targets, idxes): + """Computes the classification loss & the distillation loss. + + Distillation loss is null at the first task. + + :param logits: Logits produced the model. + :param targets: The targets. + :param idxes: The real indexes of the just-processed images. Needed to + match the previous predictions. + :return: A tuple of the classification loss and the distillation loss. + """ + clf_loss = F.cross_entropy(logits, targets) + + if self._task == 0: + distil_loss = torch.zeros(1, device=self._device) + else: + last_index = len(self._task_idxes) - 1 + + distil_loss = 0. + with torch.no_grad(): + previous_logits = self._old_model(inputs) + + for i in range(last_index): + task_idxes = self._task_idxes[i] + + ce_loss = F.binary_cross_entropy( + F.softmax(logits[..., task_idxes], dim=1), + F.softmax(previous_logits[..., task_idxes], dim=1) + ) + entropy_loss = self.entropy(logits[..., task_idxes]) + + mer_loss = ce_loss - entropy_loss + if mer_loss < 0: + import pdb; pdb.set_trace() + + distil_loss += mer_loss + + return clf_loss, distil_loss + + @staticmethod + def entropy(p): + e = F.softmax(p, dim=1) * F.log_softmax(p, dim=1) + return -1.0 * e.mean() + + def _compute_predictions(self, loader): + """Precomputes the logits before a task. + + :param data_loader: A DataLoader. + :return: A tensor storing the whole current dataset logits. + """ + logits = torch.zeros(self._n_train_data, self._n_classes, device=self._device) + + for idxes, inputs, _ in loader: + inputs = inputs.to(self._device) + idxes = idxes[1].to(self._device) + + logits[idxes] = self._network(inputs).detach() + + return logits + + def _classify(self, loader): + """Classify the images given by the data loader. + + :param data_loader: A DataLoader. + :return: A numpy array of the predicted targets and a numpy array of the + ground-truth targets. + """ + ypred = [] + ytrue = [] + + for _, inputs, targets in loader: + inputs = inputs.to(self._device) + logits = F.softmax(self._network(inputs), dim=1) + preds = logits.argmax(dim=1).cpu().numpy() + + ypred.extend(preds) + ytrue.extend(targets) + + return np.array(ypred), np.array(ytrue) + + @property + def _m(self): + """Returns the number of examplars per class.""" + return self._k // self._n_classes + + def _build_examplars(self, loader, n_examplars=None): + """Builds new examplars. + + :param loader: A DataLoader. + :param n_examplars: Maximum number of examplars to create. + """ + n_examplars = n_examplars or self._m + + lo, hi = self._task * self._task_size, self._n_classes + print("Building examplars for classes {} -> {}.".format(lo, hi)) + for class_idx in range(lo, hi): + loader.dataset.set_classes_range(class_idx, class_idx) + self._examplars[class_idx] = self._build_class_examplars(loader, n_examplars) + + def _build_class_examplars(self, loader, n_examplars): + """Build examplars for a single class. + + Examplars are selected as the closest to the class mean. + + :param loader: DataLoader that provides images for a single class. + :param n_examplars: Maximum number of examplars to create. + :return: The real indexes of the chosen examplars. + """ + idxes = [] + for (real_idxes, _), _, _ in loader: + idxes.extend(real_idxes.numpy().tolist()) + idxes = np.array(idxes) + + nb_examplars = min(n_examplars, len(idxes)) + + np.random.shuffle(idxes) + return idxes[:nb_examplars] + + @property + def examplars(self): + """Returns all the real examplars indexes. + + :return: A numpy array of indexes. + """ + return np.array( + [ + examplar_idx for class_examplars in self._examplars.values() + for examplar_idx in class_examplars + ] + ) + + def _reduce_examplars(self): + print("Reducing examplars.") + for class_idx in range(len(self._examplars)): + self._examplars[class_idx] = self._examplars[class_idx][:self._m] diff --git a/inclearn/parser.py b/inclearn/parser.py index 5c86ba4..d1eabd5 100644 --- a/inclearn/parser.py +++ b/inclearn/parser.py @@ -22,13 +22,13 @@ def get_parser(): help="Temperature used to soften the predictions.") # Data related: - parser.add_argument("-d", "--dataset", default="iCIFAR100", type=str, + parser.add_argument("-d", "--dataset", default="cifar100", type=str, help="Dataset to test on.") parser.add_argument("-inc", "--increment", default=10, type=int, help="Number of class to add per task.") parser.add_argument("-b", "--batch-size", default=128, type=int, help="Batch size.") - parser.add_argument("-w", "--workers", default=10, type=int, + parser.add_argument("-w", "--workers", default=1, type=int, help="Number of workers preprocessing the data.") parser.add_argument("-v", "--validation", default=0., type=float, help="Validation split (0. <= x <= 1.).") @@ -44,7 +44,7 @@ def get_parser(): help="Weight decay.") parser.add_argument("-sc", "--scheduling", default=[50, 64], nargs="*", type=int, help="Epoch step where to reduce the learning rate.") - parser.add_argument("-lr-decay", "--lr-decay", default=1/5, type=int, + parser.add_argument("-lr-decay", "--lr-decay", default=1/5, type=float, help="LR multiplied by it.") parser.add_argument("-opt", "--optimizer", default="sgd", type=str, help="Optimizer to use.") @@ -61,4 +61,5 @@ def get_parser(): parser.add_argument("-seed-range", "--seed-range", type=int, nargs=2, help="Seed range going from first number to second (both included).") + return parser diff --git a/inclearn/train.py b/inclearn/train.py index a806c95..16dedba 100644 --- a/inclearn/train.py +++ b/inclearn/train.py @@ -1,10 +1,11 @@ import copy import random +import time import numpy as np import torch -from inclearn import factory, results_utils, utils +from inclearn.lib import factory, results_utils, utils def train(args): @@ -14,7 +15,10 @@ def train(args): for seed in seed_list: args["seed"] = seed args["device"] = device + + start_time = time.time() _train(args) + print("Training finished in {}s.".format(int(time.time() - start_time))) def _train(args): @@ -22,53 +26,56 @@ def _train(args): factory.set_device(args) - train_set = factory.get_data(args, train=True) - test_set = factory.get_data(args, train=False, classes_order=train_set.classes_order) - - train_loader, val_loader = train_set.get_loader(args["validation"]) - test_loader, _ = test_set.get_loader() - #val_loader = test_loader + inc_dataset = factory.get_data(args) + args["classes_order"] = inc_dataset.class_order model = factory.get_model(args) results = results_utils.get_template_results(args) - for task in range(0, train_set.total_n_classes // args["increment"]): - if args["max_task"] == task: - break + memory = None - # Setting current task's classes: - train_set.set_classes_range(low=task * args["increment"], - high=(task + 1) * args["increment"]) - test_set.set_classes_range(high=(task + 1) * args["increment"]) + for _ in range(inc_dataset.n_tasks): + task_info, train_loader, val_loader, test_loader = inc_dataset.new_task(memory) + if task_info["task"] == args["max_task"]: + break model.set_task_info( - task, - train_set.total_n_classes, - args["increment"], - len(train_set), - len(test_set) + task=task_info["task"], + total_n_classes=task_info["max_class"], + increment=task_info["increment"], + n_train_data=task_info["n_train_data"], + n_test_data=task_info["n_test_data"], + n_tasks=task_info["max_task"] ) + model.eval() model.before_task(train_loader, val_loader) - print("train", task * args["increment"], (task + 1) * args["increment"]) + print("Train on {}->{}.".format(task_info["min_class"], task_info["max_class"])) + model.train() model.train_task(train_loader, val_loader) - model.after_task(train_loader) + model.eval() + model.after_task(inc_dataset) + print("Eval on {}->{}.".format(0, task_info["max_class"])) ypred, ytrue = model.eval_task(test_loader) acc_stats = utils.compute_accuracy(ypred, ytrue, task_size=args["increment"]) print(acc_stats) results["results"].append(acc_stats) - memory_indexes = model.get_memory_indexes() - train_set.set_memory(memory_indexes) + memory = model.get_memory() + + print( + "Average Incremental Accuracy: {}.".format( + results_utils.compute_avg_inc_acc(results["results"]) + ) + ) if args["name"]: results_utils.save_results(results, args["name"]) del model - del train_set - del test_set + del inc_dataset torch.cuda.empty_cache()