Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Robust segvit compatibility + new dataset. #1183

Open
wants to merge 170 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 134 commits
Commits
Show all changes
170 commits
Select commit Hold shift + click to select a range
fe21619
Merge branch 'google:main' into segmenter
ekellbuch Nov 11, 2021
902c118
init deterministic.py for segmenter loader using scenic loader of cit…
ekellbuch Nov 15, 2021
82bd6c9
fix bug in segmenter: vit+ backbone classifier was not inherited
ekellbuch Nov 15, 2021
fe1f68d
include data and models in deterministic.py
ekellbuch Nov 16, 2021
fe3f1a2
Merge branch 'google:main' into citytrain
ekellbuch Nov 16, 2021
27e2cf0
include segmenter data loader and trainer using scenic's infra
ekellbuch Nov 16, 2021
49db9dd
include flags to debug on mac using only a subset of training data av…
ekellbuch Nov 16, 2021
e56df3b
this commit includes a trainable implementation segmenter on cityscap…
ekellbuch Nov 16, 2021
7f6eb46
Merge branch 'google:main' into citytrain
ekellbuch Nov 16, 2021
08aad99
include configs to debug model on gcp vms
ekellbuch Nov 17, 2021
e38fd6a
Merge branch 'citytrain' of https://github.com/ekellbuch/uncertainty-…
ekellbuch Nov 17, 2021
3977b94
Merge branch 'google:main' into citytrain
ekellbuch Nov 17, 2021
85f3a6f
update patch size
ekellbuch Nov 17, 2021
e33696e
add config to train model on all data for 1 epoch
ekellbuch Nov 17, 2021
2b81aef
add init_from flag to allow model to load pretrained checkpoints
ekellbuch Nov 17, 2021
5de076c
include flag to check for pretrained backbone
ekellbuch Nov 17, 2021
afa8fec
include config for 100 epochs
ekellbuch Nov 17, 2021
88ee518
Merge branch 'google:main' into load_check
ekellbuch Nov 17, 2021
6757cbb
Merge branch 'load_check' of https://github.com/ekellbuch/uncertainty…
ekellbuch Nov 17, 2021
245de00
(1) vit backbone classifier is fixed to 'gap' for segmenter model (2)…
ekellbuch Nov 17, 2021
062172b
update model
ekellbuch Nov 29, 2021
7c02d78
add code to preload weights -- which can fail if frozen dictionary is…
ekellbuch Nov 29, 2021
d73defb
Merge branch 'citytrain' of https://github.com/ekellbuch/uncertainty-…
ekellbuch Nov 29, 2021
a55c555
Merge branch 'master' into load_check
ekellbuch Nov 29, 2021
1dc1998
(1) load weights of network pretrained on imagenet
ekellbuch Nov 30, 2021
6ed9fa7
fix tou config to debug segmenter with pretrained backbone
ekellbuch Nov 30, 2021
b731032
update config files for pretrained weights run in vm
ekellbuch Nov 30, 2021
b747e07
add config files to compare model trained from scratch, init using de…
ekellbuch Nov 30, 2021
f7c9351
add call to tensorboard to compare different runs in an experiment
ekellbuch Nov 30, 2021
805abb1
Merge branch 'google:main' into load_check
ekellbuch Nov 30, 2021
bf0d783
add config files for tpu pods
ekellbuch Dec 3, 2021
5fc0560
update tpu pod config
ekellbuch Dec 3, 2021
e143662
add run file for model with 256x256
ekellbuch Dec 6, 2021
62329a8
add run file for model with 512x512 img
ekellbuch Dec 6, 2021
8e7761a
update config files for 512 run'
ekellbuch Dec 6, 2021
66a7712
add script to train model w different splits -- compatible w scenic/c…
ekellbuch Dec 7, 2021
6697ee3
add run with 10 and 100% of data'
ekellbuch Dec 7, 2021
6cc984c
fix bug with deterministic 10% train split, add config to train from …
ekellbuch Dec 7, 2021
51f0fe2
update experimental configs
ekellbuch Dec 10, 2021
1fa4602
update script to run experiments
ekellbuch Dec 10, 2021
ef401ec
fix local debugger for deterministic experiments
ekellbuch Dec 10, 2021
a92a9c6
updated readme
ekellbuch Jan 4, 2022
5254780
add eval code to store model outputs
ekellbuch Jan 6, 2022
9afb73f
add call to eval multiple splits
ekellbuch Jan 6, 2022
e4a1a6e
update logit dimensionality
ekellbuch Jan 6, 2022
7f916b0
update eval code to write directly to bucket
ekellbuch Jan 6, 2022
ba64522
add uncertainty calculations during model eval
ekellbuch Jan 11, 2022
1895cd3
add code to store metrics for all runs
ekellbuch Jan 11, 2022
6e0eb0e
fix bug in uncertainty calculate
ekellbuch Jan 11, 2022
e328239
update custom models
ekellbuch Jan 18, 2022
b02d0f1
ignore cases where mask is 0, still need to find a nice workaround
ekellbuch Jan 18, 2022
a49b80e
add comment to pretrainer utils download which has config inputs not …
ekellbuch Jan 31, 2022
74dd7ff
update debug config to include base model for vit-l32
ekellbuch Jan 31, 2022
868c8c6
add code to debug a toy model on vm
ekellbuch Jan 31, 2022
f287731
include training config for vit-l32 model, directly read positional e…
ekellbuch Feb 1, 2022
d044637
add experiment config for different splits using vit-l32 models
ekellbuch Feb 1, 2022
77d1f1e
add config file with default batch size matches # tpu
ekellbuch Feb 2, 2022
b309462
update default debugging params to be compatible with osx
ekellbuch Feb 8, 2022
728e502
add code to (1) run ensemble (2) read metrics from ensemble
ekellbuch Feb 8, 2022
14fcb46
include new experiments to eval changes for learning_rate, # training…
ekellbuch Feb 8, 2022
f7fa226
add code to eval vitl32 models
ekellbuch Feb 8, 2022
9fa2047
add code to train multiple hyperparams simulatenously
ekellbuch Feb 8, 2022
cac843f
add code compatible with wandb for hyperparam search
ekellbuch Feb 8, 2022
e46879f
update readme
ekellbuch Feb 8, 2022
6733d65
add config file compatible with wandb
ekellbuch Feb 8, 2022
a06a446
update bash script to run with bash and not sh
ekellbuch Feb 8, 2022
f2641ff
update code so wandb inherits tensorboard logs
ekellbuch Feb 8, 2022
97aca33
add early stopping flag
ekellbuch Feb 9, 2022
6add978
add deterministic
ekellbuch Feb 9, 2022
a9240f1
fix bug where checkpoints where not stored in gs when wandb was used:…
ekellbuch Feb 10, 2022
62d27fb
integrate code for hyperparameter sweep in wanbd
ekellbuch Feb 10, 2022
56b368d
Merge branch 'google:main' into hyper_tune
ekellbuch Feb 10, 2022
9197124
add vit_batchensemble test to debug model loading
ekellbuch Feb 11, 2022
500644f
Merge branch 'google:main' into batch_ensemble_upstream
ekellbuch Feb 11, 2022
4135b49
update config file for ensemble run
ekellbuch Feb 11, 2022
71baf2c
add option to turn on stochastic_depth to vit transformer encoder
ekellbuch Feb 11, 2022
c728128
update segmenter code so it uses vit model module
ekellbuch Feb 11, 2022
34a61f1
add stochastic layer to batch ensemble model
ekellbuch Feb 11, 2022
22af10a
update ensemble config names
ekellbuch Feb 11, 2022
1dda6bd
add segmenter_be_model
ekellbuch Feb 11, 2022
c01a8c7
add note about updating targets
ekellbuch Feb 11, 2022
5a64c87
merge
ekellbuch Sep 12, 2022
394b2e0
add baseline changes to train segmenter model on ade20k dataset on mac
ekellbuch Sep 26, 2022
cf6a4af
update run_toy_mac to select between city or ade20k
ekellbuch Sep 26, 2022
fefe26b
remove duplicated eval configs and call to numpy call to masked array
ekellbuch Sep 26, 2022
9aae57c
update default config for toy_model
ekellbuch Sep 26, 2022
3ff6f86
add street_hazards to custom_segmentation_trainer
ekellbuch Sep 26, 2022
5a3a902
add street_hazards configurations
ekellbuch Sep 27, 2022
e819957
remove call to wandb
ekellbuch Sep 27, 2022
087af3a
update bug in checkpoint name
ekellbuch Sep 27, 2022
01977d2
set wandb default = False
ekellbuch Sep 27, 2022
ba6e1e4
debug loading gcp checkpoint from vision_transformer
ekellbuch Sep 27, 2022
f6cb15f
update config to match train_target_size
ekellbuch Sep 27, 2022
b1b0cb1
add wanbd config options to all config files
ekellbuch Sep 27, 2022
f7aee61
update pavpu calculation based on softmax
ekellbuch Sep 28, 2022
ffd92da
add -1*mlogit as an ood metric
ekellbuch Sep 29, 2022
6710faa
add street hazards gp config
ekellbuch Sep 29, 2022
fd93a84
add class to compute metrics offline
ekellbuch Sep 29, 2022
89e4f88
update gp hparam search to only include mean_field_factor
ekellbuch Sep 29, 2022
9589a11
fix bug where the config called was deterministic
ekellbuch Sep 29, 2022
f6c2aca
update batch_size to fit in memory
ekellbuch Sep 29, 2022
128eaa2
update img dimension used for experiment
ekellbuch Sep 29, 2022
d0e9fed
remove old cityscapes dataset
ekellbuch Sep 29, 2022
2bd0347
Merge branch 'google:main' into run_open
ekellbuch Sep 29, 2022
c0e09c6
remove local changes not needed in main repository
ekellbuch Sep 29, 2022
9e29d26
remove unnecessary changes to main branch
ekellbuch Sep 29, 2022
939446a
add configs for runs with different seeds for segmenter and segmenter…
ekellbuch Sep 29, 2022
973dcd9
update uncertainty matrix calculation to support multi host and multi…
ekellbuch Oct 3, 2022
2f13f76
add deterministic call for wandb
ekellbuch Oct 3, 2022
3a4022a
update batch size for deterministic city
ekellbuch Oct 3, 2022
63c8156
add custom_segmentation_trainer with store_logits code
ekellbuch Oct 3, 2022
a3dc36b
add segmm torch model eval
ekellbuch Oct 4, 2022
48fec99
update be_eval to load checkpoint even when running locally
ekellbuch Oct 4, 2022
c756e02
add configuration to eval models on cityscapes
ekellbuch Oct 4, 2022
f9484f2
add wandb yaml files to call eval experiments
ekellbuch Oct 4, 2022
1306955
fix syntax error in wandb yaml files
ekellbuch Oct 4, 2022
69769c9
fix bug in name of checkpoint
ekellbuch Oct 4, 2022
972a86a
add eval config files for default (non-opt) parameter for different m…
ekellbuch Oct 4, 2022
9ee5181
add wandb yaml files to call eval loaders
ekellbuch Oct 4, 2022
c2469a5
add store_logits flag to eval config files
ekellbuch Oct 4, 2022
9e817a6
add wandb yaml files to store logits
ekellbuch Oct 4, 2022
5d75b11
update name of default directory
ekellbuch Oct 4, 2022
eba4cd5
when computing ood metrics skip images wo any ood pixels or images wh…
ekellbuch Oct 4, 2022
ea02a02
add wandb yaml config where we use 1-msp as ood score
ekellbuch Oct 4, 2022
2fd0bcd
add wand yaml files with ood_score=msp
ekellbuch Oct 4, 2022
19a204d
update test for multihost metrics to exclude images where all pixels …
ekellbuch Oct 4, 2022
c9cf756
add wandb yaml file to evaluate deterministic ade20k model
ekellbuch Oct 4, 2022
c012056
clean up readme
ekellbuch Nov 2, 2022
bd6b60d
clean readme
ekellbuch Nov 2, 2022
1c3e2cf
Merge branch 'google:main' into run_open
ekellbuch Nov 2, 2022
28a6532
add multihost ece calculation
ekellbuch Nov 3, 2022
8078ead
update script file to run toy experiments (run_toy_mac.sh)
ekellbuch Nov 3, 2022
ccd3cf1
add config file and wandb yaml file to call experiments
ekellbuch Nov 3, 2022
efa4a66
update batch size to fix oom issues
ekellbuch Nov 3, 2022
71716bf
(1) add comments to the multihost class and update trainer to use thi…
ekellbuch Nov 3, 2022
74caf1e
merge checkpoint code
ekellbuch Nov 3, 2022
68f1239
use (-1) as a scaling factor for ood
ekellbuch Nov 3, 2022
7aa4efd
use ood score without (-1) factor
ekellbuch Nov 3, 2022
e20c435
remove dropout factor in deterministic_eval
ekellbuch Nov 3, 2022
748e4b0
add models to evaluate the performance of the gp model
ekellbuch Nov 3, 2022
289caf0
add code to train het model on street hazatds dataset
ekellbuch Nov 3, 2022
0545d59
fix default lr for het model given det model results
ekellbuch Nov 3, 2022
ce513b9
add config file to train be model on street hazards
ekellbuch Nov 3, 2022
7b16a72
add calibration AUCROC metric
ekellbuch Nov 4, 2022
df0b06d
use softmax call
ekellbuch Nov 7, 2022
ec3e5c1
Merge branch 'google:main' into run_open
ekellbuch Nov 7, 2022
e393f3c
remove unnecessary call to del params
ekellbuch Nov 7, 2022
e8caba5
Merge branch 'run_open' of https://github.com/ekellbuch/uncertainty-b…
ekellbuch Nov 7, 2022
9bcacca
add script code to call eval in all models for all datasets
ekellbuch Nov 8, 2022
dc9c4c0
update the prefix for the ood metrics to easily plot these
ekellbuch Nov 8, 2022
b665aa0
update prefix to group metrics according to corruption/level instead …
ekellbuch Nov 9, 2022
3c73e26
update trainer to call the same data loader for any dataset + include…
ekellbuch Nov 9, 2022
4be5347
update cityscapes config to use the same data loader as ade20k and st…
ekellbuch Nov 9, 2022
bed1293
use updated data laoder for all datasets
ekellbuch Nov 10, 2022
1be3cec
add street_hazards_corrupted
ekellbuch Nov 10, 2022
92ef059
Merge branch 'dev1' into run_open
ekellbuch Nov 12, 2022
54c2d93
update config to use msp for default ood metric
ekellbuch Nov 12, 2022
a1793ae
add script to train plex model with different seeds
ekellbuch Nov 12, 2022
aec65b5
update order for call to config
ekellbuch Nov 12, 2022
ff87a4e
add bash script to train model with runlocal
ekellbuch Nov 12, 2022
ce04a6e
include wandb in debug call
ekellbuch Nov 12, 2022
653a6b0
add wandb option to code call
ekellbuch Nov 12, 2022
e0f92fa
update multihost calculation for auc
ekellbuch Nov 13, 2022
b33beb4
add del code to free memory
ekellbuch Nov 13, 2022
9b085c5
set toy default code to run cityscapes
ekellbuch Nov 13, 2022
b3a722f
update default call for auc calculation
ekellbuch Nov 13, 2022
d6c8cd6
remove call to computescoreaucmetric
ekellbuch Nov 13, 2022
7e82e2f
remove ood calculation based on -1*score given that tf.keras.auc does…
ekellbuch Nov 13, 2022
87fc8b0
slim trainer
ekellbuch Nov 13, 2022
ad4e53f
Merge branch 'dev1' into run_open
ekellbuch Nov 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions experimental/robust_segvit/README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
# Robust segvit

