diff --git a/train_clml.py b/train_clml.py index 84b6abc..b43c495 100644 --- a/train_clml.py +++ b/train_clml.py @@ -12,7 +12,7 @@ from src.loss_functions.losses import AsymmetricLoss, Hill, SPLC from randaugment import RandAugment from torch.cuda.amp import GradScaler, autocast -from ConLoss_MLML import OLELoss,CLML +from ConLoss_MLML import CLLoss,CLML parser = argparse.ArgumentParser(description='PyTorch MS_COCO Training') parser.add_argument('--dataset', help='select dataset', default='./dataset/coco_train_0.75left.txt') @@ -129,7 +129,7 @@ def train_multi_label_coco(args, model, train_loader, val_loader, lr): loss_classification=crit1(output,target) if use_clml: - loss_clml=crit_clml(output, target, epoch,feature,lam) + loss_clml=crit_clml(output, target, feature, epoch,lam) else: loss_clml=0