Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Traing Loss problem. #25

Open
740402059 opened this issue Nov 4, 2023 · 3 comments
Open

Traing Loss problem. #25

740402059 opened this issue Nov 4, 2023 · 3 comments

Comments

@740402059
Copy link

When I used your algorithm and parameters to train on both the WTH dataset and my own dataset, I found that the loss was very low in the first epoch, but increased sharply in the second epoch, and subsequently, the loss remained higher than in the first epoch. The variation in the training loss is perplexing, and I hope you can provide some insights.

@Wentao-Gao
Copy link

Wentao-Gao commented Nov 19, 2023

I think it is the problem of loss function, the loss function of Time Domain Contrastive Loss using the MOCO loss, but add a denominator
截屏2023-11-19 下午1 23 00

I assume it is not necessary. if you delete this one. you may solve the problem.

And the loss code in CoST/cost.py Class CoSTModel I write is:


def compute_loss(self, q, k, k_negs):
    # compute logits
    # positive logits: Nx1
    l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
    # negative logits: NxK
    l_neg = torch.einsum('nc,ck->nk', [q, k_negs])

    # logits: Nx(1+K)
    logits = torch.cat([l_pos, l_neg], dim=1)

    # apply temperature
    logits /= self.T

    # labels: positive key indicators - first dim of each batch
    labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

    # Mask to zero-out positives in the denominator
    mask = torch.ones_like(logits)
    mask[:, 0] = 0

    # Apply mask and calculate cross-entropy loss
    logits_masked = logits - (mask * 1e9)  # Using a large value to mask
    loss = F.cross_entropy(logits_masked, labels)

    return loss

Hope this would help.

@740402059
Copy link
Author

Thank you. I made the modification according to your suggestion, and this problem has been solved.

@Wentao-Gao
Copy link

Thank you. I made the modification according to your suggestion, and this problem has been solved.

By the way, the moco v2 is using this paper's loss function. I just noticed that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants