-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
setfit model performance advice #540
Comments
I solved my own problem. 1- I had left parameters in there which were only applicable to a differentiable head. Although the dataset is still small sized and variability is to be expected in deployment/production. Here are the revised test set performances on this multi class problem. |
could you show your codes what you did to improve the performance? |
@chewbm05 My specific code won't probably help you because your data is different then my data so the changes you need to implement and hyperparameter values you use would probably need to be different. The main change is data transformation: The hyper-parameter changes below would probably not have been necessary, If the dataset wasn't still suffering from some quality issues. a) I passed weight="balanced" to the logistic regression head, and C=0.1 model_head = LogisticRegression(class_weight='balanced', C=0.1) C is the opposite of regularization strength, you can read more here: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html b) used a slightly unusual body learning rate because I suspected the model wasn't learning the sample data enough: body_learning_rate=8e-6 I used an epoch of 1 and a batch size of 8. c) Tried some non-default sampling strategies:
You can read more on sampling strategies here: https://huggingface.co/docs/setfit/v1.0.3/en/conceptual_guides/sampling_strategies Setting some of these hyper parameters can be more easily automated using optuna, once you figure out what ranges to use for a smaller sized dataset that has quality issues. I wouldn't say my approach was highly scientific but it worked. And again, fix the dataset problems and the performance should improve considerably. Models can't make data problems disappear. Garbage in, Garbage out. |
Hello gents,
I was hoping I can get a second opinion about a situation I am facing while using setfit for a multi class classification use case.
The dataset is small with 255 samples across 9 classes. It suffers from both ambiguity (overlapping) and class imbalance. I partially remedied the overlaps by merging classes together. The imbalance is still present but the new dataset version is such as , except for one class, all the other have at least 10 samples per class as shown here below:
I am performing a stratified split training 80%, validation 10%, testing 10%. The train, validation and test splits are preserving the initial class imbalance observed in the dataset post-merge.
My model is an embeddings model with a logistic regression head, constructed using this code:
I am then using the following code (with the latest hyper parameters choice) to train the setfit model:
This is the logging output I get during training:
the evaluation step results in an accuracy score of 0.6153846153846154
I am plotting the embedding and eval embedding loss curves:
when I predict against the test split , I get the following results:
Looking at these results but especially the embedding and eval embedding curves, it's obvious the training routine needs improvement. I was initially suffering from a eval embedding loss curve that was flat, but the problem was in the splitting and that problem is now solved. But the curves shouldn't look like this.
I also suspect the classifier might not be learning enough (it's just a suspicion based on digging at the predictions it made against the test split).
I welcome any helpful suggestions.
PS:
batch_size=(16, 2)
num_epochs=(3,16),
body_learning_rate=(2e-5, 1e-5),
head_learning_rate=1e-2,
but I got this:
and worse evaluation and test prediction scores
The text was updated successfully, but these errors were encountered: