Skip to content

Commit

Permalink
Small fix for prediction calculation for bert models
Browse files Browse the repository at this point in the history
  • Loading branch information
advaithsrao committed Nov 26, 2023
1 parent 38060b7 commit e8dc4a0
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions detector/modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def predict(
with torch.no_grad():
outputs = self.model(b_input_ids, attention_mask=b_input_mask)
loss = outputs[0]
logits = F.softmax(outputs[1], dim=1) # Taking the softmax of output
logits = F.softmax(outputs[1]) # Taking the softmax of output

_, prediction= torch.max(logits, dim=1)

Expand Down Expand Up @@ -552,7 +552,7 @@ def predict(
with torch.no_grad():
outputs = self.model(b_input_ids, attention_mask=b_input_mask)
loss = outputs[0]
logits = F.softmax(outputs[1], dim=1) # Taking the softmax of output
logits = F.softmax(outputs[1]) # Taking the softmax of output

_, prediction= torch.max(logits, dim=1)

Expand Down

0 comments on commit e8dc4a0

Please sign in to comment.