Skip to content

Commit

Permalink
Changed code to detach logits and labels to cpu before converting to …
Browse files Browse the repository at this point in the history
…np array
  • Loading branch information
advaithsrao committed Nov 24, 2023
1 parent fd0e3a2 commit c915715
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions detector/modeler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

import shutil
import pandas as pd
Expand Down Expand Up @@ -196,8 +196,8 @@ def train(
# Calculate the loss using the weighted loss function
loss = loss_function(logits, b_labels_one_hot)
total_eval_loss += loss.item()
logits = logits.detach().to(self.device).numpy()
label_ids = b_labels.to(self.device).numpy()
logits = logits.detach().cpu().numpy()
label_ids = b_labels.detach().cpu().numpy()
total_eval_accuracy += self.accuracy(logits, label_ids)

avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)
Expand Down Expand Up @@ -490,8 +490,8 @@ def train(
# Calculate the loss using the weighted loss function
loss = loss_function(logits, b_labels_one_hot)
total_eval_loss += loss.item()
logits = logits.detach().to(self.device).numpy()
label_ids = b_labels.to(self.device).numpy()
logits = logits.detach().cpu().numpy()
label_ids = b_labels.detach().cpu().numpy()
total_eval_accuracy += self.accuracy(logits, label_ids)

avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)
Expand Down

0 comments on commit c915715

Please sign in to comment.