Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
chuangua authored Jul 13, 2023
1 parent fa02283 commit 7dcac9a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions ConLoss_MLML.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def forward(ctx, X, y,lambda_):
r = np.sum(S < eigThd)
uprod=np.dot(U[:, 0:U.shape[1] - r],np.transpose(V[:, 0:V.shape[1] - r]))
dX_all = uprod
dX = (dX_c - lambda_ * dX_all) / N * np.float(lambda_)
dX = (dX_c - dX_all) / N * np.float(lambda_)
ctx.dX = torch.FloatTensor(dX).cuda()

obj = (Obj_c - lambda_*Obj_all)/N*np.float(lambda_)
obj = (Obj_c - Obj_all)/N*np.float(lambda_)
obj=torch.FloatTensor([float(obj)])[0].cuda()
return obj

Expand Down
10 changes: 5 additions & 5 deletions train_clml.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
parser.add_argument('--dataset', help='select dataset', default='./dataset/coco_train_0.75left.txt')
parser.add_argument('--data', metavar='DIR', help='path to dataset', default='/home/mscoco')
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--model-name', default='resnet101')
parser.add_argument('--model-name', default='resnet50')
parser.add_argument('--num-classes', default=80)
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',help='number of data loading workers (default: 16)')
parser.add_argument('--image-size', default=448, type=int, metavar='N', help='input image size (default: 448)')
Expand All @@ -27,8 +27,9 @@
parser.add_argument('--print-freq', '-p', default=64, type=int, metavar='N', help='print frequency (default: 64)')
parser.add_argument('--loss', default='SPLC', type=str, help='select loss function', choices=['BCE','Focal','Hill','SPLC'])
parser.add_argument('--lambda_', type=float, default=1.00, help='CL loss is multiplied by lambda_.')
parser.add_argument('--useclml', type=bool, default=True, help='use clml')
parser.add_argument('--use_clml', type=bool, default=True, help='use clml')
parser.add_argument('--threshold', type=float, default=0.75, help='If the predicted probability is greater than this value, the sample makes up the label')
parser.add_argument('--NeEpoch', type=int, default=1, help='Start Epoch')


def main():
Expand Down Expand Up @@ -89,7 +90,7 @@ def train_multi_label_coco(args, model, train_loader, val_loader, lr):
Epochs = 80
Stop_epoch = 25
weight_decay = 1e-4
use_clml=args.useclml
use_clml=args.use_clml
lam=args.lambda_
if args.loss == 'BCE':
crit1 = AsymmetricLoss(gamma_neg=0, gamma_pos=0, clip=0)
Expand All @@ -111,14 +112,13 @@ def train_multi_label_coco(args, model, train_loader, val_loader, lr):
highest_mAP = 0
trainInfoList = []
scaler = GradScaler()
crit_clml=CLML(tau=args.threshold)
crit_clml=CLML(tau=args.threshold,change_epoch=args.NeEpoch)

for epoch in range(Epochs):
if epoch > Stop_epoch:
break
for batch_idx, (inputData, target ) in enumerate(train_loader):
inputData = inputData.cuda()
#target=target.to(torch.float32)
target = target.cuda() # (batch,3,num_classes)
output,feature = model(inputData)

Expand Down
2 changes: 1 addition & 1 deletion validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

parser = argparse.ArgumentParser(description='PyTorch MS_COCO Validation')
parser.add_argument('--data', metavar='DIR', help='path to dataset', default='/home/MSCOCO_2014/')
parser.add_argument('--model-name', default='resnet101')
parser.add_argument('--model-name', default='resnet50')
parser.add_argument('--model-path', default='/home/model/coco_75_hill.ckpt', type=str)
parser.add_argument('--num-classes', default=80)
parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', help='number of data loading workers (default: 16)')
Expand Down

0 comments on commit 7dcac9a

Please sign in to comment.