diff --git a/README.md b/README.md index 63c18f1..b7a0385 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,11 @@ Note that the package is not intended for general purpose domain adaptation. Ins We will release the first stable version of the package on PyPI. Until then, you can install directly from the main repo: ```bash +# Latest stable release: pip install git+git://github.com/bethgelab/robustness.git + +# Release candidate: +pip install git+git://github.com/bethgelab/robustness.git@rc ``` Here is an example for how to use `robusta` for batchnorm adaptation & robust pseudo-labeling. diff --git a/examples/batchnorm/README.md b/examples/batchnorm/README.md index 8b0352b..ed170bb 100644 --- a/examples/batchnorm/README.md +++ b/examples/batchnorm/README.md @@ -5,7 +5,7 @@ Steffen Schneider*, Evgenia Rusak*, Luisa Eck, Oliver Bringmann, Wieland Brendel Website: [domainadaptation.org/batchnorm](https://domainadaptation.org/batchnorm) This repository contains evaluation code for the paper *Improving robustness against common corruptions by covariate shift adaptation*. -We will release the code in the upcoming weeks. To get notified, watch and/or star this repository to get notified of updates! +The repository is updated frequently. To get notified, watch and/or star this repository! Today's state-of-the-art machine vision models are vulnerable to image corruptions like blurring or compression artefacts, limiting their performance in many real-world applications. We here argue that popular benchmarks to measure model robustness against common corruptions (like ImageNet-C) underestimate model robustness in many (but not all) application scenarios. The key insight is that in many scenarios, multiple unlabeled examples of the corruptions are available and can be used for unsupervised online adaptation. Replacing the activation statistics estimated by batch normalization on the training set with the statistics of the corrupted images consistently improves the robustness across 25 different popular computer vision models. Using the corrected statistics, ResNet-50 reaches 62.2% mCE on ImageNet-C compared to 76.7% without adaptation. With the more robust AugMix model, we improve the state of the art from 56.5% mCE to 51.0% mCE. Even adapting to a single sample improves robustness for the ResNet-50 and AugMix models, and 32 samples are sufficient to improve the current state of the art for a ResNet-50 architecture. We argue that results with adapted statistics should be included whenever reporting scores in corruption benchmarks and other out-of-distribution generalization settings @@ -26,6 +26,7 @@ With a simple recalculation of batch normalization statistics, we improve the me | [DeepAugment+AugMix](https://github.com/hendrycks/imagenet-r) | 53.6 | 48.4 |45.4| | [DeepAug+AM+RNXt101](https://github.com/hendrycks/imagenet-r) | **44.5** |**40.7** | **38.0** | + ### Results for models trained with [Fixup](https://github.com/hongyi-zhang/Fixup) and [GroupNorm](https://github.com/ppwwyyxx/GroupNorm-reproduce) on ImageNet-C Fixup and GN trained models perform better than non-adapted BN models but worse than adapted BN models. @@ -36,6 +37,26 @@ Fixup and GN trained models perform better than non-adapted BN models but worse |ResNet-101 |68.2 |67.6 |69.0 |**59.1**| |ResNet-152 |67.6 |65.4 |69.3 |**58.0**| +### To reproduce the first table above + +Run [`scripts/paper/table1.sh`](scripts/paper/table1.sh): +```sh +row="2" # This is the row to compute from the table +docker run -v "$IMAGENET_C_PATH":/ImageNet-C:ro \ + -v "$CHECKPOINT_PATH":/checkpoints:ro \ + -v .:/batchnorm \ + -v ..:/deps \ + -it georgepachitariu/robustness:latest \ + bash /batchnorm/scripts/paper/table1.sh $row 2>&1 +``` +The script file requires 2 dependencies: +1. `IMANGENETC_PATH="/ImageNet-C"` + This is the path where you store the ImageNet-C dataset. The dataset is described [here](https://github.com/hendrycks/robustness) and you can download it from [here](https://zenodo.org/record/2235448#.YJjcNyaxWcw). + +2. `CHECKPOINT_PATH="/checkpoints"` + This is the path where you store our checkpoints. + You can download them from here: TODO. + ## News diff --git a/examples/batchnorm/scripts/paper/table1.sbatch b/examples/batchnorm/scripts/paper/table1.sbatch index 83ec57f..250af0a 100644 --- a/examples/batchnorm/scripts/paper/table1.sbatch +++ b/examples/batchnorm/scripts/paper/table1.sbatch @@ -7,7 +7,7 @@ scontrol show job "$SLURM_JOB_ID" # The image georgepachitariu/robustness was created using -# the Dockerfile from parent folder. +# the Dockerfile from main repository folder. row="2" # This is the row in the table singularity exec --nv -B /scratch_local \ -B "$IMAGENET_C_PATH":/ImageNet-C:ro \ diff --git a/robusta/__init__.py b/robusta/__init__.py index 903dbcf..1f8bffa 100644 --- a/robusta/__init__.py +++ b/robusta/__init__.py @@ -17,7 +17,6 @@ # This licence notice applies to all originally written code by the # authors. Code taken from other open-source projects is indicated. # See NOTICE for a list of all third-party licences used in the project. - """A package for robustness and adaptation on ImageNet scale.""" from robusta import batchnorm diff --git a/robusta/batchnorm/bn.py b/robusta/batchnorm/bn.py index 52368cf..b79187d 100644 --- a/robusta/batchnorm/bn.py +++ b/robusta/batchnorm/bn.py @@ -17,7 +17,6 @@ # This licence notice applies to all originally written code by the # authors. Code taken from other open-source projects is indicated. # See NOTICE for a list of all third-party licences used in the project. - """ Batch norm variants """ @@ -39,6 +38,7 @@ def adapt_bayesian(model: nn.Module, prior: float): class PartlyAdaptiveBN(nn.Module): + @staticmethod def find_bns(parent, estimate_mean, estimate_var): replace_mods = [] @@ -52,8 +52,7 @@ def find_bns(parent, estimate_mean, estimate_var): else: replace_mods.extend( PartlyAdaptiveBN.find_bns(child, estimate_mean, - estimate_var) - ) + estimate_var)) return replace_mods @@ -129,6 +128,7 @@ def forward(self, input): class EMABatchNorm(nn.Module): + @staticmethod def reset_stats(module): module.reset_running_stats() @@ -205,23 +205,19 @@ def __init__(self, layer, prior): self.layer = layer self.layer.eval() - self.norm = nn.BatchNorm2d( - self.layer.num_features, affine=False, momentum=1.0 - ) + self.norm = nn.BatchNorm2d(self.layer.num_features, + affine=False, + momentum=1.0) self.prior = prior def forward(self, input): self.norm(input) - running_mean = ( - self.prior * self.layer.running_mean - + (1 - self.prior) * self.norm.running_mean - ) - running_var = ( - self.prior * self.layer.running_var - + (1 - self.prior) * self.norm.running_var - ) + running_mean = (self.prior * self.layer.running_mean + + (1 - self.prior) * self.norm.running_mean) + running_var = (self.prior * self.layer.running_var + + (1 - self.prior) * self.norm.running_var) return F.batch_norm( input, diff --git a/robusta/batchnorm/stages.py b/robusta/batchnorm/stages.py index fc2aa2c..d29ffcb 100644 --- a/robusta/batchnorm/stages.py +++ b/robusta/batchnorm/stages.py @@ -17,12 +17,12 @@ # This licence notice applies to all originally written code by the # authors. Code taken from other open-source projects is indicated. # See NOTICE for a list of all third-party licences used in the project. - """ Helper functions for stages ablations """ import torchvision from torch import nn + def split_model(model): if not isinstance(model, torchvision.models.ResNet): print("Only resnet models defined for this analysis so far") diff --git a/robusta/datasets/__init__.py b/robusta/datasets/__init__.py index df5d7bc..5a466c8 100644 --- a/robusta/datasets/__init__.py +++ b/robusta/datasets/__init__.py @@ -18,7 +18,6 @@ # authors. Code taken from other open-source projects is indicated. # See NOTICE for a list of all third-party licences used in the project. - from robusta.datasets import base from robusta.datasets import imagenet200 from robusta.datasets import imageneta diff --git a/robusta/datasets/base.py b/robusta/datasets/base.py index 2ed6f2a..c9fa481 100644 --- a/robusta/datasets/base.py +++ b/robusta/datasets/base.py @@ -21,10 +21,11 @@ import torchvision.datasets import torchvision.transforms + class TorchvisionTransform(torchvision.transforms.Compose): """Standard torchvision transform for cropped and non-cropped datasets.""" - def __init__(self, resize = False): + def __init__(self, resize=False): self.resize = resize self.mean = [0.485, 0.456, 0.406] @@ -33,31 +34,285 @@ def __init__(self, resize = False): torchvision.transforms.Resize(256), torchvision.transforms.CenterCrop(224), torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize( - self.mean, self.std - ) + torchvision.transforms.Normalize(self.mean, self.std) ]) + class ImageNetRobustnessDataset(torchvision.datasets.ImageFolder): - def __init__(self, dataset_dir, transform = None, **kwargs): + def __init__(self, dataset_dir, transform=None, **kwargs): if transform == "torchvision": transform = TorchvisionTransform() - super().__init__(dataset_dir, transform = transform, **kwargs) + super().__init__(dataset_dir, transform=transform, **kwargs) def accuracy_metric(self, logits, targets): - raise NotImplementedError() + pred = logits.data.max(1)[1] + return pred.eq(targets.data).sum().item() -class RemappedImageNet(): - def __init__(self): - super().__init__() +class RemappedImageNet(ImageNetRobustnessDataset): + """This dataset is used for the ImageNet-A, ImageNet-R and the ImageNet200 datasets which + require a remapping of 1000 ImageNet classes to the 200 classes of ImageNet-R. + """ - def map_logits(self, logits): - output = logits[:, imagenet_r_mask] + def map_logits(self, logits, mask): + output = logits[:, mask] return output - def accuracy_metric(self, logits, targets): - logits200 = self.map_logits(logits) - super().accuracy_metric(logits200, targets) + def accuracy_metric(self, logits, targets, mask): + logits200 = self.map_logits(logits, mask) + return super().accuracy_metric(logits200, targets) + + +class ImageNetRClasses(): + + def get_imagenet_wnids(): + imagenet_r_wnids = { + 'n01443537', 'n01484850', 'n01494475', 'n01498041', 'n01514859', + 'n01518878', 'n01531178', 'n01534433', 'n01614925', 'n01616318', + 'n01630670', 'n01632777', 'n01644373', 'n01677366', 'n01694178', + 'n01748264', 'n01770393', 'n01774750', 'n01784675', 'n01806143', + 'n01820546', 'n01833805', 'n01843383', 'n01847000', 'n01855672', + 'n01860187', 'n01882714', 'n01910747', 'n01944390', 'n01983481', + 'n01986214', 'n02007558', 'n02009912', 'n02051845', 'n02056570', + 'n02066245', 'n02071294', 'n02077923', 'n02085620', 'n02086240', + 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02091032', + 'n02091134', 'n02092339', 'n02094433', 'n02096585', 'n02097298', + 'n02098286', 'n02099601', 'n02099712', 'n02102318', 'n02106030', + 'n02106166', 'n02106550', 'n02106662', 'n02108089', 'n02108915', + 'n02109525', 'n02110185', 'n02110341', 'n02110958', 'n02112018', + 'n02112137', 'n02113023', 'n02113624', 'n02113799', 'n02114367', + 'n02117135', 'n02119022', 'n02123045', 'n02128385', 'n02128757', + 'n02129165', 'n02129604', 'n02130308', 'n02134084', 'n02138441', + 'n02165456', 'n02190166', 'n02206856', 'n02219486', 'n02226429', + 'n02233338', 'n02236044', 'n02268443', 'n02279972', 'n02317335', + 'n02325366', 'n02346627', 'n02356798', 'n02363005', 'n02364673', + 'n02391049', 'n02395406', 'n02398521', 'n02410509', 'n02423022', + 'n02437616', 'n02445715', 'n02447366', 'n02480495', 'n02480855', + 'n02481823', 'n02483362', 'n02486410', 'n02510455', 'n02526121', + 'n02607072', 'n02655020', 'n02672831', 'n02701002', 'n02749479', + 'n02769748', 'n02793495', 'n02797295', 'n02802426', 'n02808440', + 'n02814860', 'n02823750', 'n02841315', 'n02843684', 'n02883205', + 'n02906734', 'n02909870', 'n02939185', 'n02948072', 'n02950826', + 'n02951358', 'n02966193', 'n02980441', 'n02992529', 'n03124170', + 'n03272010', 'n03345487', 'n03372029', 'n03424325', 'n03452741', + 'n03467068', 'n03481172', 'n03494278', 'n03495258', 'n03498962', + 'n03594945', 'n03602883', 'n03630383', 'n03649909', 'n03676483', + 'n03710193', 'n03773504', 'n03775071', 'n03888257', 'n03930630', + 'n03947888', 'n04086273', 'n04118538', 'n04133789', 'n04141076', + 'n04146614', 'n04147183', 'n04192698', 'n04254680', 'n04266014', + 'n04275548', 'n04310018', 'n04325704', 'n04347754', 'n04389033', + 'n04409515', 'n04465501', 'n04487394', 'n04522168', 'n04536866', + 'n04552348', 'n04591713', 'n07614500', 'n07693725', 'n07695742', + 'n07697313', 'n07697537', 'n07714571', 'n07714990', 'n07718472', + 'n07720875', 'n07734744', 'n07742313', 'n07745940', 'n07749582', + 'n07753275', 'n07753592', 'n07768694', 'n07873807', 'n07880968', + 'n07920052', 'n09472597', 'n09835506', 'n10565667', 'n12267677' + } + return imagenet_r_wnids + def get_class_mask(): + all_wnids = [ + 'n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', + 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', + 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', + 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', + 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', + 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', + 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', + 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', + 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', + 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', + 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', + 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', + 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', + 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', + 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', + 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', + 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', + 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', + 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', + 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', + 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', + 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', + 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', + 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', + 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', + 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', + 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', + 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', + 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', + 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', + 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', + 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', + 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', + 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', + 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', + 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', + 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', + 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', + 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', + 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', + 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', + 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', + 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', + 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', + 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', + 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', + 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', + 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', + 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', + 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', + 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', + 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', + 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', + 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', + 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', + 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', + 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', + 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', + 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', + 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', + 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', + 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', + 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', + 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', + 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', + 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', + 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', + 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', + 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', + 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', + 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', + 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', + 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', + 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', + 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', + 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', + 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', + 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', + 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', + 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', + 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', + 'n02692877', 'n02699494', 'n02701002', 'n02704792', 'n02708093', + 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', + 'n02776631', 'n02777292', 'n02782093', 'n02783161', 'n02786058', + 'n02787622', 'n02788148', 'n02790996', 'n02791124', 'n02791270', + 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', + 'n02802426', 'n02804414', 'n02804610', 'n02807133', 'n02808304', + 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', + 'n02823428', 'n02823750', 'n02825657', 'n02834397', 'n02835271', + 'n02837789', 'n02840245', 'n02841315', 'n02843684', 'n02859443', + 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02871525', + 'n02877765', 'n02879718', 'n02883205', 'n02892201', 'n02892767', + 'n02894605', 'n02895154', 'n02906734', 'n02909870', 'n02910353', + 'n02916936', 'n02917067', 'n02927161', 'n02930766', 'n02939185', + 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', + 'n02965783', 'n02966193', 'n02966687', 'n02971356', 'n02974003', + 'n02977058', 'n02978881', 'n02979186', 'n02980441', 'n02981792', + 'n02988304', 'n02992211', 'n02992529', 'n02999410', 'n03000134', + 'n03000247', 'n03000684', 'n03014705', 'n03016953', 'n03017168', + 'n03018349', 'n03026506', 'n03028079', 'n03032252', 'n03041632', + 'n03042490', 'n03045698', 'n03047690', 'n03062245', 'n03063599', + 'n03063689', 'n03065424', 'n03075370', 'n03085013', 'n03089624', + 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', + 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', + 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', + 'n03160309', 'n03179701', 'n03180011', 'n03187595', 'n03188531', + 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', + 'n03208938', 'n03216828', 'n03218198', 'n03220513', 'n03223299', + 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', + 'n03271574', 'n03272010', 'n03272562', 'n03290653', 'n03291819', + 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', + 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', + 'n03379051', 'n03384352', 'n03388043', 'n03388183', 'n03388549', + 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', + 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', + 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', + 'n03457902', 'n03459775', 'n03461385', 'n03467068', 'n03476684', + 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', + 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', + 'n03496892', 'n03498962', 'n03527444', 'n03529860', 'n03530642', + 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', + 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', + 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', + 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03637318', + 'n03642806', 'n03649909', 'n03657121', 'n03658185', 'n03661043', + 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', + 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03697007', + 'n03706229', 'n03709823', 'n03710193', 'n03710637', 'n03710721', + 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', + 'n03733131', 'n03733281', 'n03733805', 'n03742115', 'n03743016', + 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', + 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', + 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', + 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', + 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', + 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', + 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', + 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', + 'n03868242', 'n03868863', 'n03871628', 'n03873416', 'n03874293', + 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', + 'n03887697', 'n03888257', 'n03888605', 'n03891251', 'n03891332', + 'n03895866', 'n03899768', 'n03902125', 'n03903868', 'n03908618', + 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', + 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', + 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', + 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', + 'n03967562', 'n03970156', 'n03976467', 'n03976657', 'n03977966', + 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', + 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04008634', + 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', + 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', + 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', + 'n04070727', 'n04074963', 'n04081281', 'n04086273', 'n04090263', + 'n04099969', 'n04111531', 'n04116512', 'n04118538', 'n04118776', + 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', + 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', + 'n04147183', 'n04149813', 'n04152593', 'n04153751', 'n04154565', + 'n04162706', 'n04179913', 'n04192698', 'n04200800', 'n04201297', + 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', + 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04239074', + 'n04243546', 'n04251144', 'n04252077', 'n04252225', 'n04254120', + 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', + 'n04264628', 'n04265275', 'n04266014', 'n04270147', 'n04273569', + 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04296562', + 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', + 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', + 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', + 'n04355338', 'n04355933', 'n04356056', 'n04357314', 'n04366367', + 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04372370', + 'n04376876', 'n04380533', 'n04389033', 'n04392985', 'n04398044', + 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', + 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', + 'n04443257', 'n04447861', 'n04456115', 'n04458633', 'n04461696', + 'n04462240', 'n04465501', 'n04467665', 'n04476259', 'n04479046', + 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', + 'n04487394', 'n04493381', 'n04501370', 'n04505470', 'n04507155', + 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04523525', + 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', + 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', + 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', + 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', + 'n04590129', 'n04591157', 'n04591713', 'n04592741', 'n04596742', + 'n04597913', 'n04599235', 'n04604644', 'n04606251', 'n04612504', + 'n04613696', 'n06359193', 'n06596364', 'n06785654', 'n06794110', + 'n06874185', 'n07248320', 'n07565083', 'n07579787', 'n07583066', + 'n07584110', 'n07590611', 'n07613480', 'n07614500', 'n07615774', + 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', + 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', + 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', + 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', + 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', + 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', + 'n07836838', 'n07860988', 'n07871810', 'n07873807', 'n07875152', + 'n07880968', 'n07892512', 'n07920052', 'n07930864', 'n07932039', + 'n09193705', 'n09229709', 'n09246464', 'n09256479', 'n09288635', + 'n09332890', 'n09399592', 'n09421951', 'n09428293', 'n09468604', + 'n09472597', 'n09835506', 'n10148035', 'n10565667', 'n11879895', + 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', + 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', + 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141' + ] + imagenet_r_wnids = ImageNetRClasses.get_imagenet_wnids() + return [wnid in imagenet_r_wnids for wnid in all_wnids] diff --git a/robusta/datasets/imagenet200.py b/robusta/datasets/imagenet200.py index 89165ee..6447e3c 100644 --- a/robusta/datasets/imagenet200.py +++ b/robusta/datasets/imagenet200.py @@ -17,7 +17,6 @@ # This licence notice applies to all originally written code by the # authors. Code taken from other open-source projects is indicated. # See NOTICE for a list of all third-party licences used in the project. - """ ImageNet-R Reference: https://github.com/hendrycks/imagenet-r @@ -47,16 +46,14 @@ SOFTWARE. """ -from copy import copy import os import torch import torch.nn as nn import torch.nn.functional as F import numpy as np -from tqdm import tqdm import shutil - from robusta.datasets.base import RemappedImageNet +from robusta.datasets.base import ImageNetRClasses class ImageNet200(RemappedImageNet): @@ -65,28 +62,27 @@ class ImageNet200(RemappedImageNet): Reference: https://github.com/hendrycks/imagenet-r """ - def get_class_mask(self): - all_wnids = ['n01440764', 'n01443537', 'n01484850', 'n01491361', 'n01494475', 'n01496331', 'n01498041', 'n01514668', 'n01514859', 'n01518878', 'n01530575', 'n01531178', 'n01532829', 'n01534433', 'n01537544', 'n01558993', 'n01560419', 'n01580077', 'n01582220', 'n01592084', 'n01601694', 'n01608432', 'n01614925', 'n01616318', 'n01622779', 'n01629819', 'n01630670', 'n01631663', 'n01632458', 'n01632777', 'n01641577', 'n01644373', 'n01644900', 'n01664065', 'n01665541', 'n01667114', 'n01667778', 'n01669191', 'n01675722', 'n01677366', 'n01682714', 'n01685808', 'n01687978', 'n01688243', 'n01689811', 'n01692333', 'n01693334', 'n01694178', 'n01695060', 'n01697457', 'n01698640', 'n01704323', 'n01728572', 'n01728920', 'n01729322', 'n01729977', 'n01734418', 'n01735189', 'n01737021', 'n01739381', 'n01740131', 'n01742172', 'n01744401', 'n01748264', 'n01749939', 'n01751748', 'n01753488', 'n01755581', 'n01756291', 'n01768244', 'n01770081', 'n01770393', 'n01773157', 'n01773549', 'n01773797', 'n01774384', 'n01774750', 'n01775062', 'n01776313', 'n01784675', 'n01795545', 'n01796340', 'n01797886', 'n01798484', 'n01806143', 'n01806567', 'n01807496', 'n01817953', 'n01818515', 'n01819313', 'n01820546', 'n01824575', 'n01828970', 'n01829413', 'n01833805', 'n01843065', 'n01843383', 'n01847000', 'n01855032', 'n01855672', 'n01860187', 'n01871265', 'n01872401', 'n01873310', 'n01877812', 'n01882714', 'n01883070', 'n01910747', 'n01914609', 'n01917289', 'n01924916', 'n01930112', 'n01943899', 'n01944390', 'n01945685', 'n01950731', 'n01955084', 'n01968897', 'n01978287', 'n01978455', 'n01980166', 'n01981276', 'n01983481', 'n01984695', 'n01985128', 'n01986214', 'n01990800', 'n02002556', 'n02002724', 'n02006656', 'n02007558', 'n02009229', 'n02009912', 'n02011460', 'n02012849', 'n02013706', 'n02017213', 'n02018207', 'n02018795', 'n02025239', 'n02027492', 'n02028035', 'n02033041', 'n02037110', 'n02051845', 'n02056570', 'n02058221', 'n02066245', 'n02071294', 'n02074367', 'n02077923', 'n02085620', 'n02085782', 'n02085936', 'n02086079', 'n02086240', 'n02086646', 'n02086910', 'n02087046', 'n02087394', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02088632', 'n02089078', 'n02089867', 'n02089973', 'n02090379', 'n02090622', 'n02090721', 'n02091032', 'n02091134', 'n02091244', 'n02091467', 'n02091635', 'n02091831', 'n02092002', 'n02092339', 'n02093256', 'n02093428', 'n02093647', 'n02093754', 'n02093859', 'n02093991', 'n02094114', 'n02094258', 'n02094433', 'n02095314', 'n02095570', 'n02095889', 'n02096051', 'n02096177', 'n02096294', 'n02096437', 'n02096585', 'n02097047', 'n02097130', 'n02097209', 'n02097298', 'n02097474', 'n02097658', 'n02098105', 'n02098286', 'n02098413', 'n02099267', 'n02099429', 'n02099601', 'n02099712', 'n02099849', 'n02100236', 'n02100583', 'n02100735', 'n02100877', 'n02101006', 'n02101388', 'n02101556', 'n02102040', 'n02102177', 'n02102318', 'n02102480', 'n02102973', 'n02104029', 'n02104365', 'n02105056', 'n02105162', 'n02105251', 'n02105412', 'n02105505', 'n02105641', 'n02105855', 'n02106030', 'n02106166', 'n02106382', 'n02106550', 'n02106662', 'n02107142', 'n02107312', 'n02107574', 'n02107683', 'n02107908', 'n02108000', 'n02108089', 'n02108422', 'n02108551', 'n02108915', 'n02109047', 'n02109525', 'n02109961', 'n02110063', 'n02110185', 'n02110341', 'n02110627', 'n02110806', 'n02110958', 'n02111129', 'n02111277', 'n02111500', 'n02111889', 'n02112018', 'n02112137', 'n02112350', 'n02112706', 'n02113023', 'n02113186', 'n02113624', 'n02113712', 'n02113799', 'n02113978', 'n02114367', 'n02114548', 'n02114712', 'n02114855', 'n02115641', 'n02115913', 'n02116738', 'n02117135', 'n02119022', 'n02119789', 'n02120079', 'n02120505', 'n02123045', 'n02123159', 'n02123394', 'n02123597', 'n02124075', 'n02125311', 'n02127052', 'n02128385', 'n02128757', 'n02128925', 'n02129165', 'n02129604', 'n02130308', 'n02132136', 'n02133161', 'n02134084', 'n02134418', 'n02137549', 'n02138441', 'n02165105', 'n02165456', 'n02167151', 'n02168699', 'n02169497', 'n02172182', 'n02174001', 'n02177972', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02229544', 'n02231487', 'n02233338', 'n02236044', 'n02256656', 'n02259212', 'n02264363', 'n02268443', 'n02268853', 'n02276258', 'n02277742', 'n02279972', 'n02280649', 'n02281406', 'n02281787', 'n02317335', 'n02319095', 'n02321529', 'n02325366', 'n02326432', 'n02328150', 'n02342885', 'n02346627', 'n02356798', 'n02361337', 'n02363005', 'n02364673', 'n02389026', 'n02391049', 'n02395406', 'n02396427', 'n02397096', 'n02398521', 'n02403003', 'n02408429', 'n02410509', 'n02412080', 'n02415577', 'n02417914', 'n02422106', 'n02422699', 'n02423022', 'n02437312', 'n02437616', 'n02441942', 'n02442845', 'n02443114', 'n02443484', 'n02444819', 'n02445715', 'n02447366', 'n02454379', 'n02457408', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02483708', 'n02484975', 'n02486261', 'n02486410', 'n02487347', 'n02488291', 'n02488702', 'n02489166', 'n02490219', 'n02492035', 'n02492660', 'n02493509', 'n02493793', 'n02494079', 'n02497673', 'n02500267', 'n02504013', 'n02504458', 'n02509815', 'n02510455', 'n02514041', 'n02526121', 'n02536864', 'n02606052', 'n02607072', 'n02640242', 'n02641379', 'n02643566', 'n02655020', 'n02666196', 'n02667093', 'n02669723', 'n02672831', 'n02676566', 'n02687172', 'n02690373', 'n02692877', 'n02699494', 'n02701002', 'n02704792', 'n02708093', 'n02727426', 'n02730930', 'n02747177', 'n02749479', 'n02769748', 'n02776631', 'n02777292', 'n02782093', 'n02783161', 'n02786058', 'n02787622', 'n02788148', 'n02790996', 'n02791124', 'n02791270', 'n02793495', 'n02794156', 'n02795169', 'n02797295', 'n02799071', 'n02802426', 'n02804414', 'n02804610', 'n02807133', 'n02808304', 'n02808440', 'n02814533', 'n02814860', 'n02815834', 'n02817516', 'n02823428', 'n02823750', 'n02825657', 'n02834397', 'n02835271', 'n02837789', 'n02840245', 'n02841315', 'n02843684', 'n02859443', 'n02860847', 'n02865351', 'n02869837', 'n02870880', 'n02871525', 'n02877765', 'n02879718', 'n02883205', 'n02892201', 'n02892767', 'n02894605', 'n02895154', 'n02906734', 'n02909870', 'n02910353', 'n02916936', 'n02917067', 'n02927161', 'n02930766', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02951585', 'n02963159', 'n02965783', 'n02966193', 'n02966687', 'n02971356', 'n02974003', 'n02977058', 'n02978881', 'n02979186', 'n02980441', 'n02981792', 'n02988304', 'n02992211', 'n02992529', 'n02999410', 'n03000134', 'n03000247', 'n03000684', 'n03014705', 'n03016953', 'n03017168', 'n03018349', 'n03026506', 'n03028079', 'n03032252', 'n03041632', - 'n03042490', 'n03045698', 'n03047690', 'n03062245', 'n03063599', 'n03063689', 'n03065424', 'n03075370', 'n03085013', 'n03089624', 'n03095699', 'n03100240', 'n03109150', 'n03110669', 'n03124043', 'n03124170', 'n03125729', 'n03126707', 'n03127747', 'n03127925', 'n03131574', 'n03133878', 'n03134739', 'n03141823', 'n03146219', 'n03160309', 'n03179701', 'n03180011', 'n03187595', 'n03188531', 'n03196217', 'n03197337', 'n03201208', 'n03207743', 'n03207941', 'n03208938', 'n03216828', 'n03218198', 'n03220513', 'n03223299', 'n03240683', 'n03249569', 'n03250847', 'n03255030', 'n03259280', 'n03271574', 'n03272010', 'n03272562', 'n03290653', 'n03291819', 'n03297495', 'n03314780', 'n03325584', 'n03337140', 'n03344393', 'n03345487', 'n03347037', 'n03355925', 'n03372029', 'n03376595', 'n03379051', 'n03384352', 'n03388043', 'n03388183', 'n03388549', 'n03393912', 'n03394916', 'n03400231', 'n03404251', 'n03417042', 'n03424325', 'n03425413', 'n03443371', 'n03444034', 'n03445777', 'n03445924', 'n03447447', 'n03447721', 'n03450230', 'n03452741', 'n03457902', 'n03459775', 'n03461385', 'n03467068', 'n03476684', 'n03476991', 'n03478589', 'n03481172', 'n03482405', 'n03483316', 'n03485407', 'n03485794', 'n03492542', 'n03494278', 'n03495258', 'n03496892', 'n03498962', 'n03527444', 'n03529860', 'n03530642', 'n03532672', 'n03534580', 'n03535780', 'n03538406', 'n03544143', 'n03584254', 'n03584829', 'n03590841', 'n03594734', 'n03594945', 'n03595614', 'n03598930', 'n03599486', 'n03602883', 'n03617480', 'n03623198', 'n03627232', 'n03630383', 'n03633091', 'n03637318', 'n03642806', 'n03649909', 'n03657121', 'n03658185', 'n03661043', 'n03662601', 'n03666591', 'n03670208', 'n03673027', 'n03676483', 'n03680355', 'n03690938', 'n03691459', 'n03692522', 'n03697007', 'n03706229', 'n03709823', 'n03710193', 'n03710637', 'n03710721', 'n03717622', 'n03720891', 'n03721384', 'n03724870', 'n03729826', 'n03733131', 'n03733281', 'n03733805', 'n03742115', 'n03743016', 'n03759954', 'n03761084', 'n03763968', 'n03764736', 'n03769881', 'n03770439', 'n03770679', 'n03773504', 'n03775071', 'n03775546', 'n03776460', 'n03777568', 'n03777754', 'n03781244', 'n03782006', 'n03785016', 'n03786901', 'n03787032', 'n03788195', 'n03788365', 'n03791053', 'n03792782', 'n03792972', 'n03793489', 'n03794056', 'n03796401', 'n03803284', 'n03804744', 'n03814639', 'n03814906', 'n03825788', 'n03832673', 'n03837869', 'n03838899', 'n03840681', 'n03841143', 'n03843555', 'n03854065', 'n03857828', 'n03866082', 'n03868242', 'n03868863', 'n03871628', 'n03873416', 'n03874293', 'n03874599', 'n03876231', 'n03877472', 'n03877845', 'n03884397', 'n03887697', 'n03888257', 'n03888605', 'n03891251', 'n03891332', 'n03895866', 'n03899768', 'n03902125', 'n03903868', 'n03908618', 'n03908714', 'n03916031', 'n03920288', 'n03924679', 'n03929660', 'n03929855', 'n03930313', 'n03930630', 'n03933933', 'n03935335', 'n03937543', 'n03938244', 'n03942813', 'n03944341', 'n03947888', 'n03950228', 'n03954731', 'n03956157', 'n03958227', 'n03961711', 'n03967562', 'n03970156', 'n03976467', 'n03976657', 'n03977966', 'n03980874', 'n03982430', 'n03983396', 'n03991062', 'n03992509', 'n03995372', 'n03998194', 'n04004767', 'n04005630', 'n04008634', 'n04009552', 'n04019541', 'n04023962', 'n04026417', 'n04033901', 'n04033995', 'n04037443', 'n04039381', 'n04040759', 'n04041544', 'n04044716', 'n04049303', 'n04065272', 'n04067472', 'n04069434', 'n04070727', 'n04074963', 'n04081281', 'n04086273', 'n04090263', 'n04099969', 'n04111531', 'n04116512', 'n04118538', 'n04118776', 'n04120489', 'n04125021', 'n04127249', 'n04131690', 'n04133789', 'n04136333', 'n04141076', 'n04141327', 'n04141975', 'n04146614', 'n04147183', 'n04149813', 'n04152593', 'n04153751', 'n04154565', 'n04162706', 'n04179913', 'n04192698', 'n04200800', 'n04201297', 'n04204238', 'n04204347', 'n04208210', 'n04209133', 'n04209239', 'n04228054', 'n04229816', 'n04235860', 'n04238763', 'n04239074', 'n04243546', 'n04251144', 'n04252077', 'n04252225', 'n04254120', 'n04254680', 'n04254777', 'n04258138', 'n04259630', 'n04263257', 'n04264628', 'n04265275', 'n04266014', 'n04270147', 'n04273569', 'n04275548', 'n04277352', 'n04285008', 'n04286575', 'n04296562', 'n04310018', 'n04311004', 'n04311174', 'n04317175', 'n04325704', 'n04326547', 'n04328186', 'n04330267', 'n04332243', 'n04335435', 'n04336792', 'n04344873', 'n04346328', 'n04347754', 'n04350905', 'n04355338', 'n04355933', 'n04356056', 'n04357314', 'n04366367', 'n04367480', 'n04370456', 'n04371430', 'n04371774', 'n04372370', 'n04376876', 'n04380533', 'n04389033', 'n04392985', 'n04398044', 'n04399382', 'n04404412', 'n04409515', 'n04417672', 'n04418357', 'n04423845', 'n04428191', 'n04429376', 'n04435653', 'n04442312', 'n04443257', 'n04447861', 'n04456115', 'n04458633', 'n04461696', 'n04462240', 'n04465501', 'n04467665', 'n04476259', 'n04479046', 'n04482393', 'n04483307', 'n04485082', 'n04486054', 'n04487081', 'n04487394', 'n04493381', 'n04501370', 'n04505470', 'n04507155', 'n04509417', 'n04515003', 'n04517823', 'n04522168', 'n04523525', 'n04525038', 'n04525305', 'n04532106', 'n04532670', 'n04536866', 'n04540053', 'n04542943', 'n04548280', 'n04548362', 'n04550184', 'n04552348', 'n04553703', 'n04554684', 'n04557648', 'n04560804', 'n04562935', 'n04579145', 'n04579432', 'n04584207', 'n04589890', 'n04590129', 'n04591157', 'n04591713', 'n04592741', 'n04596742', 'n04597913', 'n04599235', 'n04604644', 'n04606251', 'n04612504', 'n04613696', 'n06359193', 'n06596364', 'n06785654', 'n06794110', 'n06874185', 'n07248320', 'n07565083', 'n07579787', 'n07583066', 'n07584110', 'n07590611', 'n07613480', 'n07614500', 'n07615774', 'n07684084', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07711569', 'n07714571', 'n07714990', 'n07715103', 'n07716358', 'n07716906', 'n07717410', 'n07717556', 'n07718472', 'n07718747', 'n07720875', 'n07730033', 'n07734744', 'n07742313', 'n07745940', 'n07747607', 'n07749582', 'n07753113', 'n07753275', 'n07753592', 'n07754684', 'n07760859', 'n07768694', 'n07802026', 'n07831146', 'n07836838', 'n07860988', 'n07871810', 'n07873807', 'n07875152', 'n07880968', 'n07892512', 'n07920052', 'n07930864', 'n07932039', 'n09193705', 'n09229709', 'n09246464', 'n09256479', 'n09288635', 'n09332890', 'n09399592', 'n09421951', 'n09428293', 'n09468604', 'n09472597', 'n09835506', 'n10148035', 'n10565667', 'n11879895', 'n11939491', 'n12057211', 'n12144580', 'n12267677', 'n12620546', 'n12768682', 'n12985857', 'n12998815', 'n13037406', 'n13040303', 'n13044778', 'n13052670', 'n13054560', 'n13133613', 'n15075141'] - imagenet_r_wnids = {'n01443537', 'n01484850', 'n01494475', 'n01498041', 'n01514859', 'n01518878', 'n01531178', 'n01534433', 'n01614925', 'n01616318', 'n01630670', 'n01632777', 'n01644373', 'n01677366', 'n01694178', 'n01748264', 'n01770393', 'n01774750', 'n01784675', 'n01806143', 'n01820546', 'n01833805', 'n01843383', 'n01847000', 'n01855672', 'n01860187', 'n01882714', 'n01910747', 'n01944390', 'n01983481', 'n01986214', 'n02007558', 'n02009912', 'n02051845', 'n02056570', 'n02066245', 'n02071294', 'n02077923', 'n02085620', 'n02086240', 'n02088094', 'n02088238', 'n02088364', 'n02088466', 'n02091032', 'n02091134', 'n02092339', 'n02094433', 'n02096585', 'n02097298', 'n02098286', 'n02099601', 'n02099712', 'n02102318', 'n02106030', 'n02106166', 'n02106550', 'n02106662', 'n02108089', 'n02108915', 'n02109525', 'n02110185', 'n02110341', 'n02110958', 'n02112018', 'n02112137', 'n02113023', 'n02113624', 'n02113799', 'n02114367', 'n02117135', 'n02119022', 'n02123045', 'n02128385', 'n02128757', 'n02129165', 'n02129604', 'n02130308', 'n02134084', 'n02138441', 'n02165456', 'n02190166', 'n02206856', 'n02219486', 'n02226429', 'n02233338', 'n02236044', 'n02268443', 'n02279972', 'n02317335', 'n02325366', 'n02346627', 'n02356798', 'n02363005', 'n02364673', 'n02391049', 'n02395406', 'n02398521', 'n02410509', - 'n02423022', 'n02437616', 'n02445715', 'n02447366', 'n02480495', 'n02480855', 'n02481823', 'n02483362', 'n02486410', 'n02510455', 'n02526121', 'n02607072', 'n02655020', 'n02672831', 'n02701002', 'n02749479', 'n02769748', 'n02793495', 'n02797295', 'n02802426', 'n02808440', 'n02814860', 'n02823750', 'n02841315', 'n02843684', 'n02883205', 'n02906734', 'n02909870', 'n02939185', 'n02948072', 'n02950826', 'n02951358', 'n02966193', 'n02980441', 'n02992529', 'n03124170', 'n03272010', 'n03345487', 'n03372029', 'n03424325', 'n03452741', 'n03467068', 'n03481172', 'n03494278', 'n03495258', 'n03498962', 'n03594945', 'n03602883', 'n03630383', 'n03649909', 'n03676483', 'n03710193', 'n03773504', 'n03775071', 'n03888257', 'n03930630', 'n03947888', 'n04086273', 'n04118538', 'n04133789', 'n04141076', 'n04146614', 'n04147183', 'n04192698', 'n04254680', 'n04266014', 'n04275548', 'n04310018', 'n04325704', 'n04347754', 'n04389033', 'n04409515', 'n04465501', 'n04487394', 'n04522168', 'n04536866', 'n04552348', 'n04591713', 'n07614500', 'n07693725', 'n07695742', 'n07697313', 'n07697537', 'n07714571', 'n07714990', 'n07718472', 'n07720875', 'n07734744', 'n07742313', 'n07745940', 'n07749582', 'n07753275', 'n07753592', 'n07768694', 'n07873807', 'n07880968', 'n07920052', 'n09472597', 'n09835506', 'n10565667', 'n12267677'} - return [wnid in imagenet_r_wnids for wnid in all_wnids] + mask = ImageNetRClasses.get_class_mask() def create_symlinks_to_imagenet(self): if not os.path.exists(self.imagenet_200_location): os.makedirs(self.imagenet_200_location) - folders_of_interest = self.get_class_mask() + folders_of_interest = ImageNetRClasses.get_imagenet_wnids() for folder in folders_of_interest: os.symlink(self.imagenet_1k_location + folder, - self.imagenet_200_location+folder, target_is_directory=True) + self.imagenet_200_location + folder, + target_is_directory=True) else: print('Folder containing IID validation images already exists') - def __init__(self, imagenet_directory, imagenet_200_directory= "/tmp/in200", transform = None): + def __init__(self, + imagenet_directory, + imagenet_200_directory="/tmp/in200", + transform=None): self.imagenet_1k_location = imagenet_directory self.imagenet_200_location = imagenet_200_directory self.create_symlinks_to_imagenet() - super().__init__( - self.imagenet_200_location, - transform= transform - ) \ No newline at end of file + super().__init__(self.imagenet_200_location, transform=transform) + + def accuracy_metric(self, logits, targets): + return super().accuracy_metric(logits, targets, ImageNet200.mask) diff --git a/robusta/datasets/imageneta.py b/robusta/datasets/imageneta.py index ff1d512..a9bd102 100644 --- a/robusta/datasets/imageneta.py +++ b/robusta/datasets/imageneta.py @@ -17,7 +17,6 @@ # This licence notice applies to all originally written code by the # authors. Code taken from other open-source projects is indicated. # See NOTICE for a list of all third-party licences used in the project. - """ The ImageNet-A dataset. Reference: https://github.com/hendrycks/natural-adv-examples @@ -55,17 +54,1028 @@ import torch.nn.functional as F import numpy as np -class ImageNetA: +from robusta.datasets.base import RemappedImageNet + + +class ImageNetAClasses(): + + def get_class_mask(): + # dict definition taken from https://github.com/hendrycks/natural-adv-examples/blob/master/eval.py#L12 + thousand_k_to_200 = { + 0: -1, + 1: -1, + 2: -1, + 3: -1, + 4: -1, + 5: -1, + 6: 0, + 7: -1, + 8: -1, + 9: -1, + 10: -1, + 11: 1, + 12: -1, + 13: 2, + 14: -1, + 15: 3, + 16: -1, + 17: 4, + 18: -1, + 19: -1, + 20: -1, + 21: -1, + 22: 5, + 23: 6, + 24: -1, + 25: -1, + 26: -1, + 27: 7, + 28: -1, + 29: -1, + 30: 8, + 31: -1, + 32: -1, + 33: -1, + 34: -1, + 35: -1, + 36: -1, + 37: 9, + 38: -1, + 39: 10, + 40: -1, + 41: -1, + 42: 11, + 43: -1, + 44: -1, + 45: -1, + 46: -1, + 47: 12, + 48: -1, + 49: -1, + 50: 13, + 51: -1, + 52: -1, + 53: -1, + 54: -1, + 55: -1, + 56: -1, + 57: 14, + 58: -1, + 59: -1, + 60: -1, + 61: -1, + 62: -1, + 63: -1, + 64: -1, + 65: -1, + 66: -1, + 67: -1, + 68: -1, + 69: -1, + 70: 15, + 71: 16, + 72: -1, + 73: -1, + 74: -1, + 75: -1, + 76: 17, + 77: -1, + 78: -1, + 79: 18, + 80: -1, + 81: -1, + 82: -1, + 83: -1, + 84: -1, + 85: -1, + 86: -1, + 87: -1, + 88: -1, + 89: 19, + 90: 20, + 91: -1, + 92: -1, + 93: -1, + 94: 21, + 95: -1, + 96: 22, + 97: 23, + 98: -1, + 99: 24, + 100: -1, + 101: -1, + 102: -1, + 103: -1, + 104: -1, + 105: 25, + 106: -1, + 107: 26, + 108: 27, + 109: -1, + 110: 28, + 111: -1, + 112: -1, + 113: 29, + 114: -1, + 115: -1, + 116: -1, + 117: -1, + 118: -1, + 119: -1, + 120: -1, + 121: -1, + 122: -1, + 123: -1, + 124: 30, + 125: 31, + 126: -1, + 127: -1, + 128: -1, + 129: -1, + 130: 32, + 131: -1, + 132: 33, + 133: -1, + 134: -1, + 135: -1, + 136: -1, + 137: -1, + 138: -1, + 139: -1, + 140: -1, + 141: -1, + 142: -1, + 143: 34, + 144: 35, + 145: -1, + 146: -1, + 147: -1, + 148: -1, + 149: -1, + 150: 36, + 151: 37, + 152: -1, + 153: -1, + 154: -1, + 155: -1, + 156: -1, + 157: -1, + 158: -1, + 159: -1, + 160: -1, + 161: -1, + 162: -1, + 163: -1, + 164: -1, + 165: -1, + 166: -1, + 167: -1, + 168: -1, + 169: -1, + 170: -1, + 171: -1, + 172: -1, + 173: -1, + 174: -1, + 175: -1, + 176: -1, + 177: -1, + 178: -1, + 179: -1, + 180: -1, + 181: -1, + 182: -1, + 183: -1, + 184: -1, + 185: -1, + 186: -1, + 187: -1, + 188: -1, + 189: -1, + 190: -1, + 191: -1, + 192: -1, + 193: -1, + 194: -1, + 195: -1, + 196: -1, + 197: -1, + 198: -1, + 199: -1, + 200: -1, + 201: -1, + 202: -1, + 203: -1, + 204: -1, + 205: -1, + 206: -1, + 207: 38, + 208: -1, + 209: -1, + 210: -1, + 211: -1, + 212: -1, + 213: -1, + 214: -1, + 215: -1, + 216: -1, + 217: -1, + 218: -1, + 219: -1, + 220: -1, + 221: -1, + 222: -1, + 223: -1, + 224: -1, + 225: -1, + 226: -1, + 227: -1, + 228: -1, + 229: -1, + 230: -1, + 231: -1, + 232: -1, + 233: -1, + 234: 39, + 235: 40, + 236: -1, + 237: -1, + 238: -1, + 239: -1, + 240: -1, + 241: -1, + 242: -1, + 243: -1, + 244: -1, + 245: -1, + 246: -1, + 247: -1, + 248: -1, + 249: -1, + 250: -1, + 251: -1, + 252: -1, + 253: -1, + 254: 41, + 255: -1, + 256: -1, + 257: -1, + 258: -1, + 259: -1, + 260: -1, + 261: -1, + 262: -1, + 263: -1, + 264: -1, + 265: -1, + 266: -1, + 267: -1, + 268: -1, + 269: -1, + 270: -1, + 271: -1, + 272: -1, + 273: -1, + 274: -1, + 275: -1, + 276: -1, + 277: 42, + 278: -1, + 279: -1, + 280: -1, + 281: -1, + 282: -1, + 283: 43, + 284: -1, + 285: -1, + 286: -1, + 287: 44, + 288: -1, + 289: -1, + 290: -1, + 291: 45, + 292: -1, + 293: -1, + 294: -1, + 295: 46, + 296: -1, + 297: -1, + 298: 47, + 299: -1, + 300: -1, + 301: 48, + 302: -1, + 303: -1, + 304: -1, + 305: -1, + 306: 49, + 307: 50, + 308: 51, + 309: 52, + 310: 53, + 311: 54, + 312: -1, + 313: 55, + 314: 56, + 315: 57, + 316: -1, + 317: 58, + 318: -1, + 319: 59, + 320: -1, + 321: -1, + 322: -1, + 323: 60, + 324: 61, + 325: -1, + 326: 62, + 327: 63, + 328: -1, + 329: -1, + 330: 64, + 331: -1, + 332: -1, + 333: -1, + 334: 65, + 335: 66, + 336: 67, + 337: -1, + 338: -1, + 339: -1, + 340: -1, + 341: -1, + 342: -1, + 343: -1, + 344: -1, + 345: -1, + 346: -1, + 347: 68, + 348: -1, + 349: -1, + 350: -1, + 351: -1, + 352: -1, + 353: -1, + 354: -1, + 355: -1, + 356: -1, + 357: -1, + 358: -1, + 359: -1, + 360: -1, + 361: 69, + 362: -1, + 363: 70, + 364: -1, + 365: -1, + 366: -1, + 367: -1, + 368: -1, + 369: -1, + 370: -1, + 371: -1, + 372: 71, + 373: -1, + 374: -1, + 375: -1, + 376: -1, + 377: -1, + 378: 72, + 379: -1, + 380: -1, + 381: -1, + 382: -1, + 383: -1, + 384: -1, + 385: -1, + 386: 73, + 387: -1, + 388: -1, + 389: -1, + 390: -1, + 391: -1, + 392: -1, + 393: -1, + 394: -1, + 395: -1, + 396: -1, + 397: 74, + 398: -1, + 399: -1, + 400: 75, + 401: 76, + 402: 77, + 403: -1, + 404: 78, + 405: -1, + 406: -1, + 407: 79, + 408: -1, + 409: -1, + 410: -1, + 411: 80, + 412: -1, + 413: -1, + 414: -1, + 415: -1, + 416: 81, + 417: 82, + 418: -1, + 419: -1, + 420: 83, + 421: -1, + 422: -1, + 423: -1, + 424: -1, + 425: 84, + 426: -1, + 427: -1, + 428: 85, + 429: -1, + 430: 86, + 431: -1, + 432: -1, + 433: -1, + 434: -1, + 435: -1, + 436: -1, + 437: 87, + 438: 88, + 439: -1, + 440: -1, + 441: -1, + 442: -1, + 443: -1, + 444: -1, + 445: 89, + 446: -1, + 447: -1, + 448: -1, + 449: -1, + 450: -1, + 451: -1, + 452: -1, + 453: -1, + 454: -1, + 455: -1, + 456: 90, + 457: 91, + 458: -1, + 459: -1, + 460: -1, + 461: 92, + 462: 93, + 463: -1, + 464: -1, + 465: -1, + 466: -1, + 467: -1, + 468: -1, + 469: -1, + 470: 94, + 471: -1, + 472: 95, + 473: -1, + 474: -1, + 475: -1, + 476: -1, + 477: -1, + 478: -1, + 479: -1, + 480: -1, + 481: -1, + 482: -1, + 483: 96, + 484: -1, + 485: -1, + 486: 97, + 487: -1, + 488: 98, + 489: -1, + 490: -1, + 491: -1, + 492: 99, + 493: -1, + 494: -1, + 495: -1, + 496: 100, + 497: -1, + 498: -1, + 499: -1, + 500: -1, + 501: -1, + 502: -1, + 503: -1, + 504: -1, + 505: -1, + 506: -1, + 507: -1, + 508: -1, + 509: -1, + 510: -1, + 511: -1, + 512: -1, + 513: -1, + 514: 101, + 515: -1, + 516: 102, + 517: -1, + 518: -1, + 519: -1, + 520: -1, + 521: -1, + 522: -1, + 523: -1, + 524: -1, + 525: -1, + 526: -1, + 527: -1, + 528: 103, + 529: -1, + 530: 104, + 531: -1, + 532: -1, + 533: -1, + 534: -1, + 535: -1, + 536: -1, + 537: -1, + 538: -1, + 539: 105, + 540: -1, + 541: -1, + 542: 106, + 543: 107, + 544: -1, + 545: -1, + 546: -1, + 547: -1, + 548: -1, + 549: 108, + 550: -1, + 551: -1, + 552: 109, + 553: -1, + 554: -1, + 555: -1, + 556: -1, + 557: 110, + 558: -1, + 559: -1, + 560: -1, + 561: 111, + 562: 112, + 563: -1, + 564: -1, + 565: -1, + 566: -1, + 567: -1, + 568: -1, + 569: 113, + 570: -1, + 571: -1, + 572: 114, + 573: 115, + 574: -1, + 575: 116, + 576: -1, + 577: -1, + 578: -1, + 579: 117, + 580: -1, + 581: -1, + 582: -1, + 583: -1, + 584: -1, + 585: -1, + 586: -1, + 587: -1, + 588: -1, + 589: 118, + 590: -1, + 591: -1, + 592: -1, + 593: -1, + 594: -1, + 595: -1, + 596: -1, + 597: -1, + 598: -1, + 599: -1, + 600: -1, + 601: -1, + 602: -1, + 603: -1, + 604: -1, + 605: -1, + 606: 119, + 607: 120, + 608: -1, + 609: 121, + 610: -1, + 611: -1, + 612: -1, + 613: -1, + 614: 122, + 615: -1, + 616: -1, + 617: -1, + 618: -1, + 619: -1, + 620: -1, + 621: -1, + 622: -1, + 623: -1, + 624: -1, + 625: -1, + 626: 123, + 627: 124, + 628: -1, + 629: -1, + 630: -1, + 631: -1, + 632: -1, + 633: -1, + 634: -1, + 635: -1, + 636: -1, + 637: -1, + 638: -1, + 639: -1, + 640: 125, + 641: 126, + 642: 127, + 643: 128, + 644: -1, + 645: -1, + 646: -1, + 647: -1, + 648: -1, + 649: -1, + 650: -1, + 651: -1, + 652: -1, + 653: -1, + 654: -1, + 655: -1, + 656: -1, + 657: -1, + 658: 129, + 659: -1, + 660: -1, + 661: -1, + 662: -1, + 663: -1, + 664: -1, + 665: -1, + 666: -1, + 667: -1, + 668: 130, + 669: -1, + 670: -1, + 671: -1, + 672: -1, + 673: -1, + 674: -1, + 675: -1, + 676: -1, + 677: 131, + 678: -1, + 679: -1, + 680: -1, + 681: -1, + 682: 132, + 683: -1, + 684: 133, + 685: -1, + 686: -1, + 687: 134, + 688: -1, + 689: -1, + 690: -1, + 691: -1, + 692: -1, + 693: -1, + 694: -1, + 695: -1, + 696: -1, + 697: -1, + 698: -1, + 699: -1, + 700: -1, + 701: 135, + 702: -1, + 703: -1, + 704: 136, + 705: -1, + 706: -1, + 707: -1, + 708: -1, + 709: -1, + 710: -1, + 711: -1, + 712: -1, + 713: -1, + 714: -1, + 715: -1, + 716: -1, + 717: -1, + 718: -1, + 719: 137, + 720: -1, + 721: -1, + 722: -1, + 723: -1, + 724: -1, + 725: -1, + 726: -1, + 727: -1, + 728: -1, + 729: -1, + 730: -1, + 731: -1, + 732: -1, + 733: -1, + 734: -1, + 735: -1, + 736: 138, + 737: -1, + 738: -1, + 739: -1, + 740: -1, + 741: -1, + 742: -1, + 743: -1, + 744: -1, + 745: -1, + 746: 139, + 747: -1, + 748: -1, + 749: 140, + 750: -1, + 751: -1, + 752: 141, + 753: -1, + 754: -1, + 755: -1, + 756: -1, + 757: -1, + 758: 142, + 759: -1, + 760: -1, + 761: -1, + 762: -1, + 763: 143, + 764: -1, + 765: 144, + 766: -1, + 767: -1, + 768: 145, + 769: -1, + 770: -1, + 771: -1, + 772: -1, + 773: 146, + 774: 147, + 775: -1, + 776: 148, + 777: -1, + 778: -1, + 779: 149, + 780: 150, + 781: -1, + 782: -1, + 783: -1, + 784: -1, + 785: -1, + 786: 151, + 787: -1, + 788: -1, + 789: -1, + 790: -1, + 791: -1, + 792: 152, + 793: -1, + 794: -1, + 795: -1, + 796: -1, + 797: 153, + 798: -1, + 799: -1, + 800: -1, + 801: -1, + 802: 154, + 803: 155, + 804: 156, + 805: -1, + 806: -1, + 807: -1, + 808: -1, + 809: -1, + 810: -1, + 811: -1, + 812: -1, + 813: 157, + 814: -1, + 815: 158, + 816: -1, + 817: -1, + 818: -1, + 819: -1, + 820: 159, + 821: -1, + 822: -1, + 823: 160, + 824: -1, + 825: -1, + 826: -1, + 827: -1, + 828: -1, + 829: -1, + 830: -1, + 831: 161, + 832: -1, + 833: 162, + 834: -1, + 835: 163, + 836: -1, + 837: -1, + 838: -1, + 839: 164, + 840: -1, + 841: -1, + 842: -1, + 843: -1, + 844: -1, + 845: 165, + 846: -1, + 847: 166, + 848: -1, + 849: -1, + 850: 167, + 851: -1, + 852: -1, + 853: -1, + 854: -1, + 855: -1, + 856: -1, + 857: -1, + 858: -1, + 859: 168, + 860: -1, + 861: -1, + 862: 169, + 863: -1, + 864: -1, + 865: -1, + 866: -1, + 867: -1, + 868: -1, + 869: -1, + 870: 170, + 871: -1, + 872: -1, + 873: -1, + 874: -1, + 875: -1, + 876: -1, + 877: -1, + 878: -1, + 879: 171, + 880: 172, + 881: -1, + 882: -1, + 883: -1, + 884: -1, + 885: -1, + 886: -1, + 887: -1, + 888: 173, + 889: -1, + 890: 174, + 891: -1, + 892: -1, + 893: -1, + 894: -1, + 895: -1, + 896: -1, + 897: 175, + 898: -1, + 899: -1, + 900: 176, + 901: -1, + 902: -1, + 903: -1, + 904: -1, + 905: -1, + 906: -1, + 907: 177, + 908: -1, + 909: -1, + 910: -1, + 911: -1, + 912: -1, + 913: 178, + 914: -1, + 915: -1, + 916: -1, + 917: -1, + 918: -1, + 919: -1, + 920: -1, + 921: -1, + 922: -1, + 923: -1, + 924: 179, + 925: -1, + 926: -1, + 927: -1, + 928: -1, + 929: -1, + 930: -1, + 931: -1, + 932: 180, + 933: 181, + 934: 182, + 935: -1, + 936: -1, + 937: 183, + 938: -1, + 939: -1, + 940: -1, + 941: -1, + 942: -1, + 943: 184, + 944: -1, + 945: 185, + 946: -1, + 947: 186, + 948: -1, + 949: -1, + 950: -1, + 951: 187, + 952: -1, + 953: -1, + 954: 188, + 955: -1, + 956: 189, + 957: 190, + 958: -1, + 959: 191, + 960: -1, + 961: -1, + 962: -1, + 963: -1, + 964: -1, + 965: -1, + 966: -1, + 967: -1, + 968: -1, + 969: -1, + 970: -1, + 971: 192, + 972: 193, + 973: -1, + 974: -1, + 975: -1, + 976: -1, + 977: -1, + 978: -1, + 979: -1, + 980: 194, + 981: 195, + 982: -1, + 983: -1, + 984: 196, + 985: -1, + 986: 197, + 987: 198, + 988: 199, + 989: -1, + 990: -1, + 991: -1, + 992: -1, + 993: -1, + 994: -1, + 995: -1, + 996: -1, + 997: -1, + 998: -1, + 999: -1 + } - def __init__(self): - thousand_k_to_200 = {0: -1, 1: -1, 2: -1, 3: -1, 4: -1, 5: -1, 6: 0, 7: -1, 8: -1, 9: -1, 10: -1, 11: 1, 12: -1, 13: 2, 14: -1, 15: 3, 16: -1, 17: 4, 18: -1, 19: -1, 20: -1, 21: -1, 22: 5, 23: 6, 24: -1, 25: -1, 26: -1, 27: 7, 28: -1, 29: -1, 30: 8, 31: -1, 32: -1, 33: -1, 34: -1, 35: -1, 36: -1, 37: 9, 38: -1, 39: 10, 40: -1, 41: -1, 42: 11, 43: -1, 44: -1, 45: -1, 46: -1, 47: 12, 48: -1, 49: -1, 50: 13, 51: -1, 52: -1, 53: -1, 54: -1, 55: -1, 56: -1, 57: 14, 58: -1, 59: -1, 60: -1, 61: -1, 62: -1, 63: -1, 64: -1, 65: -1, 66: -1, 67: -1, 68: -1, 69: -1, 70: 15, 71: 16, 72: -1, 73: -1, 74: -1, 75: -1, 76: 17, 77: -1, 78: -1, 79: 18, 80: -1, 81: -1, 82: -1, 83: -1, 84: -1, 85: -1, 86: -1, 87: -1, 88: -1, 89: 19, 90: 20, 91: -1, 92: -1, 93: -1, 94: 21, 95: -1, 96: 22, 97: 23, 98: -1, 99: 24, 100: -1, 101: -1, 102: -1, 103: -1, 104: -1, 105: 25, 106: -1, 107: 26, 108: 27, 109: -1, 110: 28, 111: -1, 112: -1, 113: 29, 114: -1, 115: -1, 116: -1, 117: -1, 118: -1, 119: -1, 120: -1, 121: -1, 122: -1, 123: -1, 124: 30, 125: 31, 126: -1, 127: -1, 128: -1, 129: -1, 130: 32, 131: -1, 132: 33, 133: -1, 134: -1, 135: -1, 136: -1, 137: -1, 138: -1, 139: -1, 140: -1, 141: -1, 142: -1, 143: 34, 144: 35, 145: -1, 146: -1, 147: -1, 148: -1, 149: -1, 150: 36, 151: 37, 152: -1, 153: -1, 154: -1, 155: -1, 156: -1, 157: -1, 158: -1, 159: -1, 160: -1, 161: -1, 162: -1, 163: -1, 164: -1, 165: -1, 166: -1, 167: -1, 168: -1, 169: -1, 170: -1, 171: -1, 172: -1, 173: -1, 174: -1, 175: -1, 176: -1, 177: -1, 178: -1, 179: -1, 180: -1, 181: -1, 182: -1, 183: -1, 184: -1, 185: -1, 186: -1, 187: -1, 188: -1, 189: -1, 190: -1, 191: -1, 192: -1, 193: -1, 194: -1, 195: -1, 196: -1, 197: -1, 198: -1, 199: -1, 200: -1, 201: -1, 202: -1, 203: -1, 204: -1, 205: -1, 206: -1, 207: 38, 208: -1, 209: -1, 210: -1, 211: -1, 212: -1, 213: -1, 214: -1, 215: -1, 216: -1, 217: -1, 218: -1, 219: -1, 220: -1, 221: -1, 222: -1, 223: -1, 224: -1, 225: -1, 226: -1, 227: -1, 228: -1, 229: -1, 230: -1, 231: -1, 232: -1, 233: -1, 234: 39, 235: 40, 236: -1, 237: -1, 238: -1, 239: -1, 240: -1, 241: -1, 242: -1, 243: -1, 244: -1, 245: -1, 246: -1, 247: -1, 248: -1, 249: -1, 250: -1, 251: -1, 252: -1, 253: -1, 254: 41, 255: -1, 256: -1, 257: -1, 258: -1, 259: -1, 260: -1, 261: -1, 262: -1, 263: -1, 264: -1, 265: -1, 266: -1, 267: -1, 268: -1, 269: -1, 270: -1, 271: -1, 272: -1, 273: -1, 274: -1, 275: -1, 276: -1, 277: 42, 278: -1, 279: -1, 280: -1, 281: -1, 282: -1, 283: 43, 284: -1, 285: -1, 286: -1, 287: 44, 288: -1, 289: -1, 290: -1, 291: 45, 292: -1, 293: -1, 294: -1, 295: 46, 296: -1, 297: -1, 298: 47, 299: -1, 300: -1, 301: 48, 302: -1, 303: -1, 304: -1, 305: -1, 306: 49, 307: 50, 308: 51, 309: 52, 310: 53, 311: 54, 312: -1, 313: 55, 314: 56, 315: 57, 316: -1, 317: 58, 318: -1, 319: 59, 320: -1, 321: -1, 322: -1, 323: 60, 324: 61, 325: -1, 326: 62, 327: 63, 328: -1, 329: -1, 330: 64, 331: -1, 332: -1, 333: -1, 334: 65, 335: 66, 336: 67, 337: -1, 338: -1, 339: -1, 340: -1, 341: -1, 342: -1, 343: -1, 344: -1, 345: -1, 346: -1, 347: 68, 348: -1, 349: -1, 350: -1, 351: -1, 352: -1, 353: -1, 354: -1, 355: -1, 356: -1, 357: -1, 358: -1, 359: -1, 360: -1, 361: 69, 362: -1, 363: 70, 364: -1, 365: -1, 366: -1, 367: -1, 368: -1, 369: -1, 370: -1, 371: -1, 372: 71, 373: -1, 374: -1, 375: -1, 376: -1, 377: -1, 378: 72, 379: -1, 380: -1, 381: -1, 382: -1, 383: -1, 384: -1, 385: -1, 386: 73, 387: -1, 388: -1, 389: -1, 390: -1, 391: -1, 392: -1, 393: -1, 394: -1, 395: -1, 396: -1, 397: 74, 398: -1, 399: -1, 400: 75, 401: 76, 402: 77, 403: -1, 404: 78, 405: -1, 406: -1, 407: 79, 408: -1, 409: -1, 410: -1, 411: 80, 412: -1, 413: -1, 414: -1, 415: -1, 416: 81, 417: 82, 418: -1, 419: -1, 420: 83, 421: -1, 422: -1, 423: -1, 424: -1, 425: 84, 426: -1, 427: -1, 428: 85, 429: -1, 430: 86, 431: -1, 432: -1, 433: -1, 434: -1, 435: -1, 436: -1, 437: 87, 438: 88, 439: -1, 440: -1, 441: -1, 442: -1, 443: -1, 444: -1, 445: 89, 446: -1, 447: -1, 448: -1, 449: -1, 450: -1, 451: -1, 452: -1, 453: -1, 454: -1, 455: -1, 456: 90, 457: 91, 458: -1, 459: -1, 460: -1, 461: 92, 462: 93, 463: -1, 464: -1, 465: -1, 466: -1, 467: -1, 468: -1, 469: -1, 470: 94, 471: -1, 472: 95, 473: -1, 474: -1, 475: -1, 476: -1, 477: -1, 478: -1, 479: -1, 480: -1, 481: -1, 482: -1, 483: 96, 484: -1, 485: -1, 486: 97, 487: -1, 488: 98, 489: -1, 490: -1, 491: -1, 492: 99, 493: -1, 494: -1, 495: -1, 496: 100, 497: -1, 498: -1, 499: -1, 500: -1, 501: -1, 502: -1, 503: -1, 504: -1, 505: -1, 506: -1, 507: -1, 508: -1, 509: -1, 510: -1, 511: -1, 512: -1, 513: -1, 514: 101, 515: -1, 516: 102, 517: -1, 518: -1, 519: -1, 520: -1, 521: -1, 522: -1, 523: -1, 524: -1, 525: -1, 526: -1, 527: -1, 528: 103, 529: -1, 530: 104, 531: -1, 532: -1, 533: -1, 534: -1, 535: -1, 536: -1, 537: -1, 538: -1, 539: 105, 540: -1, 541: -1, 542: 106, 543: 107, 544: -1, 545: -1, 546: -1, 547: -1, 548: -1, 549: 108, 550: -1, 551: -1, 552: 109, 553: -1, 554: -1, 555: -1, 556: -1, 557: 110, 558: -1, 559: -1, 560: -1, 561: 111, 562: 112, 563: -1, 564: -1, 565: -1, 566: -1, 567: -1, 568: -1, 569: 113, 570: -1, 571: -1, 572: 114, 573: 115, 574: -1, 575: 116, 576: -1, 577: -1, 578: -1, 579: 117, 580: -1, 581: -1, 582: -1, 583: -1, 584: -1, 585: -1, 586: -1, 587: -1, 588: -1, 589: 118, 590: -1, 591: -1, 592: -1, 593: -1, 594: -1, 595: -1, 596: -1, 597: -1, 598: -1, 599: -1, 600: -1, 601: -1, 602: -1, 603: -1, 604: -1, 605: -1, 606: 119, 607: 120, 608: -1, 609: 121, 610: -1, 611: -1, 612: -1, 613: -1, 614: 122, 615: -1, 616: -1, 617: -1, 618: -1, 619: -1, 620: -1, 621: -1, 622: -1, 623: -1, 624: -1, 625: -1, 626: 123, 627: 124, 628: -1, 629: -1, 630: -1, 631: -1, 632: -1, 633: -1, 634: -1, 635: -1, 636: -1, 637: -1, 638: -1, 639: -1, 640: 125, 641: 126, 642: 127, 643: 128, 644: -1, 645: -1, 646: -1, 647: -1, 648: -1, 649: -1, 650: -1, 651: -1, 652: -1, 653: -1, 654: -1, 655: -1, 656: -1, 657: -1, 658: 129, 659: -1, 660: -1, 661: -1, 662: -1, 663: -1, 664: -1, 665: -1, 666: -1, 667: -1, 668: 130, 669: -1, 670: -1, 671: -1, 672: -1, 673: -1, 674: -1, 675: -1, 676: -1, 677: 131, 678: -1, 679: -1, 680: -1, 681: -1, 682: 132, 683: -1, 684: 133, 685: -1, 686: -1, 687: 134, 688: -1, 689: -1, 690: -1, 691: -1, 692: -1, 693: -1, 694: -1, 695: -1, 696: -1, 697: -1, 698: -1, 699: -1, 700: -1, 701: 135, 702: -1, 703: -1, 704: 136, 705: -1, 706: -1, 707: -1, 708: -1, 709: -1, 710: -1, 711: -1, 712: -1, 713: -1, 714: -1, 715: -1, 716: -1, 717: -1, 718: -1, 719: 137, 720: -1, 721: -1, 722: -1, 723: -1, 724: -1, 725: -1, 726: -1, 727: -1, 728: -1, 729: -1, 730: -1, 731: -1, 732: -1, 733: -1, 734: -1, 735: -1, 736: 138, 737: -1, 738: -1, 739: -1, 740: -1, 741: -1, 742: -1, 743: -1, 744: -1, 745: -1, 746: 139, 747: -1, 748: -1, 749: 140, 750: -1, 751: -1, 752: 141, 753: -1, 754: -1, 755: -1, 756: -1, 757: -1, 758: 142, 759: -1, 760: -1, 761: -1, 762: -1, 763: 143, 764: -1, 765: 144, 766: -1, 767: -1, 768: 145, 769: -1, 770: -1, 771: -1, 772: -1, 773: 146, 774: 147, 775: -1, 776: 148, 777: -1, 778: -1, 779: 149, 780: 150, 781: -1, 782: -1, 783: -1, 784: -1, 785: -1, 786: 151, 787: -1, 788: -1, 789: -1, 790: -1, 791: -1, 792: 152, 793: -1, 794: -1, 795: -1, 796: -1, 797: 153, 798: -1, 799: -1, 800: -1, 801: -1, 802: 154, 803: 155, 804: 156, 805: -1, 806: -1, 807: -1, 808: -1, 809: -1, 810: -1, 811: -1, 812: -1, 813: 157, 814: -1, 815: 158, 816: -1, 817: -1, 818: -1, 819: -1, 820: 159, 821: -1, 822: -1, 823: 160, 824: -1, 825: -1, 826: -1, 827: -1, 828: -1, 829: -1, 830: -1, 831: 161, 832: -1, 833: 162, 834: -1, 835: 163, 836: -1, 837: -1, 838: -1, 839: 164, 840: -1, 841: -1, 842: -1, 843: -1, 844: -1, 845: 165, 846: -1, 847: 166, 848: -1, 849: -1, 850: 167, 851: -1, 852: -1, 853: -1, 854: -1, 855: -1, 856: -1, 857: -1, 858: -1, 859: 168, 860: -1, 861: -1, 862: 169, 863: -1, 864: -1, 865: -1, 866: -1, 867: -1, 868: -1, 869: -1, 870: 170, 871: -1, 872: -1, 873: -1, 874: -1, 875: -1, 876: -1, 877: -1, 878: -1, 879: 171, 880: 172, 881: -1, 882: -1, 883: -1, 884: -1, 885: -1, 886: -1, 887: -1, 888: 173, 889: -1, 890: 174, 891: -1, 892: -1, 893: -1, 894: -1, 895: -1, 896: -1, 897: 175, 898: -1, 899: -1, 900: 176, 901: -1, 902: -1, 903: -1, 904: -1, 905: -1, 906: -1, 907: 177, 908: -1, 909: -1, 910: -1, 911: -1, 912: -1, 913: 178, 914: -1, 915: -1, 916: -1, 917: -1, 918: -1, 919: -1, 920: -1, 921: -1, 922: -1, 923: -1, 924: 179, 925: -1, 926: -1, 927: -1, 928: -1, 929: -1, 930: -1, 931: -1, 932: 180, 933: 181, 934: 182, 935: -1, 936: -1, 937: 183, 938: -1, 939: -1, 940: -1, 941: -1, 942: -1, 943: 184, 944: -1, 945: 185, 946: -1, 947: 186, 948: -1, 949: -1, 950: -1, 951: 187, 952: -1, 953: -1, 954: 188, 955: -1, 956: 189, 957: 190, 958: -1, 959: 191, 960: -1, 961: -1, 962: -1, 963: -1, 964: -1, 965: -1, 966: -1, 967: -1, 968: -1, 969: -1, 970: -1, 971: 192, 972: 193, 973: -1, 974: -1, 975: -1, 976: -1, 977: -1, 978: -1, 979: -1, 980: 194, 981: 195, 982: -1, 983: -1, 984: 196, 985: -1, 986: 197, 987: 198, 988: 199, 989: -1, 990: -1, 991: -1, 992: -1, 993: -1, 994: -1, 995: -1, 996: -1, 997: -1, 998: -1, 999: -1} - indices_in_1k = [k for k in thousand_k_to_200 if thousand_k_to_200[k] != -1] + return [k for k in thousand_k_to_200 if thousand_k_to_200[k] != -1] - mean = [0.485, 0.456, 0.406] - std = [0.229, 0.224, 0.225] - test_transform = trn.Compose( - [trn.Resize(256), trn.CenterCrop(224), trn.ToTensor(), trn.Normalize(mean, std)]) +class ImageNetA(RemappedImageNet): + """This class implements the ImageNet-A dataset from https://arxiv.org/abs/1907.07174, + https://github.com/hendrycks/natural-adv-examples. It contains natural images that + were sampled such that vanilla ResNet50 classifiers obtained chance level performance + on them. The functionality of this dataset is implemented in robusta.datasets.base.py. + For the evaluation, one needs to remap the predictions from the 1000 ImageNet classes to + the 200 ImageNet-A classes which is done in the RemappedImageNet class.""" - output = net(data)[:,indices_in_1k] + mask = ImageNetAClasses.get_class_mask() + def accuracy_metric(self, logits, targets): + return super().accuracy_metric(logits, targets, ImageNetA.mask) diff --git a/robusta/datasets/imagenetc.py b/robusta/datasets/imagenetc.py index f10d0e6..a621bcd 100644 --- a/robusta/datasets/imagenetc.py +++ b/robusta/datasets/imagenetc.py @@ -10,28 +10,27 @@ class ImageNetC(torchvision.datasets.ImageFolder): num_classes = 1000 image_size = (224, 224) - train_corruptions = ("brightness", "elastic_transform", "impulse_noise", - "pixelate", "snow", "zoom_blur", - "contrast", "fog", "gaussian_noise", - "jpeg_compression", "defocus_blur", "frost", - "glass_blur", "motion_blur", "shot_noise") + train_corruptions = ("brightness", "elastic_transform", "impulse_noise", + "pixelate", "snow", "zoom_blur", "contrast", "fog", + "gaussian_noise", "jpeg_compression", "defocus_blur", + "frost", "glass_blur", "motion_blur", "shot_noise") test_corruptions = ("gaussian_blur", "saturate", "spatter", "speckle_noise") severities = ("1", "2", "3", "4", "5") - def __init__(self, root: str, corruption: str, severity: str, + def __init__(self, + root: str, + corruption: str, + severity: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, loader: Callable[[str], Any] = default_loader, - is_valid_file: Optional[Callable[[str], bool]] = None - ): - super(ImageNetC, self).__init__( - root=os.path.join(root, corruption, str(severity)), - transform=transform, - target_transform=target_transform, - loader=loader, - is_valid_file=is_valid_file - ) + is_valid_file: Optional[Callable[[str], bool]] = None): + super(ImageNetC, + self).__init__(root=os.path.join(root, corruption, str(severity)), + transform=transform, + target_transform=target_transform, + loader=loader, + is_valid_file=is_valid_file) assert corruption in ImageNetC.train_corruptions or \ corruption in ImageNetC.test_corruptions assert str(severity) in ImageNetC.severities - diff --git a/robusta/datasets/imagenetr.py b/robusta/datasets/imagenetr.py index e7ebe71..51fe7bf 100644 --- a/robusta/datasets/imagenetr.py +++ b/robusta/datasets/imagenetr.py @@ -18,9 +18,20 @@ # authors. Code taken from other open-source projects is indicated. # See NOTICE for a list of all third-party licences used in the project. -import torchvision +from robusta.datasets.base import ImageNetRClasses from robusta.datasets.base import RemappedImageNet -class ImageNetR(torchvision.datasets.ImageFolder, RemappedImageNet): - pass +class ImageNetR(RemappedImageNet): + """This class implements the ImageNet-R dataset from https://arxiv.org/abs/2006.16241, + https://github.com/hendrycks/imagenet-r. It contains different renditions of 200 ImageNet + classes. The functionality of this dataset is implemented in robusta.datasets.base.py. + For the evaluation, one needs to remap the predictions from the 1000 ImageNet classes to + the 200 ImageNet-R classes which is done in the RemappedImageNet class.""" + + @property + def mask(self): + return ImageNetRClasses.get_class_mask() + + def accuracy_metric(self, logits, targets): + return super().accuracy_metric(logits, targets, ImageNetR.mask) diff --git a/robusta/models/BiT_models.py b/robusta/models/BiT_models.py index 5c5361b..cb47600 100644 --- a/robusta/models/BiT_models.py +++ b/robusta/models/BiT_models.py @@ -23,23 +23,32 @@ class StdConv2d(nn.Conv2d): + def forward(self, x): w = self.weight v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) w = (w - m) / torch.sqrt(v + 1e-10) - return F.conv2d( - x, w, self.bias, self.stride, self.padding, self.dilation, self.groups - ) + return F.conv2d(x, w, self.bias, self.stride, self.padding, + self.dilation, self.groups) def conv3x3(cin, cout, stride=1, groups=1, bias=False): - return StdConv2d( - cin, cout, kernel_size=3, stride=stride, padding=1, bias=bias, groups=groups - ) + return StdConv2d(cin, + cout, + kernel_size=3, + stride=stride, + padding=1, + bias=bias, + groups=groups) def conv1x1(cin, cout, stride=1, bias=False): - return StdConv2d(cin, cout, kernel_size=1, stride=stride, padding=0, bias=bias) + return StdConv2d(cin, + cout, + kernel_size=1, + stride=stride, + padding=0, + bias=bias) def tf2th(conv_weights): @@ -64,7 +73,8 @@ def __init__(self, cin, cout=None, cmid=None, stride=1): self.gn1 = nn.GroupNorm(32, cin) self.conv1 = conv1x1(cin, cmid) self.gn2 = nn.GroupNorm(32, cmid) - self.conv2 = conv3x3(cmid, cmid, stride) # Original code has it on conv1!! + self.conv2 = conv3x3(cmid, cmid, + stride) # Original code has it on conv1!! self.gn3 = nn.GroupNorm(32, cmid) self.conv3 = conv1x1(cmid, cout) self.relu = nn.ReLU(inplace=True) @@ -91,9 +101,12 @@ def forward(self, x): def load_from(self, weights, prefix=""): convname = "standardized_conv2d" with torch.no_grad(): - self.conv1.weight.copy_(tf2th(weights[f"{prefix}a/{convname}/kernel"])) - self.conv2.weight.copy_(tf2th(weights[f"{prefix}b/{convname}/kernel"])) - self.conv3.weight.copy_(tf2th(weights[f"{prefix}c/{convname}/kernel"])) + self.conv1.weight.copy_( + tf2th(weights[f"{prefix}a/{convname}/kernel"])) + self.conv2.weight.copy_( + tf2th(weights[f"{prefix}b/{convname}/kernel"])) + self.conv3.weight.copy_( + tf2th(weights[f"{prefix}c/{convname}/kernel"])) self.gn1.weight.copy_(tf2th(weights[f"{prefix}a/group_norm/gamma"])) self.gn2.weight.copy_(tf2th(weights[f"{prefix}b/group_norm/gamma"])) self.gn3.weight.copy_(tf2th(weights[f"{prefix}c/group_norm/gamma"])) @@ -108,153 +121,111 @@ def load_from(self, weights, prefix=""): class ResNetV2(nn.Module): """Implementation of Pre-activation (v2) ResNet mode.""" - def __init__(self, block_units, width_factor, head_size=21843, zero_head=False): + def __init__(self, + block_units, + width_factor, + head_size=21843, + zero_head=False): super().__init__() wf = width_factor # shortcut 'cause we'll use it a lot. # The following will be unreadable if we split lines. # pylint: disable=line-too-long self.root = nn.Sequential( - OrderedDict( - [ - ( - "conv", - StdConv2d( - 3, 64 * wf, kernel_size=7, stride=2, padding=3, bias=False - ), - ), - ("pad", nn.ConstantPad2d(1, 0)), - ("pool", nn.MaxPool2d(kernel_size=3, stride=2, padding=0)), - # The following is subtly not the same! - # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), - ] - ) - ) + OrderedDict([ + ( + "conv", + StdConv2d(3, + 64 * wf, + kernel_size=7, + stride=2, + padding=3, + bias=False), + ), + ("pad", nn.ConstantPad2d(1, 0)), + ("pool", nn.MaxPool2d(kernel_size=3, stride=2, padding=0)), + # The following is subtly not the same! + # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + ])) self.body = nn.Sequential( - OrderedDict( - [ - ( - "block1", - nn.Sequential( - OrderedDict( - [ - ( - "unit01", - PreActBottleneck( - cin=64 * wf, cout=256 * wf, cmid=64 * wf - ), - ) - ] - + [ - ( - f"unit{i:02d}", - PreActBottleneck( - cin=256 * wf, cout=256 * wf, cmid=64 * wf - ), - ) - for i in range(2, block_units[0] + 1) - ], - ) - ), - ), - ( - "block2", - nn.Sequential( - OrderedDict( - [ - ( - "unit01", - PreActBottleneck( - cin=256 * wf, - cout=512 * wf, - cmid=128 * wf, - stride=2, - ), - ) - ] - + [ - ( - f"unit{i:02d}", - PreActBottleneck( - cin=512 * wf, cout=512 * wf, cmid=128 * wf - ), - ) - for i in range(2, block_units[1] + 1) - ], - ) - ), - ), - ( - "block3", - nn.Sequential( - OrderedDict( - [ - ( - "unit01", - PreActBottleneck( - cin=512 * wf, - cout=1024 * wf, - cmid=256 * wf, - stride=2, - ), - ) - ] - + [ - ( - f"unit{i:02d}", - PreActBottleneck( - cin=1024 * wf, cout=1024 * wf, cmid=256 * wf - ), - ) - for i in range(2, block_units[2] + 1) - ], - ) - ), - ), - ( - "block4", - nn.Sequential( - OrderedDict( - [ - ( - "unit01", - PreActBottleneck( - cin=1024 * wf, - cout=2048 * wf, - cmid=512 * wf, - stride=2, - ), - ) - ] - + [ - ( - f"unit{i:02d}", - PreActBottleneck( - cin=2048 * wf, cout=2048 * wf, cmid=512 * wf - ), - ) - for i in range(2, block_units[3] + 1) - ], - ) - ), - ), - ] - ) - ) + OrderedDict([ + ( + "block1", + nn.Sequential( + OrderedDict([( + "unit01", + PreActBottleneck( + cin=64 * wf, cout=256 * wf, cmid=64 * wf), + )] + [( + f"unit{i:02d}", + PreActBottleneck( + cin=256 * wf, cout=256 * wf, cmid=64 * wf), + ) for i in range(2, block_units[0] + 1)],)), + ), + ( + "block2", + nn.Sequential( + OrderedDict([( + "unit01", + PreActBottleneck( + cin=256 * wf, + cout=512 * wf, + cmid=128 * wf, + stride=2, + ), + )] + [( + f"unit{i:02d}", + PreActBottleneck( + cin=512 * wf, cout=512 * wf, cmid=128 * wf), + ) for i in range(2, block_units[1] + 1)],)), + ), + ( + "block3", + nn.Sequential( + OrderedDict([( + "unit01", + PreActBottleneck( + cin=512 * wf, + cout=1024 * wf, + cmid=256 * wf, + stride=2, + ), + )] + [( + f"unit{i:02d}", + PreActBottleneck( + cin=1024 * wf, cout=1024 * wf, cmid=256 * wf), + ) for i in range(2, block_units[2] + 1)],)), + ), + ( + "block4", + nn.Sequential( + OrderedDict([( + "unit01", + PreActBottleneck( + cin=1024 * wf, + cout=2048 * wf, + cmid=512 * wf, + stride=2, + ), + )] + [( + f"unit{i:02d}", + PreActBottleneck( + cin=2048 * wf, cout=2048 * wf, cmid=512 * wf), + ) for i in range(2, block_units[3] + 1)],)), + ), + ])) # pylint: enable=line-too-long self.zero_head = zero_head self.head = nn.Sequential( - OrderedDict( - [ - ("gn", nn.GroupNorm(32, 2048 * wf)), - ("relu", nn.ReLU(inplace=True)), - ("avg", nn.AdaptiveAvgPool2d(output_size=1)), - ("conv", nn.Conv2d(2048 * wf, head_size, kernel_size=1, bias=True)), - ] - ) - ) + OrderedDict([ + ("gn", nn.GroupNorm(32, 2048 * wf)), + ("relu", nn.ReLU(inplace=True)), + ("avg", nn.AdaptiveAvgPool2d(output_size=1)), + ("conv", + nn.Conv2d(2048 * wf, head_size, kernel_size=1, bias=True)), + ])) def forward(self, x): x = self.head(self.body(self.root(x))) @@ -266,7 +237,8 @@ def load_from(self, weights, prefix="resnet/"): self.root.conv.weight.copy_( tf2th(weights[f"{prefix}root_block/standardized_conv2d/kernel"]) ) # pylint: disable=line-too-long - self.head.gn.weight.copy_(tf2th(weights[f"{prefix}group_norm/gamma"])) + self.head.gn.weight.copy_( + tf2th(weights[f"{prefix}group_norm/gamma"])) self.head.gn.bias.copy_(tf2th(weights[f"{prefix}group_norm/beta"])) if self.zero_head: nn.init.zeros_(self.head.conv.weight) @@ -275,26 +247,25 @@ def load_from(self, weights, prefix="resnet/"): self.head.conv.weight.copy_( tf2th(weights[f"{prefix}head/conv2d/kernel"]) ) # pylint: disable=line-too-long - self.head.conv.bias.copy_(tf2th(weights[f"{prefix}head/conv2d/bias"])) + self.head.conv.bias.copy_( + tf2th(weights[f"{prefix}head/conv2d/bias"])) for bname, block in self.body.named_children(): for uname, unit in block.named_children(): unit.load_from(weights, prefix=f"{prefix}{bname}/{uname}/") -KNOWN_MODELS = OrderedDict( - [ - ("BiT-M-R50x1", lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)), - ("BiT-M-R50x3", lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)), - ("BiT-M-R101x1", lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)), - ("BiT-M-R101x3", lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)), - ("BiT-M-R152x2", lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)), - ("BiT-M-R152x4", lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)), - ("BiT-S-R50x1", lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)), - ("BiT-S-R50x3", lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)), - ("BiT-S-R101x1", lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)), - ("BiT-S-R101x3", lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)), - ("BiT-S-R152x2", lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)), - ("BiT-S-R152x4", lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)), - ] -) +KNOWN_MODELS = OrderedDict([ + ("BiT-M-R50x1", lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)), + ("BiT-M-R50x3", lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)), + ("BiT-M-R101x1", lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)), + ("BiT-M-R101x3", lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)), + ("BiT-M-R152x2", lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)), + ("BiT-M-R152x4", lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)), + ("BiT-S-R50x1", lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)), + ("BiT-S-R50x3", lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)), + ("BiT-S-R101x1", lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)), + ("BiT-S-R101x3", lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)), + ("BiT-S-R152x2", lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)), + ("BiT-S-R152x4", lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)), +]) diff --git a/robusta/models/fixup.py b/robusta/models/fixup.py index 777f202..69ebd87 100644 --- a/robusta/models/fixup.py +++ b/robusta/models/fixup.py @@ -40,7 +40,6 @@ import torch.nn as nn import numpy as np - __all__ = [ "fixup_resnet18", "fixup_resnet34", @@ -52,14 +51,21 @@ def conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" - return nn.Conv2d( - in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False - ) + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) def conv1x1(in_planes, out_planes, stride=1): """1x1 convolution""" - return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) class FixupBasicBlock(nn.Module): @@ -139,11 +145,17 @@ def forward(self, x): class FixupResNet(nn.Module): + def __init__(self, block, layers, num_classes=1000): super(FixupResNet, self).__init__() self.num_layers = sum(layers) self.inplanes = 64 - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.conv1 = nn.Conv2d(3, + 64, + kernel_size=7, + stride=2, + padding=3, + bias=False) self.bias1 = nn.Parameter(torch.zeros(1)) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -160,56 +172,42 @@ def __init__(self, block, layers, num_classes=1000): nn.init.normal_( m.conv1.weight, mean=0, - std=np.sqrt( - 2 - / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:])) - ) - * self.num_layers ** (-0.5), + std=np.sqrt(2 / (m.conv1.weight.shape[0] * + np.prod(m.conv1.weight.shape[2:]))) * + self.num_layers**(-0.5), ) nn.init.constant_(m.conv2.weight, 0) if m.downsample is not None: nn.init.normal_( m.downsample.weight, mean=0, - std=np.sqrt( - 2 - / ( - m.downsample.weight.shape[0] - * np.prod(m.downsample.weight.shape[2:]) - ) - ), + std=np.sqrt(2 / + (m.downsample.weight.shape[0] * + np.prod(m.downsample.weight.shape[2:]))), ) elif isinstance(m, FixupBottleneck): nn.init.normal_( m.conv1.weight, mean=0, - std=np.sqrt( - 2 - / (m.conv1.weight.shape[0] * np.prod(m.conv1.weight.shape[2:])) - ) - * self.num_layers ** (-0.25), + std=np.sqrt(2 / (m.conv1.weight.shape[0] * + np.prod(m.conv1.weight.shape[2:]))) * + self.num_layers**(-0.25), ) nn.init.normal_( m.conv2.weight, mean=0, - std=np.sqrt( - 2 - / (m.conv2.weight.shape[0] * np.prod(m.conv2.weight.shape[2:])) - ) - * self.num_layers ** (-0.25), + std=np.sqrt(2 / (m.conv2.weight.shape[0] * + np.prod(m.conv2.weight.shape[2:]))) * + self.num_layers**(-0.25), ) nn.init.constant_(m.conv3.weight, 0) if m.downsample is not None: nn.init.normal_( m.downsample.weight, mean=0, - std=np.sqrt( - 2 - / ( - m.downsample.weight.shape[0] - * np.prod(m.downsample.weight.shape[2:]) - ) - ), + std=np.sqrt(2 / + (m.downsample.weight.shape[0] * + np.prod(m.downsample.weight.shape[2:]))), ) elif isinstance(m, nn.Linear): nn.init.constant_(m.weight, 0) @@ -218,7 +216,8 @@ def __init__(self, block, layers, num_classes=1000): def _make_layer(self, block, planes, blocks, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: - downsample = conv1x1(self.inplanes, planes * block.expansion, stride) + downsample = conv1x1(self.inplanes, planes * block.expansion, + stride) layers = [] layers.append(block(self.inplanes, planes, stride, downsample)) diff --git a/robusta/models/imagenet_model.py b/robusta/models/imagenet_model.py index d7a21e1..850e1cd 100644 --- a/robusta/models/imagenet_model.py +++ b/robusta/models/imagenet_model.py @@ -4,6 +4,7 @@ class ZeroOneResNet50_parallel(nn.Module): + def __init__(self, device="cuda", pretrained=True): super().__init__() self.resnet = models.resnet50(pretrained=pretrained) @@ -23,9 +24,11 @@ def forward(self, input): class ZeroOneInceptionV3(nn.Module): + def __init__(self, device="cuda", pretrained=False): super().__init__() - self.inception = models.inception_v3(pretrained=True, transform_input=True) + self.inception = models.inception_v3(pretrained=True, + transform_input=True) self.mean = nn.Parameter( torch.FloatTensor([0.485, 0.456, 0.406])[None, :, None, None], requires_grad=False, diff --git a/robusta/models/resnet_gn.py b/robusta/models/resnet_gn.py index 48220ca..d7ff996 100644 --- a/robusta/models/resnet_gn.py +++ b/robusta/models/resnet_gn.py @@ -2,15 +2,17 @@ import math import torch.utils.model_zoo as model_zoo - __all__ = ["ResNet", "resnet50", "resnet101", "resnet152"] def conv3x3(in_planes, out_planes, stride=1): "3x3 convolution with padding" - return nn.Conv2d( - in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False - ) + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) class Bottleneck(nn.Module): @@ -20,9 +22,12 @@ def __init__(self, inplanes, planes, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.GroupNorm(32, planes) - self.conv2 = nn.Conv2d( - planes, planes, kernel_size=3, stride=stride, padding=1, bias=False - ) + self.conv2 = nn.Conv2d(planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) self.bn2 = nn.GroupNorm(32, planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.GroupNorm(32, planes * 4) @@ -70,10 +75,16 @@ def gn_init(m, zero_init=False): class ResNet(nn.Module): + def __init__(self, block, layers, num_classes=1000): self.inplanes = 64 super(ResNet, self).__init__() - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.conv1 = nn.Conv2d(3, + 64, + kernel_size=7, + stride=2, + padding=3, + bias=False) self.bn1 = nn.GroupNorm(32, 64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) diff --git a/robusta/selflearning/__init__.py b/robusta/selflearning/__init__.py index ec0a4c0..9e8b816 100644 --- a/robusta/selflearning/__init__.py +++ b/robusta/selflearning/__init__.py @@ -18,11 +18,7 @@ # authors. Code taken from other open-source projects is indicated. # See NOTICE for a list of all third-party licences used in the project. -import torch.nn - -from robusta.selflearning import functional -from robusta.selflearning.nn import EntropyLoss -from robusta.selflearning.nn import GeneralizedCrossEntropy +import torch def _iter_params(model): diff --git a/robusta/selflearning/functional.py b/robusta/selflearning/functional.py index ed86b90..cf8456b 100644 --- a/robusta/selflearning/functional.py +++ b/robusta/selflearning/functional.py @@ -20,7 +20,8 @@ import torch.nn.functional as F -def gce(logits, target, q = 0.8): + +def gce(logits, target, q=0.8): """ Generalized cross entropy. Reference: https://arxiv.org/abs/1805.07836 @@ -30,10 +31,11 @@ def gce(logits, target, q = 0.8): loss = (1. - probs_with_correct_idx**q) / q return loss.mean() -def entropy(logits, target, q = 0.8): + +def entropy(logits, target, q=0.8): """ Entropy. """ log_probs = F.log_softmax(logits, dim=1) probs = F.softmax(logits, dim=1) - return -(probs * log_probs).sum(dim=-1).mean() \ No newline at end of file + return -(probs * log_probs).sum(dim=-1).mean() diff --git a/robusta/selflearning/nn.py b/robusta/selflearning/nn.py index e0117a0..8c24de2 100644 --- a/robusta/selflearning/nn.py +++ b/robusta/selflearning/nn.py @@ -19,28 +19,30 @@ # See NOTICE for a list of all third-party licences used in the project. from torch import nn -import robusta.selflearning.functional as RF +import functional as RF + class GeneralizedCrossEntropy(nn.Module): - def __init__(self, q = 0.8): + def __init__(self, q=0.8): super().__init__() self.q = q - def forward(self, logits, target = None): + def forward(self, logits, target=None): if target is None: - target = logits.argmax(dim = 1) + target = logits.argmax(dim=1) return RF.gce(logits, target, self.q) + class EntropyLoss(nn.Module): - def __init__(self, stop_teacher_gradient = False): + def __init__(self, stop_teacher_gradient=False): super().__init__() self.stop_teacher_gradient = stop_teacher_gradient - def forward(self, logits, target = None): + def forward(self, logits, target=None): if target is None: target = logits if self.top_teacher_gradient: target = target.detach() - return RF.entropy(logits, target, self.q) \ No newline at end of file + return RF.entropy(logits, target, self.q)