Skip to content

Commit

Permalink
Merge pull request #21 from yuriihavrylko/feature/l12t2-model-optimiz…
Browse files Browse the repository at this point in the history
…ation

L12T2 Model optimization
  • Loading branch information
yuriihavrylko authored Feb 14, 2024
2 parents dacf26b + 28f38b9 commit d843205
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 0 deletions.
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,18 @@ Run from config
```
kubectl create -f deployment/app-fastapi-scaling.yml
```


### Model optimization

Run pruning:

```
python -m src.model.pruning
```

Run distilation:

```
python -m src.model.distilation
```
1 change: 1 addition & 0 deletions app/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ datasets==2.16.1
wandb==0.16.1
httpx==0.23.0
locust==2.20.1
textpruner==1.1.post2
ipykernel==6.28.0
108 changes: 108 additions & 0 deletions app/src/model/distilation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import torch
from torch.nn import KLDivLoss, CrossEntropyLoss, Softmax
from transformers import (
BertForSequenceClassification,
DistilBertForSequenceClassification,
Trainer,
TrainingArguments,
BertTokenizer,
DataCollatorWithPadding
)
from datasets import load_dataset
from src.helpers.wandb_registry import download_model, publish_model

SEED = 42
MODEL_ID = "yurii-havrylko/huggingface/bert_fake_news:v0"
MODEL_PATH = "/tmp/model"
PROJECT = "huggingface"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_and_tokenize_data(tokenizer, split, shuffle=True, seed=SEED, max_length=512):
def tokenize_function(examples):
return tokenizer(examples["text"], padding='max_length', truncation=True, max_length=max_length)

dataset = load_dataset("GonzaloA/fake_news", split=split)
if shuffle:
dataset = dataset.shuffle(seed=seed)

return dataset.map(tokenize_function, batched=True)

def initialize_models(teacher_checkpoint, student_pretrained):
teacher = BertForSequenceClassification.from_pretrained(teacher_checkpoint, local_files_only=True).to(device)
student = DistilBertForSequenceClassification.from_pretrained(student_pretrained).to(device)
return teacher, student

def distillation_loss(teacher_logits, student_logits, temperature):
softmax = Softmax(dim=1)
kl_div = KLDivLoss(reduction="batchmean", log_target=True)
soft_teacher_logits = softmax(teacher_logits / temperature)
soft_student_logits = softmax(student_logits / temperature)
return kl_div(soft_student_logits.log(), soft_teacher_logits)

class DistillationTrainer(Trainer):
def __init__(self, teacher_model, *args, **kwargs):
super().__init__(*args, **kwargs)
self.teacher_model = teacher_model

def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(**inputs)
student_logits = outputs.logits
with torch.no_grad():
teacher_outputs = self.teacher_model(**inputs)
teacher_logits = teacher_outputs.logits
loss_distillation = distillation_loss(teacher_logits, student_logits, temperature=2.0)
labels = inputs.get("labels")
criterion = CrossEntropyLoss()
loss_ce = criterion(student_logits.view(-1, self.model.config.num_labels), labels.view(-1))
alpha, T = 0.5, 2.0
loss = alpha * loss_distillation * (T ** 2) + (1 - alpha) * loss_ce
return (loss, outputs) if return_outputs else loss

def train_student_model(student, teacher, train_dataset, eval_dataset, tokenizer, training_args):
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainer = DistillationTrainer(
teacher_model=teacher,
model=student,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
return trainer.evaluate()

def main():
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
train_set = load_and_tokenize_data(tokenizer, "train[:10000]")
valid_set = load_and_tokenize_data(tokenizer, "validation[:2000]")

model_path = "/tmp/model"
distil_model_path = "/tmp/distil/model"
tokenizer_name = "bert-base-uncased"

download_model(MODEL_ID, PROJECT, model_path)

teacher, student = initialize_models(model_path, "distilbert-base-uncased")

training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=4,
per_device_train_batch_size=16,
per_device_eval_batch_size=64,
warmup_steps=100,
weight_decay=0.01,
logging_dir='./logs',
logging_strategy="epoch",
evaluation_strategy="epoch"
)

eval_results = train_student_model(student, teacher, train_set, valid_set, tokenizer, training_args)
print(f"Distillation Evaluation Results: {eval_results}")

student.save_pretrained(distil_model_path)
tokenizer.save_pretrained(distil_model_path)
publish_model(distil_model_path, PROJECT, "bert_fake_news_distil")

if __name__ == "__main__":
main()
71 changes: 71 additions & 0 deletions app/src/model/pruning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from src.helpers.wandb_registry import download_model, publish_model
import textpruner
from transformers import (
BertForSequenceClassification,
BertTokenizer
)
from textpruner import summary, TransformerPruner


MODEL_ID = "yurii-havrylko/huggingface/bert_fake_news:v0"
MODEL_PATH = "/tmp/model"
PROJECT = "huggingface"


def load_model_and_tokenizer(model_path, tokenizer_name):
model = BertForSequenceClassification.from_pretrained(model_path, local_files_only=True)
tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
return model, tokenizer

def print_model_summary(model, message):
print(message)
print(summary(model))


def prune_bert_model(model):
pruner = TransformerPruner(model)
ffn_mask = textpruner.pruners.utils.random_mask_tensor((12,3072))
head_mask = textpruner.pruners.utils.random_mask_tensor((12,12), even_masks=False)
pruner.prune(head_mask=head_mask, ffn_mask=ffn_mask, save_model=True)
return model


def print_pruned_model_info(model):
for i in range(12):
print ((model.base_model.encoder.layer[i].intermediate.dense.weight.shape,
model.base_model.encoder.layer[i].intermediate.dense.bias.shape,
model.base_model.encoder.layer[i].attention.self.key.weight.shape))


def test_inference_time(model, tokenizer, text):
token = tokenizer(text, return_tensors="pt")
inference_time = textpruner.inference_time(model, token)
return inference_time

def main():
model_path = "/tmp/model"
pruned_model_path = "/tmp/pruned/model"
tokenizer_name = "bert-base-uncased"
text = "News title"

download_model(MODEL_ID, PROJECT, model_path)

model, tokenizer = load_model_and_tokenizer(model_path, tokenizer_name)

print_model_summary(model, "Before pruning:")

model = prune_bert_model(model)

print_model_summary(model, "After pruning:")

print_pruned_model_info(model)

inference_time = test_inference_time(model, tokenizer, text)
print(f"Inference time: {inference_time}")

model.save_pretrained(pruned_model_path)
tokenizer.save_pretrained(pruned_model_path)
publish_model(pruned_model_path, PROJECT, "bert_fake_news_pruned")

if __name__ == "__main__":
main()

0 comments on commit d843205

Please sign in to comment.