Skip to content

Commit

Permalink
Modified loss function and logits for distilbert and roberta
Browse files Browse the repository at this point in the history
  • Loading branch information
advaithsrao committed Nov 24, 2023
1 parent ac12834 commit 5c09644
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions detector/modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def train(

# Define the loss function with class weights
# loss_function = torch.nn.CrossEntropyLoss(weight=class_weights)
loss_function = torch.nn.BCEWithLogitsLoss(pos_weight=class_weights[1])
pos_weight = torch.tensor(class_weights[1], dtype=torch.float32).to(self.device)
loss_function = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)

for epoch in range(self.num_epochs):
print(f'{"="*20} Epoch {epoch + 1}/{self.num_epochs} {"="*20}')
Expand All @@ -149,7 +150,7 @@ def train(

self.model.zero_grad()
outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
logits = outputs[1]
logits = outputs.logits # Use logits attribute to get the predicted logits

# Calculate the loss using the weighted loss function
loss = loss_function(logits, b_labels)
Expand Down Expand Up @@ -181,9 +182,10 @@ def train(

with torch.no_grad():
outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
loss = outputs[0]
logits = outputs[1]

# loss = outputs[0]
logits = outputs.logits

loss = loss_function(logits, b_labels)
total_eval_loss += loss.item()
logits = logits.detach().to(self.device).numpy()
label_ids = b_labels.to(self.device).numpy()
Expand Down Expand Up @@ -263,7 +265,7 @@ def predict(

with torch.no_grad():
outputs = self.model(b_input_ids, attention_mask=b_input_mask)
logits = outputs[0]
logits = outputs.logits

logits = logits.detach().cpu().numpy()

Expand Down Expand Up @@ -419,7 +421,8 @@ def train(

# Define the loss function with class weights
# loss_function = torch.nn.CrossEntropyLoss(weight=class_weights)
loss_function = torch.nn.BCEWithLogitsLoss(pos_weight=class_weights[1])
pos_weight = torch.tensor(class_weights[1], dtype=torch.float32).to(self.device)
loss_function = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)

for epoch in range(self.num_epochs):
print(f'{"="*20} Epoch {epoch + 1}/{self.num_epochs} {"="*20}')
Expand All @@ -435,8 +438,8 @@ def train(

self.model.zero_grad()
outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
logits = outputs[1]

logits = outputs.logits
# Calculate the loss using the weighted loss function
loss = loss_function(logits, b_labels)
total_train_loss += loss.item()
Expand Down Expand Up @@ -467,9 +470,9 @@ def train(

with torch.no_grad():
outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
loss = outputs[0]
logits = outputs[1]

logits = outputs.logits

loss = loss_function(logits, b_labels)
total_eval_loss += loss.item()
logits = logits.detach().to(self.device).numpy()
label_ids = b_labels.to(self.device).numpy()
Expand Down

0 comments on commit 5c09644

Please sign in to comment.