Skip to content

Commit

Permalink
Modify roberta and distilbert to use autograd
Browse files Browse the repository at this point in the history
  • Loading branch information
advaithsrao committed Nov 25, 2023
1 parent c915715 commit 9c4ed1f
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions detector/modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ def train(
b_input_mask = batch[1].to(self.device)
b_labels = batch[2].to(self.device)

self.model.zero_grad()
# Forward pass
outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
logits = outputs.logits # Use logits attribute to get the predicted logits
logits = outputs.logits

# Convert labels to one-hot encoding
b_labels_one_hot = F.one_hot(b_labels, num_classes=2).float()
Expand All @@ -161,8 +161,10 @@ def train(
loss = loss_function(logits, b_labels_one_hot)
total_train_loss += loss.item()

# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

# Update the model parameters
optimizer.step()
scheduler.step()

Expand Down Expand Up @@ -445,7 +447,7 @@ def train(
b_input_mask = batch[1].to(self.device)
b_labels = batch[2].to(self.device)

self.model.zero_grad()
# Forward pass
outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
logits = outputs.logits

Expand All @@ -456,8 +458,10 @@ def train(
loss = loss_function(logits, b_labels_one_hot)
total_train_loss += loss.item()

# Backward pass
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

# Update the model parameters
optimizer.step()
scheduler.step()

Expand Down

0 comments on commit 9c4ed1f

Please sign in to comment.