Skip to content

Commit

Permalink
One hot encoded labels for BCE with logits
Browse files Browse the repository at this point in the history
  • Loading branch information
advaithsrao committed Nov 24, 2023
1 parent b019b32 commit fd0e3a2
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions detector/modeler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

import shutil
import pandas as pd
Expand All @@ -17,6 +17,8 @@
from transformers import AdamW, get_linear_schedule_with_warmup

from torch.utils.data import DataLoader, TensorDataset#, SubsetRandomSampler
import torch.nn.functional as F

import wandb
from mlflow.sklearn import save_model
from scipy.sparse import hstack
Expand Down Expand Up @@ -105,7 +107,7 @@ def train(
# Convert lists to tensors
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
label_ids = torch.stack(label_ids).squeeze() # Create a 1D tensor for label_ids
label_ids = torch.stack(label_ids)

# Split the data into train and validation sets
dataset = TensorDataset(input_ids, attention_masks, label_ids)
Expand Down Expand Up @@ -152,8 +154,11 @@ def train(
outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
logits = outputs.logits # Use logits attribute to get the predicted logits

# Convert labels to one-hot encoding
b_labels_one_hot = F.one_hot(b_labels, num_classes=2).float()

# Calculate the loss using the weighted loss function
loss = loss_function(logits.squeeze(), b_labels)
loss = loss_function(logits, b_labels_one_hot)
total_train_loss += loss.item()

loss.backward()
Expand Down Expand Up @@ -185,7 +190,11 @@ def train(
# loss = outputs[0]
logits = outputs.logits

loss = loss_function(logits, b_labels)
# Convert labels to one-hot encoding
b_labels_one_hot = F.one_hot(b_labels, num_classes=2).float()

# Calculate the loss using the weighted loss function
loss = loss_function(logits, b_labels_one_hot)
total_eval_loss += loss.item()
logits = logits.detach().to(self.device).numpy()
label_ids = b_labels.to(self.device).numpy()
Expand Down Expand Up @@ -393,7 +402,7 @@ def train(
# Convert lists to tensors
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
label_ids = torch.stack(label_ids).squeeze() # Create a 1D tensor for label_ids
label_ids = torch.stack(label_ids)

# Split the data into train and validation sets
dataset = TensorDataset(input_ids, attention_masks, label_ids)
Expand Down Expand Up @@ -440,8 +449,11 @@ def train(
outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
logits = outputs.logits

# Convert labels to one-hot encoding
b_labels_one_hot = F.one_hot(b_labels, num_classes=2).float()

# Calculate the loss using the weighted loss function
loss = loss_function(logits.squeeze(), b_labels)
loss = loss_function(logits, b_labels_one_hot)
total_train_loss += loss.item()

loss.backward()
Expand Down Expand Up @@ -472,7 +484,11 @@ def train(
outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
logits = outputs.logits

loss = loss_function(logits, b_labels)
# Convert labels to one-hot encoding
b_labels_one_hot = F.one_hot(b_labels, num_classes=2).float()

# Calculate the loss using the weighted loss function
loss = loss_function(logits, b_labels_one_hot)
total_eval_loss += loss.item()
logits = logits.detach().to(self.device).numpy()
label_ids = b_labels.to(self.device).numpy()
Expand Down Expand Up @@ -552,7 +568,7 @@ def predict(

with torch.no_grad():
outputs = self.model(b_input_ids, attention_mask=b_input_mask)
logits = outputs[0]
logits = outputs.logits

logits = logits.detach().cpu().numpy()

Expand Down

0 comments on commit fd0e3a2

Please sign in to comment.