-
Notifications
You must be signed in to change notification settings - Fork 61
/
torch_models.py
57 lines (40 loc) · 1.97 KB
/
torch_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Example of predicting parameters for the networks from torchvision, such as ResNet-50.
See possible networks at https://pytorch.org/vision/stable/models.html.
The script predicts parameters for the networks and evaluates them on CIFAR-10 or ImageNet.
Example:
1. python examples/torch_models.py imagenet resnet50
2. python examples/torch_models.py cifar10 convnext_base
"""
import torchvision
import sys
from ppuda.vision.loader import image_loader
from ppuda.ghn.nn import GHN2
from ppuda.utils import capacity, adjust_net, infer
try:
dataset = sys.argv[1].lower() # imagenet, cifar10
arch = sys.argv[2].lower() # resnet50, wide_resnet101_2, etc.
ghn = GHN2(dataset)
except:
print('\nExample of usage: python examples/torch_models.py imagenet resnet50\n')
raise
is_imagenet = dataset == 'imagenet'
images_val, num_classes = image_loader(dataset, num_workers=8 * is_imagenet)[1:]
if is_imagenet:
images_val.sampler.generator.manual_seed(1111) # set the generator seed to reproduce results
kw_args = {'aux_logits': False, 'init_weights': False} if arch == 'googlenet' else {} # ignore auxiliary outputs in googlenet for this example
# Predict all parameters (any network from torchvision.models can be used)
model = ghn(adjust_net(eval('torchvision.models.%s(num_classes=%d, **kw_args)' % (arch, num_classes)),
large_input=is_imagenet))
print('\nEvaluation of {} with {} parameters'.format(arch.upper(), capacity(model)[1]))
top1, top5 = infer(model, images_val, verbose=True)
# top5=5.27 for ResNet-50 on ImageNet and top1=58.62 on CIFAR-10
if arch == 'resnet50':
if (is_imagenet and abs(top5 - 5.27) > 0.01) or (not is_imagenet and top1 != 58.62):
print('WARNING: results appear to be different from expected!' )
print('\ndone')