Skip to content

Commit

Permalink
Loss function comparison for distilbert and roberta changed to float …
Browse files Browse the repository at this point in the history
…label and logits
  • Loading branch information
advaithsrao committed Nov 26, 2023
1 parent 3c3dd6b commit cc0f719
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions detector/modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def train(

# Initialize variables for early stopping
best_validation_loss = float("inf")
patience = 10 # Number of epochs to wait for improvement
patience = 3 # Number of epochs to wait for improvement
wait = 0

class_weights = compute_class_weight('balanced', classes=np.unique(label), y=label)
Expand Down Expand Up @@ -157,13 +157,13 @@ def train(
sigmoid_output = torch.sigmoid(logits[:, 1])

# Thresholding to convert probabilities to binary values (0 or 1)
binary_output = (sigmoid_output > 0.5).to(torch.int)
binary_output = (sigmoid_output > 0.5)

# # Convert labels to one-hot encoding
# b_labels_one_hot = F.one_hot(b_labels, num_classes=2).float()

# Calculate the loss using the weighted loss function
loss = loss_function(binary_output, b_labels)
loss = loss_function(binary_output, b_labels.float())
total_train_loss += loss.item()

# Backward pass
Expand Down Expand Up @@ -200,13 +200,13 @@ def train(
sigmoid_output = torch.sigmoid(logits[:, 1])

# Thresholding to convert probabilities to binary values (0 or 1)
binary_output = (sigmoid_output > 0.5).to(torch.int)
binary_output = (sigmoid_output > 0.5)

# # Convert labels to one-hot encoding
# b_labels_one_hot = F.one_hot(b_labels, num_classes=2).float()

# Calculate the loss using the weighted loss function
loss = loss_function(binary_output, b_labels)
loss = loss_function(binary_output, b_labels.float())
total_eval_loss += loss.item()
logits = logits.detach().cpu().numpy()
label_ids = b_labels.detach().cpu().numpy()
Expand Down Expand Up @@ -434,7 +434,7 @@ def train(

# Initialize variables for early stopping
best_validation_loss = float("inf")
patience = 10 # Number of epochs to wait for improvement
patience = 3 # Number of epochs to wait for improvement
wait = 0

class_weights = compute_class_weight('balanced', classes=np.unique(label), y=label)
Expand Down Expand Up @@ -464,13 +464,13 @@ def train(
sigmoid_output = torch.sigmoid(logits[:, 1])

# Thresholding to convert probabilities to binary values (0 or 1)
binary_output = (sigmoid_output > 0.5).to(torch.int)
binary_output = (sigmoid_output > 0.5)

# # Convert labels to one-hot encoding
# b_labels_one_hot = F.one_hot(b_labels, num_classes=2).float()

# Calculate the loss using the weighted loss function
loss = loss_function(binary_output, b_labels)
loss = loss_function(binary_output, b_labels.float())

total_train_loss += loss.item()

Expand Down Expand Up @@ -507,13 +507,13 @@ def train(
sigmoid_output = torch.sigmoid(logits[:, 1])

# Thresholding to convert probabilities to binary values (0 or 1)
binary_output = (sigmoid_output > 0.5).to(torch.int)
binary_output = (sigmoid_output > 0.5)

# # Convert labels to one-hot encoding
# b_labels_one_hot = F.one_hot(b_labels, num_classes=2).float()

# Calculate the loss using the weighted loss function
loss = loss_function(binary_output, b_labels)
loss = loss_function(binary_output, b_labels.float())
total_eval_loss += loss.item()
logits = logits.detach().cpu().numpy()
label_ids = b_labels.detach().cpu().numpy()
Expand Down

0 comments on commit cc0f719

Please sign in to comment.