Skip to content

Commit

Permalink
Fix ListMLE loss reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
perceptiveshawty authored Dec 25, 2022
1 parent cfdf89b commit 3f04480
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions rankcse/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ def __init__(self, tau, gamma_):
self.student_temp_scaled_sim = Similarity(tau)
self.gamma_ = gamma_

def forward(self, teacher_top1_sim_pred, z1, z2):
student_top1_sim_pred = self.student_temp_scaled_sim(z1.unsqueeze(1), z2.unsqueeze(0))

def forward(self, teacher_top1_sim_pred, student_top1_sim_pred):
p = F.log_softmax(student_top1_sim_pred.fill_diagonal_(float('-inf')), dim=-1)
q = F.softmax(teacher_top1_sim_pred.fill_diagonal_(float('-inf')), dim=-1)
loss = -(q*p).nansum() / q.nansum()
Expand All @@ -88,11 +86,10 @@ def __init__(self, tau, gamma_):
self.gamma_ = gamma_
self.eps = 1e-7

def forward(self, teacher_top1_sim_pred, z1, z2):
student_top1_sim_pred = self.temp_scaled_sim(z1.unsqueeze(1), z2.unsqueeze(0))
def forward(self, teacher_top1_sim_pred, student_top1_sim_pred):

y_pred = student_top1_sim_pred # .fill_diagonal_(float('-inf')).softmax(dim=-1)
y_true = teacher_top1_sim_pred # .fill_diagonal_(float('-inf')).softmax(dim=-1)
y_pred = student_top1_sim_pred
y_true = teacher_top1_sim_pred

# shuffle for randomised tie resolution
random_indices = torch.randperm(y_pred.shape[-1])
Expand All @@ -109,8 +106,7 @@ def forward(self, teacher_top1_sim_pred, z1, z2):
observation_loss = torch.log(cumsums + self.eps) - preds_sorted_by_true_minus_max
observation_loss[mask] = 0.0

return self.gamma_ * torch.mean(torch.mean(observation_loss, dim=1))

return self.gamma_ * torch.mean(torch.sum(observation_loss, dim=1))

class Pooler(nn.Module):
"""
Expand Down Expand Up @@ -159,6 +155,12 @@ def cl_init(cls, config):
cls.mlp = MLPLayer(config)
cls.sim = Similarity(temp=cls.model_args.temp)
cls.div = Divergence(beta_=cls.model_args.beta_)
if cls.model_args.distillation_loss == "listnet":
cls.distillation_loss_fct = ListNet(cls.model_args.tau2, cls.model_args.gamma_)
elif cls.model_args.distillation_loss == "listmle":
cls.distillation_loss_fct = ListMLE(cls.model_args.tau2, cls.model_args.gamma_)
else:
raise NotImplementedError
cls.init_weights()

def cl_forward(cls,
Expand Down Expand Up @@ -280,8 +282,8 @@ def cl_forward(cls,
loss = loss_fct(cos_sim, labels)

# RankCSE - knowledge distillation loss
distillation_loss_fct = (ListNet(cls.model_args.tau2, cls.model_args.gamma_) if cls.model_args.distillation_loss == "listnet" else ListMLE(cls.model_args.tau2, cls.model_args.gamma_))
kd_loss = distillation_loss_fct(teacher_top1_sim_pred.to(cls.device), z1, z2)
student_top1_sim_pred = cos_sim.clone()
kd_loss = cls.distillation_loss_fct(teacher_top1_sim_pred.to(cls.device), student_top1_sim_pred)

# RankCSE - self-distillation loss
z1_z2_cos = cos_sim.clone()
Expand Down Expand Up @@ -470,4 +472,4 @@ def forward(self,
mlm_input_ids=mlm_input_ids,
mlm_labels=mlm_labels,
teacher_top1_sim_pred=teacher_top1_sim_pred,
)
)

0 comments on commit 3f04480

Please sign in to comment.