Skip to content

Commit

Permalink
Added classification head to bert models
Browse files Browse the repository at this point in the history
  • Loading branch information
advaithsrao committed Nov 27, 2023
1 parent 9975d08 commit 55756de
Showing 1 changed file with 78 additions and 38 deletions.
116 changes: 78 additions & 38 deletions detector/modeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,15 @@ def __init__(
self.tokenizer = RobertaTokenizer.from_pretrained(self.model_name)

if self.path != '':
self.model = RobertaModel.from_pretrained(self.path, num_labels=self.num_labels).to(self.device)
self.model = RobertaModel.from_pretrained(self.path).to(self.device)
else:
self.model = RobertaModel.from_pretrained(self.model_name, num_labels=self.num_labels).to(self.device)
self.model = RobertaModel.from_pretrained(self.model_name).to(self.device)

self.classification_head = nn.Sequential(
nn.Linear(self.model.config.hidden_size, 128),
nn.ReLU(),
nn.Linear(128, self.num_labels)
)

def train(
self,
Expand Down Expand Up @@ -433,7 +439,8 @@ def train(
validation_dataloader = DataLoader(val_dataset, batch_size=self.batch_size)

# Initialize the optimizer and learning rate scheduler
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, eps=self.epsilon)
optimizer = AdamW(list(self.model.parameters()) + list(self.classification_head.parameters()),
lr=self.learning_rate, eps=self.epsilon)
total_steps = len(train_dataloader) * self.num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

Expand All @@ -447,6 +454,7 @@ def train(

# Training loop
self.model.train()
self.classification_head.train()
total_train_loss = 0

for step, batch in enumerate(train_dataloader):
Expand All @@ -455,15 +463,21 @@ def train(
b_labels = batch[2].to(self.device)

# Forward pass
outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
loss = outputs[0]
logits = F.softmax(outputs[1], dim=1) # Taking the softmax of output
outputs = self.model(b_input_ids, attention_mask=b_input_mask)
last_hidden_states = outputs.last_hidden_state

# Apply classification head
logits = self.classification_head(last_hidden_states[:, 0, :])

loss = F.cross_entropy(logits, b_labels)

total_train_loss += loss.item()

# Backward pass
loss.backward()

torch.nn.utils.clip_grad_norm_(list(self.model.parameters()) + list(self.classification_head.parameters()), 1.0)

# Update the model parameters
optimizer.step()

Expand All @@ -481,6 +495,7 @@ def train(

# Evaluation loop
self.model.eval()
self.classification_head.eval()
total_eval_accuracy = 0
total_eval_loss = 0

Expand All @@ -490,14 +505,16 @@ def train(
b_labels = batch[2].to(self.device)

with torch.no_grad():
outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
loss = outputs[0]
logits = F.softmax(outputs[1], dim=1) # Taking the softmax of output

total_eval_loss += loss.item()

# logits = logits.detach().cpu().numpy()
# label_ids = b_labels.detach().cpu().numpy()
outputs = self.model(b_input_ids, attention_mask=b_input_mask)
last_hidden_states = outputs.last_hidden_state

# Apply classification head
logits = self.classification_head(last_hidden_states[:, 0, :])

loss = F.cross_entropy(logits, b_labels)

total_eval_loss += loss.item()
total_eval_accuracy += self.accuracy(logits, b_labels)

total_eval_accuracy += self.accuracy(logits, b_labels)

Expand Down Expand Up @@ -575,9 +592,10 @@ def predict(

with torch.no_grad():
outputs = self.model(b_input_ids, attention_mask=b_input_mask)
# loss = outputs[0]
# logits = F.softmax(outputs[1]) # Taking the softmax of output
logits = outputs.logits
last_hidden_states = outputs.last_hidden_state

# Apply classification head
logits = self.classification_head(last_hidden_states[:, 0, :])

logits = logits.detach().cpu().numpy()

Expand All @@ -589,8 +607,8 @@ def predict(
return predictions

def save_model(
self,
path: str
self,
path: str
):
"""Saves the model to the given path.
Expand All @@ -601,8 +619,10 @@ def save_model(
# Check if the directory exists, and if not, create it
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)


# Save the transformer model and the classification head
self.model.save_pretrained(path)
torch.save(self.classification_head.state_dict(), os.path.join(path, 'classification_head.pth'))

def accuracy(
self,
Expand Down Expand Up @@ -662,6 +682,12 @@ def __init__(
self.model = DistilBertModel.from_pretrained(self.path).to(self.device)
else:
self.model = DistilBertModel.from_pretrained(self.model_name).to(self.device)

self.classification_head = nn.Sequential(
nn.Linear(self.model.config.hidden_size, 128),
nn.ReLU(),
nn.Linear(128, self.num_labels)
)

def train(
self,
Expand Down Expand Up @@ -725,7 +751,8 @@ def train(
validation_dataloader = DataLoader(val_dataset, batch_size=self.batch_size)

# Initialize the optimizer and learning rate scheduler
optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, eps=self.epsilon)
optimizer = AdamW(list(self.model.parameters()) + list(self.classification_head.parameters()),
lr=self.learning_rate, eps=self.epsilon)
total_steps = len(train_dataloader) * self.num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

Expand All @@ -739,6 +766,7 @@ def train(

# Training loop
self.model.train()
self.classification_head.train()
total_train_loss = 0

for step, batch in enumerate(train_dataloader):
Expand All @@ -747,15 +775,21 @@ def train(
b_labels = batch[2].to(self.device)

# Forward pass
outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
loss = outputs[0]
logits = F.softmax(outputs[1], dim=1) # Taking the softmax of output
outputs = self.model(b_input_ids, attention_mask=b_input_mask)
last_hidden_states = outputs.last_hidden_state

# Apply classification head
logits = self.classification_head(last_hidden_states[:, 0, :])

loss = F.cross_entropy(logits, b_labels)

total_train_loss += loss.item()

# Backward pass
loss.backward()

torch.nn.utils.clip_grad_norm_(list(self.model.parameters()) + list(self.classification_head.parameters()), 1.0)

# Update the model parameters
optimizer.step()

Expand All @@ -773,6 +807,7 @@ def train(

# Evaluation loop
self.model.eval()
self.classification_head.eval()
total_eval_accuracy = 0
total_eval_loss = 0

Expand All @@ -782,14 +817,16 @@ def train(
b_labels = batch[2].to(self.device)

with torch.no_grad():
outputs = self.model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
loss = outputs[0]
logits = F.softmax(outputs[1], dim=1) # Taking the softmax of output

total_eval_loss += loss.item()

# logits = logits.detach().cpu().numpy()
# label_ids = b_labels.detach().cpu().numpy()
outputs = self.model(b_input_ids, attention_mask=b_input_mask)
last_hidden_states = outputs.last_hidden_state

# Apply classification head
logits = self.classification_head(last_hidden_states[:, 0, :])

loss = F.cross_entropy(logits, b_labels)

total_eval_loss += loss.item()
total_eval_accuracy += self.accuracy(logits, b_labels)

total_eval_accuracy += self.accuracy(logits, b_labels)

Expand Down Expand Up @@ -867,9 +904,10 @@ def predict(

with torch.no_grad():
outputs = self.model(b_input_ids, attention_mask=b_input_mask)
# loss = outputs[0]
# logits = F.softmax(outputs[1]) # Taking the softmax of output
logits = outputs.logits
last_hidden_states = outputs.last_hidden_state

# Apply classification head
logits = self.classification_head(last_hidden_states[:, 0, :])

logits = logits.detach().cpu().numpy()

Expand All @@ -881,8 +919,8 @@ def predict(
return predictions

def save_model(
self,
path: str
self,
path: str
):
"""Saves the model to the given path.
Expand All @@ -893,8 +931,10 @@ def save_model(
# Check if the directory exists, and if not, create it
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)


# Save the transformer model and the classification head
self.model.save_pretrained(path)
torch.save(self.classification_head.state_dict(), os.path.join(path, 'classification_head.pth'))

def accuracy(
self,
Expand Down

0 comments on commit 55756de

Please sign in to comment.