-
Notifications
You must be signed in to change notification settings - Fork 8
/
Train.py
99 lines (79 loc) · 3.88 KB
/
Train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from HerosNet import HerosNet
from Dataset import dataset
import torch.utils.data as tud
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
import time
import datetime
import argparse
from torch.autograd import Variable
from Utils import *
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if __name__=="__main__":
print("===> New Model")
model = HerosNet(Ch=28, stages=8, size=256)
print("===> Setting GPU")
model = dataparallel(model, 1) # set the number of parallel GPUs
## Initialize weight
for layer in model.modules():
if isinstance(layer, nn.Conv2d):
nn.init.xavier_uniform_(layer.weight)
# nn.init.constant_(layer.bias, 0.0)
if isinstance(layer, nn.ConvTranspose2d):
nn.init.xavier_uniform_(layer.weight)
# nn.init.constant_(layer.bias, 0.0)
## Model Config
parser = argparse.ArgumentParser(description="PyTorch Spectral Compressive Imaging")
parser.add_argument('--data_path', default='/userhome/zxy/data/', type=str,help='Path of data')
parser.add_argument('--mask_path', default='/userhome/zxy/mask.mat', type=str,help='Path of mask')
parser.add_argument("--size", default=256, type=int, help='The training image size')
parser.add_argument("--trainset_num", default=20000, type=int, help='The number of training samples of each epoch')
parser.add_argument("--testset_num", default=10, type=int, help='Total number of testset')
parser.add_argument("--seed", default=1, type=int, help='Random_seed')
parser.add_argument("--batch_size", default=4, type=int, help='Batch_size')
parser.add_argument("--isTrain", default=True, type=bool, help='Train or test')
opt = parser.parse_args()
print("Random Seed: ", opt.seed)
torch.manual_seed(opt.seed)
torch.cuda.manual_seed(opt.seed)
print(opt)
## Load training data
key = 'train_list.txt'
file_path = opt.data_path + key
file_list = loadpath(file_path)
HR_HSI = prepare_data(opt.data_path, file_list, 30)
## Load trained model
initial_epoch = findLastCheckpoint(save_dir="./checkpoint") # load the last model in matconvnet style
if initial_epoch > 0:
print('Load model: resuming by loading epoch %03d' % initial_epoch)
model = torch.load(os.path.join("./checkpoint", 'model_%03d.pth' % initial_epoch))
## Loss function
criterion = nn.L1Loss()
## optimizer and scheduler
optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-8)
scheduler = MultiStepLR(optimizer, milestones=[], gamma=0.1) # learning rates
## pipline of training
for epoch in range(initial_epoch, 200):
model.train()
Dataset = dataset(opt, HR_HSI)
loader_train = tud.DataLoader(Dataset, num_workers=4, batch_size=opt.batch_size, shuffle=True)
scheduler.step(epoch)
epoch_loss = 0
start_time = time.time()
for i, label in enumerate(loader_train):
label = Variable(label)
label = label.cuda()
out, _ = model(label)
loss = criterion(out[7], label)+0.5*criterion(out[6], label)+0.5*criterion(out[5], label)
epoch_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % (50) == 0:
print('%4d %4d / %4d loss = %.10f time = %s' % (
epoch + 1, i, len(Dataset)// opt.batch_size, epoch_loss / ((i+1) * opt.batch_size), datetime.datetime.now()))
elapsed_time = time.time() - start_time
print('epcoh = %4d , loss = %.10f , time = %4.2f s' % (epoch + 1, epoch_loss / len(Dataset), elapsed_time))
np.savetxt('train_result.txt', np.hstack((epoch + 1, epoch_loss / i, elapsed_time)), fmt='%2.4f')
torch.save(model, os.path.join("./checkpoint", 'model_%03d.pth' % (epoch + 1)))