Skip to content

Commit

Permalink
Update BaseCNN.py
Browse files Browse the repository at this point in the history
  • Loading branch information
zwx8981 authored Mar 28, 2019
1 parent 5f4177c commit f0c21ff
Showing 1 changed file with 56 additions and 30 deletions.
86 changes: 56 additions & 30 deletions BaseCNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,34 +40,43 @@ def default_loader(path):



class BaseCNN(torch.nn.Module):
class BaseCNN(nn.Module):

def __init__(self, options):
"""Declare all needed layers."""
nn.Module.__init__(self)
# Convolution and pooling layers of VGG-16.
self.basemodel = torchvision.models.resnet101(pretrained=True)
self.basemodel = torchvision.models.resnet18(pretrained=True)
self.options = options

# Linear classifier.
self.fc = torch.nn.Linear(2048, 1)
self.fc = nn.Linear(512, 1)

if self.options['multitask']:
self.fc2 = nn.Linear(512, 4)

if options['fc'] == True:
if self.options['fc'] == True:
# Freeze all previous layers.
for param in self.basemodel.parameters():
param.requires_grad = False
# Initialize the fc layers.
nn.init.kaiming_normal_(self.fc.weight.data)
if self.fc.bias is not None:
nn.init.constant_(self.fc.bias.data, val=0)
if self.options['multitask']:
nn.init.kaiming_normal_(self.fc2.weight.data)
if self.fc2.bias is not None:
nn.init.constant_(self.fc2.bias.data, val=0)

else:
for param in self.basemodel.conv1.parameters():
param.requires_grad = False
for param in self.basemodel.bn1.parameters():
param.requires_grad = False
for param in self.basemodel.layer1.parameters():
param.requires_grad = False
for param in self.basemodel.layer2.parameters():
param.requires_grad = False
#for param in self.basemodel.layer1.parameters():
# param.requires_grad = False
#for param in self.basemodel.layer2.parameters():
# param.requires_grad = False
#for param in self.basemodel.layer3.parameters():
# param.requires_grad = False

Expand All @@ -86,8 +95,13 @@ def forward(self, X):
X = self.basemodel.avgpool(X)
X = X.squeeze(2).squeeze(2)
#X = torch.mean(torch.mean(X4,2),2)
X = self.fc(X)
return X
if self.options['multitask']:
X1 = self.fc(X)
X2 = self.fc2(X)
return X1, X2
else:
X = self.fc(X)
return X


class TrainManager(object):
Expand All @@ -101,7 +115,7 @@ def __init__(self, options, path):
self._path = path

# Network.
self._net = torch.nn.DataParallel(BaseCNN(self._options), device_ids=[0]).cuda()
self._net = nn.DataParallel(BaseCNN(self._options), device_ids=[0]).cuda()
if self._options['fc'] == False:
self._net.load_state_dict(torch.load(path['fc_root']))

Expand All @@ -113,13 +127,15 @@ def __init__(self, options, path):
self._criterion = nn.L1Loss().cuda()
else:
self._criterion = nn.SmoothL1Loss().cuda()
if self._options['multitask']:
self._criterion_cls = nn.CrossEntropyLoss().cuda()
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Solver.
if self._options['fc'] == True:
self._solver = torch.optim.Adam(
self._net.module.fc.parameters(), lr=self._options['base_lr'],
weight_decay=self._options['weight_decay'])
self._scheduler = torch.optim.lr_scheduler.StepLR(self._solver, step_size=6, gamma=0.1)
self._scheduler = torch.optim.lr_scheduler.StepLR(self._solver, step_size=10, gamma=0.1)
else:
self._solver = torch.optim.Adam(
self._net.module.parameters(), lr=self._options['base_lr'],
Expand All @@ -128,15 +144,15 @@ def __init__(self, options, path):

train_transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.Resize((768, 576)),
torchvision.transforms.RandomCrop((576, 432)),
torchvision.transforms.Resize((384, 288)),
torchvision.transforms.RandomCrop((288, 216)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))])


