From 547f664f398fc26daa3d7b7eb74bd3969dba4ba6 Mon Sep 17 00:00:00 2001 From: Yurii Havrylko Date: Thu, 11 Jan 2024 23:05:35 +0100 Subject: [PATCH 1/4] model pruning --- app/requirements-dev.txt | 1 + app/src/model/pruning.py | 71 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) create mode 100644 app/src/model/pruning.py diff --git a/app/requirements-dev.txt b/app/requirements-dev.txt index 16cdb7a..09cbe1b 100644 --- a/app/requirements-dev.txt +++ b/app/requirements-dev.txt @@ -9,3 +9,4 @@ datasets==2.16.1 wandb==0.16.1 httpx==0.23.0 locust==2.20.1 +textpruner==1.1.post2 diff --git a/app/src/model/pruning.py b/app/src/model/pruning.py new file mode 100644 index 0000000..497fd44 --- /dev/null +++ b/app/src/model/pruning.py @@ -0,0 +1,71 @@ +from src.helpers.wandb_registry import download_model, publish_model +import textpruner +from transformers import ( + BertForSequenceClassification, + BertTokenizer +) +from textpruner import summary, TransformerPruner + + +MODEL_ID = "yurii-havrylko/huggingface/bert_fake_news:v0" +MODEL_PATH = "/tmp/model" +PROJECT = "huggingface" + + +def load_model_and_tokenizer(model_path, tokenizer_name): + model = BertForSequenceClassification.from_pretrained(model_path, local_files_only=True) + tokenizer = BertTokenizer.from_pretrained(tokenizer_name) + return model, tokenizer + +def print_model_summary(model, message): + print(message) + print(summary(model)) + + +def prune_bert_model(model): + pruner = TransformerPruner(model) + ffn_mask = textpruner.pruners.utils.random_mask_tensor((12,3072)) + head_mask = textpruner.pruners.utils.random_mask_tensor((12,12), even_masks=False) + pruner.prune(head_mask=head_mask, ffn_mask=ffn_mask, save_model=True) + return model + + +def print_pruned_model_info(model): + for i in range(12): + print ((model.base_model.encoder.layer[i].intermediate.dense.weight.shape, + model.base_model.encoder.layer[i].intermediate.dense.bias.shape, + model.base_model.encoder.layer[i].attention.self.key.weight.shape)) + + +def test_inference_time(model, tokenizer, text): + token = tokenizer(text, return_tensors="pt") + inference_time = textpruner.inference_time(model, token) + return inference_time + +def main(): + model_path = "/tmp/model" + pruned_model_path = "/tmp/pruned/model" + tokenizer_name = "bert-base-uncased" + text = "News title" + + download_model(MODEL_ID, PROJECT, model_path) + + model, tokenizer = load_model_and_tokenizer(model_path, tokenizer_name) + + print_model_summary(model, "Before pruning:") + + model = prune_bert_model(model) + + print_model_summary(model, "After pruning:") + + print_pruned_model_info(model) + + inference_time = test_inference_time(model, tokenizer, text) + print(f"Inference time: {inference_time}") + + model.save_pretrained(pruned_model_path) + tokenizer.save_pretrained(pruned_model_path) + publish_model(pruned_model_path, PROJECT, "bert_fake_news_pruned") + +if __name__ == "__main__": + main() \ No newline at end of file From 7e94cbb56a7c21a330d9950e5c7bd8a8808e91f8 Mon Sep 17 00:00:00 2001 From: Yurii Havrylko Date: Thu, 11 Jan 2024 23:18:56 +0100 Subject: [PATCH 2/4] knowledge distillation --- app/src/model/distilation.py | 106 +++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 app/src/model/distilation.py diff --git a/app/src/model/distilation.py b/app/src/model/distilation.py new file mode 100644 index 0000000..4f64644 --- /dev/null +++ b/app/src/model/distilation.py @@ -0,0 +1,106 @@ +import torch +from torch.nn import KLDivLoss, CrossEntropyLoss, Softmax +from transformers import ( + BertForSequenceClassification, + DistilBertForSequenceClassification, + Trainer, + TrainingArguments, + BertTokenizer, + DataCollatorWithPadding +) +from datasets import load_dataset +from src.helpers.wandb_registry import download_model, publish_model + +SEED = 42 +MODEL_ID = "yurii-havrylko/huggingface/bert_fake_news:v0" +MODEL_PATH = "/tmp/model" +PROJECT = "huggingface" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def load_and_tokenize_data(tokenizer, split, shuffle=True, seed=SEED, max_length=512): + def tokenize_function(examples): + return tokenizer(examples["text"], padding='max_length', truncation=True, max_length=max_length) + + dataset = load_dataset("GonzaloA/fake_news", split=split) + if shuffle: + dataset = dataset.shuffle(seed=seed) + + return dataset.map(tokenize_function, batched=True) + +def initialize_models(teacher_checkpoint, student_pretrained): + teacher = BertForSequenceClassification.from_pretrained(teacher_checkpoint, local_files_only=True).to(device) + student = DistilBertForSequenceClassification.from_pretrained(student_pretrained).to(device) + return teacher, student + +def distillation_loss(teacher_logits, student_logits, temperature): + softmax = Softmax(dim=1) + kl_div = KLDivLoss(reduction="batchmean", log_target=True) + soft_teacher_logits = softmax(teacher_logits / temperature) + soft_student_logits = softmax(student_logits / temperature) + return kl_div(soft_student_logits.log(), soft_teacher_logits) + +class DistillationTrainer(Trainer): + def __init__(self, teacher_model, *args, **kwargs): + super().__init__(*args, **kwargs) + self.teacher_model = teacher_model + + def compute_loss(self, model, inputs, return_outputs=False): + outputs = model(**inputs) + student_logits = outputs.logits + with torch.no_grad(): + teacher_outputs = self.teacher_model(**inputs) + teacher_logits = teacher_outputs.logits + loss_distillation = distillation_loss(teacher_logits, student_logits, temperature=2.0) + labels = inputs.get("labels") + criterion = CrossEntropyLoss() + loss_ce = criterion(student_logits.view(-1, self.model.config.num_labels), labels.view(-1)) + alpha, T = 0.5, 2.0 + loss = alpha * loss_distillation * (T ** 2) + (1 - alpha) * loss_ce + return (loss, outputs) if return_outputs else loss + +def train_student_model(student, teacher, train_dataset, eval_dataset, tokenizer, training_args): + data_collator = DataCollatorWithPadding(tokenizer=tokenizer) + trainer = DistillationTrainer( + teacher_model=teacher, + model=student, + args=training_args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + trainer.train() + return trainer.evaluate() + +if __name__ == "__main__": + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + train_set = load_and_tokenize_data(tokenizer, "train[:10000]") + valid_set = load_and_tokenize_data(tokenizer, "validation[:2000]") + + model_path = "/tmp/model" + distil_model_path = "/tmp/distil/model" + tokenizer_name = "bert-base-uncased" + + download_model(MODEL_ID, PROJECT, model_path) + + teacher, student = initialize_models(model_path, "distilbert-base-uncased") + + training_args = TrainingArguments( + output_dir="./results", + num_train_epochs=4, + per_device_train_batch_size=16, + per_device_eval_batch_size=64, + warmup_steps=100, + weight_decay=0.01, + logging_dir='./logs', + logging_strategy="epoch", + evaluation_strategy="epoch" + ) + + eval_results = train_student_model(student, teacher, train_set, valid_set, tokenizer, training_args) + print(f"Distillation Evaluation Results: {eval_results}") + + student.save_pretrained(distil_model_path) + tokenizer.save_pretrained(distil_model_path) + publish_model(distil_model_path, PROJECT, "bert_fake_news_distil") + From dbb555d8f8cd2d079465b1cbef60b5ecdb6f5916 Mon Sep 17 00:00:00 2001 From: Yurii Havrylko Date: Thu, 11 Jan 2024 23:19:21 +0100 Subject: [PATCH 3/4] update read me --- README.md | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/README.md b/README.md index 55dfa9c..c3a56e6 100644 --- a/README.md +++ b/README.md @@ -100,3 +100,18 @@ Run from config ``` kubectl create -f deployment/app-fastapi-scaling.yml ``` + + +### Model optimization + +Run pruning: + +``` +python -m src.model.pruning +``` + +Run distilation: + +``` +python -m src.model.distilation +``` From 67edd50eb06f4b2a20b5acf491d75c0f5b4c8cc2 Mon Sep 17 00:00:00 2001 From: Yurii Havrylko Date: Thu, 11 Jan 2024 23:21:01 +0100 Subject: [PATCH 4/4] fix distilation script --- app/src/model/distilation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/app/src/model/distilation.py b/app/src/model/distilation.py index 4f64644..8358229 100644 --- a/app/src/model/distilation.py +++ b/app/src/model/distilation.py @@ -72,7 +72,7 @@ def train_student_model(student, teacher, train_dataset, eval_dataset, tokenizer trainer.train() return trainer.evaluate() -if __name__ == "__main__": +def main(): tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") train_set = load_and_tokenize_data(tokenizer, "train[:10000]") valid_set = load_and_tokenize_data(tokenizer, "validation[:2000]") @@ -104,3 +104,5 @@ def train_student_model(student, teacher, train_dataset, eval_dataset, tokenizer tokenizer.save_pretrained(distil_model_path) publish_model(distil_model_path, PROJECT, "bert_fake_news_distil") +if __name__ == "__main__": + main()