-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval_knn.py
88 lines (73 loc) · 3.9 KB
/
eval_knn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import argparse
from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm import tqdm
from utils.dataloader import PartialDatasetVOC
import models
from models.knn import KNNSegmentator
from utils.logger import WBLogger
import utils.transforms as _transforms
from utils import load_pretrained_weights
def main(args):
logger = WBLogger(args, group='knn', job_type=args.arch)
# Loading the backbone
backbone = models.__dict__[args.arch](
patch_size=args.patch_size,
num_classes=0
)
load_pretrained_weights(backbone, args.weights,
checkpoint_key="teacher",
model_name=args.arch,
patch_size=args.patch_size)
knn_segmentator = KNNSegmentator(backbone,
logger,
k=args.n_neighbors,
num_classes=2 if args.segmentation == 'binary' else 21,
feature=args.feature,
patch_labeling=args.patch_labeling,
background_label_percentage=args.background_label_percentage,
smooth_mask=args.smooth_mask,
weighted_majority_vote=args.weighted_majority_vote,
n_blocks=args.n_blocks,
temperature=args.temperature,
use_cuda=True)
## TRAINING DATASET ##
transform = _transforms.Compose([
_transforms.Resize(256, interpolation=_transforms.INTERPOLATION_BICUBIC),
_transforms.CenterCrop(224),
_transforms.ToTensor(),
_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
] + ([_transforms.ToBinaryMask()] if args.segmentation == 'binary' else [])
+ [_transforms.MergeContours()]
)
train_dataset = PartialDatasetVOC(percentage = args.percentage, root=args.root, image_set='train', download=False, transforms=transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers)
val_dataset = datasets.VOCSegmentation(root=args.root, image_set='val', download=False, transforms=transform)
val_loader = DataLoader(val_dataset, batch_size=args.batch_size)
knn_segmentator.fit(train_loader)
miou, iou_std = knn_segmentator.score(val_loader)
print(f'mean intersecion over union: {miou} (±{iou_std}) ')
def parser_args():
parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, default="data")
parser.add_argument('--weights', type=str, default="weights/ViT-S16.pth")
parser.add_argument('--arch', type=str, default="vit_base")
parser.add_argument('--feature', type=str, choices=['intermediate', 'query', 'key', 'value'],
default='intermediate')
parser.add_argument('--patch_labeling', type=str, choices=['coarse', 'fine'], default='coarse')
parser.add_argument('--n_neighbors', type=int, default=20)
parser.add_argument('--smooth_mask', action='store_true')
parser.add_argument('--weighted_majority_vote', action='store_true')
parser.add_argument('--patch_size', type=int, default=16)
parser.add_argument('--n_blocks', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument("--percentage", type=float, default=0.1)
parser.add_argument("--background_label_percentage", type=float, default=1.0)
parser.add_argument("--segmentation", type=str, choices=['binary', 'multi'], default='multi')
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--eval_freq", type=int, default=5)
parser.add_argument("--workers", type=int, default=4)
return parser.parse_args()
if __name__ == '__main__':
args = parser_args()
main(args)