-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #21 from yuriihavrylko/feature/l12t2-model-optimiz…
…ation L12T2 Model optimization
- Loading branch information
Showing
4 changed files
with
195 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,4 +9,5 @@ datasets==2.16.1 | |
wandb==0.16.1 | ||
httpx==0.23.0 | ||
locust==2.20.1 | ||
textpruner==1.1.post2 | ||
ipykernel==6.28.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import torch | ||
from torch.nn import KLDivLoss, CrossEntropyLoss, Softmax | ||
from transformers import ( | ||
BertForSequenceClassification, | ||
DistilBertForSequenceClassification, | ||
Trainer, | ||
TrainingArguments, | ||
BertTokenizer, | ||
DataCollatorWithPadding | ||
) | ||
from datasets import load_dataset | ||
from src.helpers.wandb_registry import download_model, publish_model | ||
|
||
SEED = 42 | ||
MODEL_ID = "yurii-havrylko/huggingface/bert_fake_news:v0" | ||
MODEL_PATH = "/tmp/model" | ||
PROJECT = "huggingface" | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
def load_and_tokenize_data(tokenizer, split, shuffle=True, seed=SEED, max_length=512): | ||
def tokenize_function(examples): | ||
return tokenizer(examples["text"], padding='max_length', truncation=True, max_length=max_length) | ||
|
||
dataset = load_dataset("GonzaloA/fake_news", split=split) | ||
if shuffle: | ||
dataset = dataset.shuffle(seed=seed) | ||
|
||
return dataset.map(tokenize_function, batched=True) | ||
|
||
def initialize_models(teacher_checkpoint, student_pretrained): | ||
teacher = BertForSequenceClassification.from_pretrained(teacher_checkpoint, local_files_only=True).to(device) | ||
student = DistilBertForSequenceClassification.from_pretrained(student_pretrained).to(device) | ||
return teacher, student | ||
|
||
def distillation_loss(teacher_logits, student_logits, temperature): | ||
softmax = Softmax(dim=1) | ||
kl_div = KLDivLoss(reduction="batchmean", log_target=True) | ||
soft_teacher_logits = softmax(teacher_logits / temperature) | ||
soft_student_logits = softmax(student_logits / temperature) | ||
return kl_div(soft_student_logits.log(), soft_teacher_logits) | ||
|
||
class DistillationTrainer(Trainer): | ||
def __init__(self, teacher_model, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.teacher_model = teacher_model | ||
|
||
def compute_loss(self, model, inputs, return_outputs=False): | ||
outputs = model(**inputs) | ||
student_logits = outputs.logits | ||
with torch.no_grad(): | ||
teacher_outputs = self.teacher_model(**inputs) | ||
teacher_logits = teacher_outputs.logits | ||
loss_distillation = distillation_loss(teacher_logits, student_logits, temperature=2.0) | ||
labels = inputs.get("labels") | ||
criterion = CrossEntropyLoss() | ||
loss_ce = criterion(student_logits.view(-1, self.model.config.num_labels), labels.view(-1)) | ||
alpha, T = 0.5, 2.0 | ||
loss = alpha * loss_distillation * (T ** 2) + (1 - alpha) * loss_ce | ||
return (loss, outputs) if return_outputs else loss | ||
|
||
def train_student_model(student, teacher, train_dataset, eval_dataset, tokenizer, training_args): | ||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | ||
trainer = DistillationTrainer( | ||
teacher_model=teacher, | ||
model=student, | ||
args=training_args, | ||
data_collator=data_collator, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
) | ||
trainer.train() | ||
return trainer.evaluate() | ||
|
||
def main(): | ||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | ||
train_set = load_and_tokenize_data(tokenizer, "train[:10000]") | ||
valid_set = load_and_tokenize_data(tokenizer, "validation[:2000]") | ||
|
||
model_path = "/tmp/model" | ||
distil_model_path = "/tmp/distil/model" | ||
tokenizer_name = "bert-base-uncased" | ||
|
||
download_model(MODEL_ID, PROJECT, model_path) | ||
|
||
teacher, student = initialize_models(model_path, "distilbert-base-uncased") | ||
|
||
training_args = TrainingArguments( | ||
output_dir="./results", | ||
num_train_epochs=4, | ||
per_device_train_batch_size=16, | ||
per_device_eval_batch_size=64, | ||
warmup_steps=100, | ||
weight_decay=0.01, | ||
logging_dir='./logs', | ||
logging_strategy="epoch", | ||
evaluation_strategy="epoch" | ||
) | ||
|
||
eval_results = train_student_model(student, teacher, train_set, valid_set, tokenizer, training_args) | ||
print(f"Distillation Evaluation Results: {eval_results}") | ||
|
||
student.save_pretrained(distil_model_path) | ||
tokenizer.save_pretrained(distil_model_path) | ||
publish_model(distil_model_path, PROJECT, "bert_fake_news_distil") | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from src.helpers.wandb_registry import download_model, publish_model | ||
import textpruner | ||
from transformers import ( | ||
BertForSequenceClassification, | ||
BertTokenizer | ||
) | ||
from textpruner import summary, TransformerPruner | ||
|
||
|
||
MODEL_ID = "yurii-havrylko/huggingface/bert_fake_news:v0" | ||
MODEL_PATH = "/tmp/model" | ||
PROJECT = "huggingface" | ||
|
||
|
||
def load_model_and_tokenizer(model_path, tokenizer_name): | ||
model = BertForSequenceClassification.from_pretrained(model_path, local_files_only=True) | ||
tokenizer = BertTokenizer.from_pretrained(tokenizer_name) | ||
return model, tokenizer | ||
|
||
def print_model_summary(model, message): | ||
print(message) | ||
print(summary(model)) | ||
|
||
|
||
def prune_bert_model(model): | ||
pruner = TransformerPruner(model) | ||
ffn_mask = textpruner.pruners.utils.random_mask_tensor((12,3072)) | ||
head_mask = textpruner.pruners.utils.random_mask_tensor((12,12), even_masks=False) | ||
pruner.prune(head_mask=head_mask, ffn_mask=ffn_mask, save_model=True) | ||
return model | ||
|
||
|
||
def print_pruned_model_info(model): | ||
for i in range(12): | ||
print ((model.base_model.encoder.layer[i].intermediate.dense.weight.shape, | ||
model.base_model.encoder.layer[i].intermediate.dense.bias.shape, | ||
model.base_model.encoder.layer[i].attention.self.key.weight.shape)) | ||
|
||
|
||
def test_inference_time(model, tokenizer, text): | ||
token = tokenizer(text, return_tensors="pt") | ||
inference_time = textpruner.inference_time(model, token) | ||
return inference_time | ||
|
||
def main(): | ||
model_path = "/tmp/model" | ||
pruned_model_path = "/tmp/pruned/model" | ||
tokenizer_name = "bert-base-uncased" | ||
text = "News title" | ||
|
||
download_model(MODEL_ID, PROJECT, model_path) | ||
|
||
model, tokenizer = load_model_and_tokenizer(model_path, tokenizer_name) | ||
|
||
print_model_summary(model, "Before pruning:") | ||
|
||
model = prune_bert_model(model) | ||
|
||
print_model_summary(model, "After pruning:") | ||
|
||
print_pruned_model_info(model) | ||
|
||
inference_time = test_inference_time(model, tokenizer, text) | ||
print(f"Inference time: {inference_time}") | ||
|
||
model.save_pretrained(pruned_model_path) | ||
tokenizer.save_pretrained(pruned_model_path) | ||
publish_model(pruned_model_path, PROJECT, "bert_fake_news_pruned") | ||
|
||
if __name__ == "__main__": | ||
main() |