test_transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize((768, 576)),
torchvision.transforms.Resize((384, 288)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))])
Expand Down Expand Up @@ -164,29 +180,34 @@ def train(self):
pscores = []
tscores = []
num_total = 0
for X, y in tqdm(self._train_loader):
self._scheduler.step()
for X, y, b in tqdm(self._train_loader):
# Data.
X = X.to(self.device)
y = y.to(self.device)
b = b.to(self.device)
#X = torch.tensor(X.cuda())
#y = torch.tensor(y.cuda())

# Clear the existing gradients.
self._solver.zero_grad()
# Forward pass.
score = self._net(X)
loss = self._criterion(score, y.view(len(score),1).detach())
if self._options['multitask']:
score, cls = self._net(X)
loss = self._criterion(score, y.view(len(score), 1).detach()) + 0.1 * self._criterion_cls(cls, b.detach())
else:
score = self._net(X)
loss = self._criterion(score, y.view(len(score), 1).detach())
epoch_loss.append(loss.item())
# Prediction.
num_total += y.size(0)
pscores = pscores + score.cpu().tolist()
pscores = pscores + score.cpu().tolist()
tscores = tscores + y.cpu().tolist()
# Backward pass.
loss.backward()
self._solver.step()
train_srcc, _ = stats.spearmanr(pscores, tscores)
test_srcc, test_plcc = self._consitency(self._test_loader)
self._scheduler.step()
if test_srcc > best_srcc:
best_srcc = test_srcc
best_epoch = t + 1
Expand All @@ -209,21 +230,24 @@ def _consitency(self, data_loader):
num_total = 0
pscores = []
tscores = []
for X, y in data_loader:
for X, y, _ in tqdm(data_loader):
# Data.
X = X.to(self.device)
y = y.to(self.device)
#X = torch.tensor(X.cuda())
#y = torch.tensor(y.cuda())

# Prediction.
score = self._net(X)
if self._options['multitask']:
score, _ = self._net(X)
else:
score = self._net(X)
pscores = pscores + score[0].cpu().tolist()
tscores = tscores + y.cpu().tolist()

num_total += y.size(0)
test_srcc, _ = stats.spearmanr(pscores,tscores)
test_plcc, _ = stats.pearsonr(pscores,tscores)
#num_total += y.size(0)
test_srcc, _ = stats.spearmanr(pscores, tscores)
test_plcc, _ = stats.pearsonr(pscores, tscores)
self._net.train(True) # Set the model to training phase
return test_srcc, test_plcc

Expand All @@ -241,8 +265,9 @@ def main():
parser.add_argument('--weight_decay', dest='weight_decay', type=float,
default=5e-4, help='Weight decay.')
parser.add_argument('--objective', dest='objective', type=str,
default='smoothl1', help='l1 | l2 | smoothl1')

default='l2', help='l1 | l2 | smoothl1')
parser.add_argument('--multitask', dest='multitask', type=bool,
default=False, help='True or False')

args = parser.parse_args()
if args.base_lr <= 0:
Expand All @@ -260,13 +285,14 @@ def main():
'epochs': args.epochs,
'weight_decay': args.weight_decay,
'objective': args.objective,
'multitask': args.multitask,
'fc': [],
'train_index': [],
'test_index': []
}

path = {
'koniq': os.path.join('/home/zwx-sjtu/codebase/koniq-10k/'),
'koniq': os.path.join('/home/zwx-sjtu/codebase/IQA_database/koniq-10k/'),
'fc_model': os.path.join('fc_models'),
'fc_root': os.path.join('fc_models', 'net_params_best.pkl'),
'db_model': os.path.join('db_models')
Expand All @@ -279,7 +305,7 @@ def main():
epoch_backup = options['epochs']
srcc_all = np.zeros((1, 10), dtype=np.float)

for i in range(0,10):
for i in range(0, 10):
#randomly split train-test set
random.shuffle(index)
train_index = index[0:round(0.8*len(index))]
Expand All @@ -291,7 +317,7 @@ def main():
options['fc'] = True
options['base_lr'] = 1e-2
options['batch_size'] = 64
options['epochs'] = 12
options['epochs'] = 20
manager = TrainManager(options, path)
best_srcc = manager.train()

Expand Down

0 comments on commit f0c21ff

Please sign in to comment.