This repository implements the main experiments of our TACL 2024 paper, Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding.
We thank the authors of RISK, on which our code was based.
We tested our code in the following environment.
- OS: Debian GNU/Linux 10 (buster)
- Python: 3.8.3
- CUDA: 11.2
- GPUs: NVIDIA V100 x 2
The experiment with DeBERTa-v3-large
requires a different environment.
- OS: Debian GNU/Linux 10 (buster)
- Python: 3.8.3
- CUDA: 11.2
- GPUs: NVIDIA A100 (40GB) x 2
git clone https://github.com/CyberAgentAILab/posthoc-control-moe
cd posthoc-control-moe
Install dependencies to reproduce the main results.
# For conda users
conda env create -f environment.yaml
conda activate posthoc-control-moe
# For the others
pip install --force-reinstall --no-cache-dir -r requirements.txt
For the experiment with DeBERTa-v3-large
, use environment_deberta.yaml
or requirements_deberta.txt
.
# For conda users
conda env create -f environment_deberta.yaml
conda activate posthoc-control-moe-deberta
# For the others
pip install --force-reinstall --no-cache-dir -r requirements_deberta.txt
Download the datasets from here and place them as follows.
Or you can just run gdown 'https://drive.google.com/drive/folders/1aleJytl3SAKdGBsxZbxznwusINOnTAzh?usp=share_link' --folder
to download the datasets at once.
The link is kindly provided by RISK.
./dataset/
├── multinli/
│ ├── train.tsv
│ └── dev_matched.tsv
├── hans/heuristics_evaluation_set.txt
├── qqp_paws/
│ ├── qqp_train.tsv
│ ├── qqp_dev.tsv
│ └── paws_devtest.tsv
└── fever/
├── fever.train.jsonl
├── fever.dev.jsonl
├── symmetric_v0.1/fever_symmetric_generated.jsonl
└── symmetric_v0.2/fever_symmetric_test.jsonl
Original links for the datasets:
- MNLI: https://cims.nyu.edu/~sbowman/multinli/
- HANS: https://github.com/tommccoy1/hans
- QQP and PAWS: https://github.com/google-research-datasets/paws
- FEVER and FEVER-Symmetric: https://github.com/TalSchuster/FeverSymmetric
Train the mixture-of-experts and save the one that performs the best on ID dev.
Here, we specify the seed that yields near the average performance shown in the paper.
The default seed is 777
, and the analyses were conducted on that seed.
mkdir -p saved_models/mnli
mkdir -p saved_models/qqp
mkdir -p saved_models/fever
# MNLI
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file accelerate_config.yaml --main_process_port 20880 \
src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
--dataset mnli --batch_size 32 --epochs 10 \
--num_experts 10 --router_loss 0.5 --router_tau 1 \
--num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/mnli \
--best_model_name bert_mos_e10_rs05k8_ep10_lr2e-5_8 --save
# QQP
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file accelerate_config.yaml --main_process_port 20880 \
src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
--dataset qqp --batch_size 32 --epochs 10 \
--num_experts 15 --router_loss 1 --router_tau 1 \
--num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/qqp \
--best_model_name bert_mos_e15_rs1k8_ep10_lr2e-5_8 --save
# FEVER
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file accelerate_config.yaml --main_process_port 20880 \
src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
--dataset fever --batch_size 32 --epochs 10 \
--num_experts 10 --router_loss 1 --router_tau 1 \
--num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/fever \
--best_model_name bert_mos_e10_rs1k8_ep10_lr2e-5_8 --save
For the DeBERTa-v3-large
ablation study:
# Make sure to use the environment and dependencies prepared for DeBERTa-v3-large
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file accelerate_config_deberta.yaml --main_process_port 20880 \
src/main_mix.py --model bert_mos --pretrained_path microsoft/deberta-v3-large \
--dataset mnli --batch_size 32 --epochs 10 \
--num_experts 10 --router_loss 0.5 --router_tau 1 \
--num_topk_mask 8 --lr 5e-6 --max_grad_norm 1 --seed 888 \
--save_dir saved_models/mnli \
--best_model_name deberta_mos_e10_rs05k8_ep10_lr5e-6g1_bf16_8 --save
Evaluate the post-hoc control over the mixture-of-experts on OOD tests.
Some saved models are available here for those who want to check the results quickly.
Download and place them under saved_models/[task_name]/
.
# HANS
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file accelerate_config.yaml --main_process_port 20880 \
src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
--dataset mnli --batch_size 32 --epochs 10 \
--num_experts 10 --router_loss 0.5 --router_tau 1 \
--num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/mnli \
--resume bert_mos_e10_rs05k8_ep10_lr2e-5_8 --evaluate
# PAWS
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file accelerate_config.yaml --main_process_port 20880 \
src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
--dataset qqp --batch_size 32 --epochs 10 \
--num_experts 15 --router_loss 1 --router_tau 1 \
--num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/qqp \
--resume bert_mos_e15_rs1k8_ep10_lr2e-5_8 --evaluate
# Symm. v1 and v2
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file accelerate_config.yaml --main_process_port 20880 \
src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
--dataset fever --batch_size 32 --epochs 10 \
--num_experts 10 --router_loss 1 --router_tau 1 \
--num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/fever \
--resume bert_mos_e10_rs1k8_ep10_lr2e-5_8 --evaluate
For the DeBERTa-v3-large
ablation study:
# Make sure to use the environment and dependencies prepared for DeBERTa-v3-large
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
--config_file accelerate_config_deberta.yaml --main_process_port 20880 \
src/main_mix.py --model bert_mos --pretrained_path microsoft/deberta-v3-large \
--dataset mnli --batch_size 32 --epochs 10 \
--num_experts 10 --router_loss 0.5 --router_tau 1 \
--num_topk_mask 8 --lr 5e-6 --max_grad_norm 1 --seed 888 \
--save_dir saved_models/mnli \
--resume deberta_mos_e10_rs05k8_ep10_lr5e-6g1_bf16_8 --evaluate
If you find our work useful for your research, please consider citing our paper:
@article{honda-etal-2024-eliminate,
author = {Honda, Ukyo and Oka, Tatsushi and Zhang, Peinan and Mita, Masato},
title = {Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding},
journal = {Transactions of the Association for Computational Linguistics},
volume = {12},
pages = {1268--1289},
year = {2024},
month = {10},
address = {Cambridge, MA},
publisher = {MIT Press},
issn = {2307-387X},
doi = {10.1162/tacl_a_00701},
url = {https://doi.org/10.1162/tacl\_a\_00701},
}