- Description Slide
- No Kaggle competition
- Install packages in
Pipfile
. - Put data (csv files) in
data/
. - setup a wandb account and modify the params in wandb.init function in
train_qa.py
- Question answering
# ensemble
python3 train_qa.py --base_model hfl/chinese-xlnet-base --batch_size 2 --max_seq_length 512 --num_epoch 100 --ensemble True --wandb_logging True --exp_name chinese-xlnet-base-512-ensemble
# basic
python3 train_qa.py --base_model hfl/chinese-xlnet-base --batch_size 2 --max_seq_length 512 --num_epoch 100 --wandb_logging True --exp_name chinese-xlnet-base-512-ensemble
- The program saves the best model by the exact match of validation data (can be changed in args).
- Model will be saved in
model/
.
python3 predict.py --base_model <the_base_model> --checkpoint <the_checkpoint_model>
- Prediction will be saved in
prediction/
.
Arg | Default | Help |
---|---|---|
train_data | train.json | Training and validation data |
data_dir | data/ | Directory to the dataset |
base_model | bert-base-chinese | Base pre-trained language model |
model_dir | model/ | Directory to save model files |
lr | 1e-5 | Learning rate |
wd | 1e-2 | Weight decay |
ensemble | False | Use ensemble |
batch_size | 8 | Batch size |
train_val_split | 0.1 | Splitting ratio of training and validation data |
max_seq_length | 512 | Maximum length of tokenizer output |
device | cuda:0 | Training device |
num_epoch | 10 | Training epoch |
n_batch_per_step | 2 | num of epoch to update optimizer |
metric_for_best | valid_loss | Store best model metric |
wandb_logging | False | Logging on wandb |
exp_name | bert-base-chinese-512 | Run name on wandb |
- Train on last checkpoint
- Store training cofig (e.g., max_len_seq, base_model) and load in prediction