Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

setfit model performance advice #540

Open
rolandtannous opened this issue Jul 14, 2024 · 4 comments
Open

setfit model performance advice #540

rolandtannous opened this issue Jul 14, 2024 · 4 comments

Comments

@rolandtannous
Copy link

Hello gents,

I was hoping I can get a second opinion about a situation I am facing while using setfit for a multi class classification use case.
The dataset is small with 255 samples across 9 classes. It suffers from both ambiguity (overlapping) and class imbalance. I partially remedied the overlaps by merging classes together. The imbalance is still present but the new dataset version is such as , except for one class, all the other have at least 10 samples per class as shown here below:

Screen Shot 2024-07-14 at 1 59 50 PM

I am performing a stratified split training 80%, validation 10%, testing 10%. The train, validation and test splits are preserving the initial class imbalance observed in the dataset post-merge.

My model is an embeddings model with a logistic regression head, constructed using this code:

model_body = SentenceTransformer("BAAI/bge-base-en-v1.5")

# Choose Logistic Regression as the classification head
model_head = LogisticRegression(class_weight="balanced")

# Create a SetFit model, combining the feature extractor and classification head
model = SetFitModel(model_body, model_head)
model.labels = categories
# labels = list(set(dataset['train']['label']))

# Load a SetFit model from Hub
model: SetFitModel = SetFitModel.from_pretrained(
    "BAAI/bge-small-en-v1.5",
)

I am then using the following code (with the latest hyper parameters choice) to train the setfit model:

# Create Training Arguments
args = TrainingArguments(
    # When an argument is a tuple, the first value is for training the embeddings,
    # and the latter is for training the differentiable classification head:
    batch_size=(32, 2),
    num_iterations=10,
    num_epochs=(5, 16),
    body_learning_rate=(1e-5, 1e-5),
    head_learning_rate=2e-2,
    end_to_end=True,
    show_progress_bar=False,
    report_to="none",
    logging_strategy="steps",
    logging_steps=50,
    eval_steps=50,
    output_dir= f"{current_path}/checkpoints",
    logging_dir = f"{current_path}/runs",
    seed=40
)

# Create Trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    metric="accuracy",
)

# Train and evaluate
trainer.train()

This is the logging output I get during training:
Screen Shot 2024-07-14 at 2 04 59 PM

the evaluation step results in an accuracy score of 0.6153846153846154

I am plotting the embedding and eval embedding loss curves:

image (14)

when I predict against the test split , I get the following results:

Screen Shot 2024-07-14 at 2 07 51 PM

Looking at these results but especially the embedding and eval embedding curves, it's obvious the training routine needs improvement. I was initially suffering from a eval embedding loss curve that was flat, but the problem was in the splitting and that problem is now solved. But the curves shouldn't look like this.
I also suspect the classifier might not be learning enough (it's just a suspicion based on digging at the predictions it made against the test split).

I welcome any helpful suggestions.

PS:

  1. I have already tried suggested hyperparameters, I read on other GitHub issues here like:
    batch_size=(16, 2)
    num_epochs=(3,16),
    body_learning_rate=(2e-5, 1e-5),
    head_learning_rate=1e-2,

but I got this:
image (16)

and worse evaluation and test prediction scores

  1. already tried optuna but do not seem to get anywhere.
@rolandtannous
Copy link
Author

rolandtannous commented Jul 15, 2024

I've improved it slightly but still does not look like how it should be.
The new training arguments look like this:

# Create Training Arguments
args = TrainingArguments(
    # When an argument is a tuple, the first value is for training the embeddings,
    # and the latter is for training the differentiable classification head:
    batch_size=(32, 2),
    num_iterations=5,
    num_epochs=(3, 16),
    body_learning_rate=(2.5e-5, 1e-5),
    head_learning_rate=1.5e-2,
    end_to_end=True,
    show_progress_bar=False,
    report_to="none",
    logging_strategy="steps",
    logging_steps=50,
    eval_steps=50,
    output_dir= f"{current_path}/checkpoints",
    logging_dir = f"{current_path}/runs",
    seed=40
)

# Create Trainer
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    metric="accuracy",
)

# Train and evaluate
trainer.train()

I am using number samples = 15 (10 gives me the same results).

The run's logs are as follows:
Screen Shot 2024-07-15 at 6 22 51 PM

The loss curves now look a bit better
Screen Shot 2024-07-15 at 6 23 01 PM

eval accuracy rate is 53.8

and the precision/recall/F1 on the test split gives me :
Screen Shot 2024-07-15 at 6 23 15 PM

I can't seem to get it to improve beyond this point on this specific set. I highly suspect the dataset quality to be the main culprit.
Goes to show the quality of a set can either positively impact a model's performance or completely constrain it.

@rolandtannous
Copy link
Author

rolandtannous commented Jul 18, 2024

I solved my own problem.

1- I had left parameters in there which were only applicable to a differentiable head.
2- I was still using pre v1.0.0 migration syntax. I mistakingly ignored the migration guide and relied on reading code snippets off of some of the GitHub issues.
3- Fixed the main culprit: data . performed data curation including annotation corrections, class merging and data augmentation.

Although the dataset is still small sized and variability is to be expected in deployment/production. Here are the revised test set performances on this multi class problem.

image

image

@chewbm05
Copy link

could you show your codes what you did to improve the performance?

@rolandtannous
Copy link
Author

@chewbm05
I fixed how I was calling setfit and dropped function arguments that are only useful, if I was using a differentiable head (I am using a logicalRegression Head). This helped: https://huggingface.co/docs/setfit/en/how_to/v1.0.0_migration_guide
But also reading through the documentation properly: https://huggingface.co/docs/setfit/v1.0.3/en/index

My specific code won't probably help you because your data is different then my data so the changes you need to implement and hyperparameter values you use would probably need to be different.
However here are some of the changes that somewhat helped out in "my case":

The main change is data transformation:
I slightly augmented the underrepresented classes in the dataset using an LLM
so the data imbalance went down from severe to low-to-mid level.
The number of samples per class was still nowhere close to where I would have liked it to be (100 samples per class). I was still at 40-50 per class.
I merged classes that were extremely semantically close (almost similar) that the model was getting confused by.

The hyper-parameter changes below would probably not have been necessary, If the dataset wasn't still suffering from some quality issues.

a) I passed weight="balanced" to the logistic regression head, and C=0.1

 model_head = LogisticRegression(class_weight='balanced', C=0.1)

C is the opposite of regularization strength, you can read more here: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html
I specified C manually based on a conversation. You can probably use grid search to find C. I I haven't yet quantified the impact from the C hyper parameter alone (vs the combination of setting C and Class_weight), and not sure it's even used here. But when I made that combo change, I noticed a positive impact, so I left it as is.

b) used a slightly unusual body learning rate because I suspected the model wasn't learning the sample data enough:

body_learning_rate=8e-6

I used an epoch of 1 and a batch size of 8.

c) Tried some non-default sampling strategies:
The default sampling strategy in set-fit is oversampling.
I tried:

  • a 'unique' sampling_strategy while setting a warmup_proportion (I used 0.0.3) as training arguments.
  • the now deprecated num_iterations by setting the value to 35.
    Both approaches seemed to help in my case.

You can read more on sampling strategies here: https://huggingface.co/docs/setfit/v1.0.3/en/conceptual_guides/sampling_strategies

Setting some of these hyper parameters can be more easily automated using optuna, once you figure out what ranges to use for a smaller sized dataset that has quality issues. I wouldn't say my approach was highly scientific but it worked.

And again, fix the dataset problems and the performance should improve considerably. Models can't make data problems disappear. Garbage in, Garbage out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants