From 9c4ed1f1be4af74437e35f770da24d68d1794f40 Mon Sep 17 00:00:00 2001 From: Advaith Rao Date: Fri, 24 Nov 2023 20:22:09 -0500 Subject: [PATCH] Modify roberta and distilbert to use autograd --- detector/modeler.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/detector/modeler.py b/detector/modeler.py index fcde853..ec5fc5c 100644 --- a/detector/modeler.py +++ b/detector/modeler.py @@ -150,9 +150,9 @@ def train( b_input_mask = batch[1].to(self.device) b_labels = batch[2].to(self.device) - self.model.zero_grad() + # Forward pass outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels) - logits = outputs.logits # Use logits attribute to get the predicted logits + logits = outputs.logits # Convert labels to one-hot encoding b_labels_one_hot = F.one_hot(b_labels, num_classes=2).float() @@ -161,8 +161,10 @@ def train( loss = loss_function(logits, b_labels_one_hot) total_train_loss += loss.item() + # Backward pass loss.backward() - torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + + # Update the model parameters optimizer.step() scheduler.step() @@ -445,7 +447,7 @@ def train( b_input_mask = batch[1].to(self.device) b_labels = batch[2].to(self.device) - self.model.zero_grad() + # Forward pass outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels) logits = outputs.logits @@ -456,8 +458,10 @@ def train( loss = loss_function(logits, b_labels_one_hot) total_train_loss += loss.item() + # Backward pass loss.backward() - torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + + # Update the model parameters optimizer.step() scheduler.step()