Skip to content

Commit

Permalink
L6T1: Test code and set up CI tests (#10)
Browse files Browse the repository at this point in the history
* training code as script

* test code

* install dev reqi

* add CI step

* fix python version

* fix path to

* change app path

* split tests and publish

* add nessesary init files

* fix path to tests

* add missing deps

* add wandb
  • Loading branch information
yuriihavrylko authored Feb 11, 2024
1 parent f267c07 commit 3f001a6
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 2 deletions.
22 changes: 22 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,29 @@ on:
workflow_dispatch:

jobs:
tests:
runs-on: ubuntu-latest
steps:
- name: 'Checkout GitHub Action'
uses: actions/checkout@main

- name: 'Set up Python'
uses: actions/setup-python@v2
with:
python-version: '3.8'

- name: 'Install dependencies'
run: |
python -m pip install --upgrade pip
pip install -r app/requirements-dev.txt
- name: 'Run pytest'
run: |
cd app/
pytest tests/
push-image:
needs: tests
runs-on: ubuntu-latest
steps:
- name: 'Checkout GitHub Action'
Expand Down
11 changes: 9 additions & 2 deletions app/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
ipykernel==6.28.0
wandb==0.16.1
-r requirements.txt

evaluate==0.4.1
great-expectations==0.18.7
pytest==7.4.4
scikit-learn==1.3.2
accelerate==0.25.0
datasets==2.16.1
wandb==0.16.1
ipykernel==6.28.0
116 changes: 116 additions & 0 deletions app/src/model/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import argparse
import numpy as np
from datasets import load_dataset
from transformers import (BertForSequenceClassification, BertTokenizer, TrainingArguments, Trainer)
from transformers import DataCollatorWithPadding
import evaluate
import wandb

MODEL_NAME = "bert-base-uncased"
SEED = 42
TRAIN_SIZE = 8000
EVAL_SIZE = 2000
DATASET_NAME = "GonzaloA/fake_news"

def parse_args(args=None):
parser = argparse.ArgumentParser(description="Train BERT for sequence classification.")
parser.add_argument('--train_size', type=int, default=TRAIN_SIZE,
help='Number of samples to use for training')
parser.add_argument('--eval_size', type=int, default=EVAL_SIZE,
help='Number of samples to use for evaluation')
return parser.parse_args(args=None)

def load_data(dataset_name=DATASET_NAME):
"""Loads a dataset using Huggingface's datasets library."""
dataset = load_dataset(dataset_name)

wandb.log({"dataset": dataset_name})

return dataset

def tokenize_data(tokenizer, dataset, padding=True, truncation=True, max_length=512):
def tokenize_function(examples):
return tokenizer(examples["text"], padding=padding, truncation=truncation, max_length=max_length)

return dataset.map(tokenize_function, batched=True)

def configure_training_args(output_dir="test_trainer"):
"""Sets up the training arguments for the Trainer."""
return TrainingArguments(
output_dir=output_dir,
evaluation_strategy="epoch",
save_strategy="epoch",
logging_dir=f"{output_dir}/logs", # directory for storing logs
logging_steps=10,
seed=SEED,
)

def compute_metrics(eval_pred):
"""Computes accuracy of the model predictions."""
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return evaluate.load("accuracy").compute(predictions=predictions, references=labels)

def prepare_datasets(tokenized_datasets, args):
"""Prepare the training and evaluation datasets from tokenized data."""
train_dataset = tokenized_datasets.select(range(args.train_size))
eval_dataset = tokenized_datasets.select(range(args.train_size, args.train_size + args.eval_size))
return train_dataset, eval_dataset

def train_model(model, tokenizer, train_dataset, eval_dataset):
"""Initialize the Trainer and train the model."""
training_args = configure_training_args()
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
data_collator=data_collator,
)

trainer.train()
return trainer

def initialize_wandb(args):
"""Initialize Weights & Biases."""
wandb.init(project="bert_fake_news_classification", entity="your_wandb_username", config=args)

def save_model_and_tokenizer(trainer, tokenizer, path="./model_checkpoint"):
"""Save the trained model and tokenizer."""
trainer.save_model(path)
tokenizer.save_pretrained(path)
return path

def log_to_wandb(dataset_name, artifact_path):
"""Log dataset and model artifact to Weights & Biases."""
wandb.log({"dataset": dataset_name})
wandb.log_artifact(artifact_path, type="model", name="bert_fake_news_classifier")

def finish_wandb():
"""Finish the Weights & Biases run."""
wandb.finish()

def main():
args = parse_args()
initialize_wandb(args)

datasets = load_data()
small_train_dataset = datasets["train"].shuffle(seed=SEED).select(range(args.train_size + args.eval_size))
model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)

tokenized_datasets = tokenize_data(tokenizer, small_train_dataset)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])

train_dataset, eval_dataset = prepare_datasets(tokenized_datasets, args)
trainer = train_model(model, tokenizer, train_dataset, eval_dataset)

artifact_path = save_model_and_tokenizer(trainer, tokenizer)
log_to_wandb(DATASET_NAME, artifact_path)
finish_wandb()

if __name__ == "__main__":
main()
Empty file added app/tests/__init__.py
Empty file.
Empty file added app/tests/model/__init__.py
Empty file.
11 changes: 11 additions & 0 deletions app/tests/model/test_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import numpy as np
from src.model.training import compute_metrics

def test_compute_metrics():
mock_logits = np.array([[2, 0.1], [0.1, 2], [2, 0.1]])
mock_labels = np.array([0, 1, 0])

output = compute_metrics((mock_logits, mock_labels))

expected_accuracy = 1.0
assert output['accuracy'] == expected_accuracy, f"Expected accuracy: {expected_accuracy}, but got: {output['accuracy']}"

0 comments on commit 3f001a6

Please sign in to comment.