Skip to content

Commit

Permalink
Replaced BERT models with regular model from sequence models
Browse files Browse the repository at this point in the history
  • Loading branch information
advaithsrao committed Nov 27, 2023
1 parent 72b73c0 commit 9975d08
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions detector/modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
import torch
from torch import nn

from transformers import RobertaTokenizer, RobertaForSequenceClassification
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from transformers import RobertaTokenizer, RobertaForSequenceClassification, RobertaModel
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, DistilBertModel
from transformers import BertModel
from transformers import AdamW, get_linear_schedule_with_warmup

from torch.utils.data import DataLoader, TensorDataset#, SubsetRandomSampler
Expand Down Expand Up @@ -151,17 +152,17 @@ def train(
total_train_loss = 0

for step, batch in enumerate(train_dataloader):
b_input_ids = batch[0].to(self.device)
b_input_ids = batch[0].to(self.device, dtype=torch.float)
b_input_mask = batch[1].to(self.device)
b_labels = batch[2].to(self.device)
b_labels = batch[2].to(self.device, dtype=torch.float)

# Forward pass
outputs = self.model(b_input_ids)

logits = F.sigmoid(outputs) # Apply sigmoid to the final output

# Compute binary cross-entropy loss
loss = F.binary_cross_entropy(logits, b_labels.float())
loss = F.binary_cross_entropy(logits, b_labels)

total_train_loss += loss.item()

Expand Down Expand Up @@ -366,9 +367,9 @@ def __init__(
self.tokenizer = RobertaTokenizer.from_pretrained(self.model_name)

if self.path != '':
self.model = RobertaForSequenceClassification.from_pretrained(self.path, num_labels=self.num_labels).to(self.device)
self.model = RobertaModel.from_pretrained(self.path, num_labels=self.num_labels).to(self.device)
else:
self.model = RobertaForSequenceClassification.from_pretrained(self.model_name, num_labels=self.num_labels).to(self.device)
self.model = RobertaModel.from_pretrained(self.model_name, num_labels=self.num_labels).to(self.device)

def train(
self,
Expand Down Expand Up @@ -658,9 +659,9 @@ def __init__(
self.tokenizer = DistilBertTokenizer.from_pretrained(self.model_name)

if self.path != '':
self.model = DistilBertForSequenceClassification.from_pretrained(self.path).to(self.device)
self.model = DistilBertModel.from_pretrained(self.path).to(self.device)
else:
self.model = DistilBertForSequenceClassification.from_pretrained(self.model_name).to(self.device)
self.model = DistilBertModel.from_pretrained(self.model_name).to(self.device)

def train(
self,
Expand Down

0 comments on commit 9975d08

Please sign in to comment.