-
Notifications
You must be signed in to change notification settings - Fork 22
/
val_demo.py
103 lines (81 loc) · 3.2 KB
/
val_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright (c) 2022 Yawei Li, Kai Zhang, Radu Timofte SPDX license identifier: MIT
# This file may have been modified by ByteDance
import os.path
import logging
from collections import OrderedDict
import torch
from utils import utils_logger
from utils import utils_image as util
from model.rlfn_ntire import RLFN_Prune
def main():
utils_logger.logger_info('NTIRE2022-EfficientSR', log_path='NTIRE2022-EfficientSR.log')
logger = logging.getLogger('NTIRE2022-EfficientSR')
# --------------------------------
# basic settings
# --------------------------------
# testsets = 'DIV2K 801-900'
testsets = os.path.join(os.getcwd(), 'data')
testset_L = 'DIV2K_valid_LR_bicubic'
torch.cuda.current_device()
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# --------------------------------
# load model
# --------------------------------
model_path = os.path.join('model_zoo', 'rlfn_ntire_x4.pth')
model = RLFN_Prune(in_channels=3, out_channels=3)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
# number of parameters
number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
logger.info('Params number: {}'.format(number_parameters))
# --------------------------------
# read image
# --------------------------------
L_folder = os.path.join(testsets, testset_L, 'X4')
E_folder = os.path.join(testsets, testset_L+'_results')
util.mkdir(E_folder)
# record PSNR, runtime
test_results = OrderedDict()
test_results['runtime'] = []
logger.info(L_folder)
logger.info(E_folder)
idx = 0
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for img in util.get_image_paths(L_folder):
# --------------------------------
# (1) img_L
# --------------------------------
idx += 1
img_name, ext = os.path.splitext(os.path.basename(img))
logger.info('{:->4d}--> {:>10s}'.format(idx, img_name+ext))
img_L = util.imread_uint(img, n_channels=3)
img_L = util.uint2tensor4(img_L)
img_L = img_L * 255.
img_L = img_L.to(device)
start.record()
img_E = model(img_L)
end.record()
torch.cuda.synchronize()
test_results['runtime'].append(start.elapsed_time(end)) # milliseconds
# torch.cuda.synchronize()
# start = time.time()
# img_E = model(img_L)
# torch.cuda.synchronize()
# end = time.time()
# test_results['runtime'].append(end-start) # seconds
# --------------------------------
# (2) img_E
# --------------------------------
img_E = img_E / 255.
img_E = util.tensor2uint(img_E)
util.imsave(img_E, os.path.join(E_folder, img_name[:4]+ext))
ave_runtime = sum(test_results['runtime']) / len(test_results['runtime']) / 1000.0
logger.info('------> Average runtime of ({}) is : {:.6f} seconds'.format(L_folder, ave_runtime))
if __name__ == '__main__':
main()