*Robust_segvit* is a codebase to evaluate the robustness of semantic segmentation models.
**Robust_segvit** is a codebase to evaluate the robustness of semantic segmentation models. The code is built on top of [uncertainty_baselines](https://github.com/google/uncertainty-baselines) and [Scenic](https://github.com/google-research/scenic).

Robust_segvit is developed in [JAX](https://github.com/google/jax) and uses [Flax](https://github.com/google/flax), [uncertainty_baselines](https://github.com/google/uncertainty-baselines) and [Scenic](https://github.com/google-research/scenic).
## Installation
Robust_segvit is developed in [JAX](https://github.com/google/jax)/[Flax](https://github.com/google/flax).

## Code structure
See uncertainty_baselines/google/experimental/cityscapes.
To run the code: <br>
1. Install [uncertainty_baselines](https://github.com/google/uncertainty-baselines). <br>
2. Install [Scenic](https://github.com/google-research/scenic). <br>
3. Follow the instructions for a toy run in [./run_deterministic_mac.sh]().

## Datasets
The experiment configurations for the different datasets are in:

## Cityscapes
- configs/cityscapes: Cityscapes dataset. <br>
- configs/ade20k_ind: ADE20k_ind dataset. <br>
- configs/street_hazards: Street Hazards dataset. <br>

We investigate the performance of different reliability methods on image segmentation tasks. <br>
## Comments:
- The checkpoint used for finetuning is the same the original segmenter model: [vit_large_patch16_384](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py)

[x] configs/cityscapes: contains experiment configurations for the cityscapes dataset. <br>
## Citing work:

If you reference this code, please cite [our paper](https://github.com/google/uncertainty-baselines). <br>
Empty file.
29 changes: 25 additions & 4 deletions experimental/robust_segvit/configs/ade20k_ind/be.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
# pylint: enable=line-too-long

import ml_collections
import os
import datetime

_CITYSCAPES_FINE_TRAIN_SIZE = 2975
_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
Expand All @@ -40,21 +42,24 @@

# Model specs.
LOAD_PRETRAINED_BACKBONE = True
BACKBONE_ORIGIN = 'big_vision'
BACKBONE_ORIGIN = 'vision_transformer'
VIT_SIZE = 'L'
STRIDE = 16
RESNET_SIZE = None
CLASSIFIER = 'token'
target_size = (640, 640)
UPSTREAM_TASK = 'i21k+imagenet2012'
UPSTREAM_TASK = 'augreg+i21k+imagenet2012'


# Upstream
MODEL_PATHS = {

# Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384
('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'):
'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz',
('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', # pylint: disable=g-long-lambda

}


Expand Down Expand Up @@ -180,9 +185,25 @@ def get_config(runlocal=''):
config.eval_label_shift = False
config.model.input_shape = target_size

config.eval_robustness_configs = ml_collections.ConfigDict()
config.eval_robustness_configs.auc_online = True
config.eval_robustness_configs.method_name = 'mlogit'

# wandb.ai configurations.
config.use_wandb = False
config.wandb_dir = 'wandb'
config.wandb_project = 'rdl-debug'
config.wandb_entity = 'ekellbuch'
config.wandb_exp_name = None # Give experiment a name.
config.wandb_exp_name = (
os.path.splitext(os.path.basename(__file__))[0] + '_' +
datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
config.wandb_exp_group = None # Give experiment a group name.

if runlocal:
config.count_flops = False
config.dataset_configs.train_target_size = (128, 128)
config.model.input_shape = config.dataset_configs.train_target_size
config.batch_size = 8
config.num_training_epochs = 5
config.warmup_steps = 0
Expand Down
38 changes: 25 additions & 13 deletions experimental/robust_segvit/configs/ade20k_ind/be_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
# pylint: enable=line-too-long

import ml_collections
import datetime
import os

_CITYSCAPES_FINE_TRAIN_SIZE = 2975
_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
Expand All @@ -42,14 +44,14 @@
STRIDE = 16
RESNET_SIZE = None
CLASSIFIER = 'token'
EXPERIMENTID = '43838358-2'
EXPERIMENTID = '45349725-1'

target_size = (640, 640)

# Upstream
CHECKPOINT_PATHS = {
('ub', 'L', 16, None, 'token', '43838358-2'):
'gs://ub-ekb/checkpoints_to_upload/ade20k/43838358-2',
('ub', 'L', 16, None, 'token', '45349725-1'):
'gs://ub-checkpoints/45349725-ade20k_ind_segmenter_be/1',
}


Expand Down Expand Up @@ -162,17 +164,33 @@ def get_config(runlocal=''):
config.eval_mode = True
config.eval_configs = ml_collections.ConfigDict()
config.eval_configs.mode = 'standard'
config.eval_covariate_shift = True
config.eval_label_shift = True
config.model.input_shape = target_size
config.eval_configs.store_logits = False

# Eval parameters for robustness
config.eval_label_shift = True
config.eval_covariate_shift = True
config.eval_robustness_configs = ml_collections.ConfigDict()
config.eval_robustness_configs.auc_online = True
config.eval_robustness_configs.method_name = 'msp'
config.eval_robustness_configs.num_top_k = 5
config.eval_robustness_configs.method_name = 'nmlogit'
config.eval_robustness_configs.num_top_k = 1

# Load checkpoint
config.checkpoint_configs = ml_collections.ConfigDict()
config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN
config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH
config.checkpoint_configs.classifier = 'token'

# wandb.ai configurations.
config.use_wandb = False
config.wandb_dir = 'wandb'
config.wandb_project = 'rdl-debug'
config.wandb_entity = 'ekellbuch'
config.wandb_exp_name = None # Give experiment a name.
config.wandb_exp_name = (
os.path.splitext(os.path.basename(__file__))[0] + '_' +
datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
config.wandb_exp_group = None # Give experiment a group name.

if runlocal:
config.count_flops = False
Expand All @@ -183,12 +201,6 @@ def get_config(runlocal=''):
config.dataset_configs.train_split = f'train[:{TRAIN_SAMPLES}]'
config.dataset_configs.validation_split = f'validation[:{TRAIN_SAMPLES}]'
config.num_train_examples = TRAIN_SAMPLES
else:
# Load checkpoint
config.checkpoint_configs = ml_collections.ConfigDict()
config.checkpoint_configs.checkpoint_format = CHECKPOINT_ORIGIN
config.checkpoint_configs.checkpoint_path = CHECKPOINT_PATH
config.checkpoint_configs.classifier = 'token'
return config


Expand Down
28 changes: 24 additions & 4 deletions experimental/robust_segvit/configs/ade20k_ind/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
# pylint: enable=line-too-long

import ml_collections
import os
import datetime

_CITYSCAPES_FINE_TRAIN_SIZE = 2975
_CITYSCAPES_COARSE_TRAIN_SIZE = 19998
Expand All @@ -40,21 +42,23 @@

# Model specs.
LOAD_PRETRAINED_BACKBONE = True
BACKBONE_ORIGIN = 'big_vision'
BACKBONE_ORIGIN = 'vision_transformer'
VIT_SIZE = 'L'
STRIDE = 16
RESNET_SIZE = None
CLASSIFIER = 'token'
target_size = (640, 640)
UPSTREAM_TASK = 'i21k+imagenet2012'
UPSTREAM_TASK = 'augreg+i21k+imagenet2012'


# Upstream
MODEL_PATHS = {

# Imagenet 21k + finetune in imagenet2012 with perf 0.85 adap_res 384
('big_vision', 'L', 16, None, 'token', 'i21k+imagenet2012'):
'gs://vit_models/imagenet21k%2Bimagenet2012/ViT-L_16.npz',
('vision_transformer', 'L', 16, None, 'token', 'i21k+imagenet2012'):
'gs://vit_models/imagenet21k+imagenet2012/ViT-L_16.npz',
('vision_transformer', 'L', 16, None, 'token', 'augreg+i21k+imagenet2012'):
'gs://vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
}


Expand Down Expand Up @@ -174,9 +178,25 @@ def get_config(runlocal=''):
config.eval_label_shift = False
config.model.input_shape = target_size

config.eval_robustness_configs = ml_collections.ConfigDict()
config.eval_robustness_configs.auc_online = True
config.eval_robustness_configs.method_name = 'mlogit'

# wandb.ai configurations.
config.use_wandb = False
config.wandb_dir = 'wandb'
config.wandb_project = 'rdl-debug'
config.wandb_entity = 'ekellbuch'
config.wandb_exp_name = None # Give experiment a name.
config.wandb_exp_name = (
os.path.splitext(os.path.basename(__file__))[0] + '_' +
datetime.datetime.today().strftime('%Y-%m-%d-%H-%M-%S'))
config.wandb_exp_group = None # Give experiment a group name.

if runlocal:
config.count_flops = False
config.dataset_configs.train_target_size = (128, 128)
config.model.input_shape = config.dataset_configs.train_target_size
config.batch_size = 8
config.num_training_epochs = 5
config.warmup_steps = 0
Expand Down
Loading