Skip to content

[TACL 2024] Code of "Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding"

License

Notifications You must be signed in to change notification settings

CyberAgentAILab/posthoc-control-moe

Repository files navigation

Post-Hoc Control over Mixture-of-Experts

This repository implements the main experiments of our TACL 2024 paper, Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding.

We thank the authors of RISK, on which our code was based.

Environment

We tested our code in the following environment.

  • OS: Debian GNU/Linux 10 (buster)
  • Python: 3.8.3
  • CUDA: 11.2
  • GPUs: NVIDIA V100 x 2

The experiment with DeBERTa-v3-large requires a different environment.

  • OS: Debian GNU/Linux 10 (buster)
  • Python: 3.8.3
  • CUDA: 11.2
  • GPUs: NVIDIA A100 (40GB) x 2

Getting Started

git clone https://github.com/CyberAgentAILab/posthoc-control-moe
cd posthoc-control-moe

Installation

Install dependencies to reproduce the main results.

# For conda users
conda env create -f environment.yaml
conda activate posthoc-control-moe

# For the others
pip install --force-reinstall --no-cache-dir -r requirements.txt

For the experiment with DeBERTa-v3-large, use environment_deberta.yaml or requirements_deberta.txt.

# For conda users
conda env create -f environment_deberta.yaml
conda activate posthoc-control-moe-deberta

# For the others
pip install --force-reinstall --no-cache-dir -r requirements_deberta.txt

Data Preparation

Download the datasets from here and place them as follows. Or you can just run gdown 'https://drive.google.com/drive/folders/1aleJytl3SAKdGBsxZbxznwusINOnTAzh?usp=share_link' --folder to download the datasets at once. The link is kindly provided by RISK.

./dataset/
  ├── multinli/
  │     ├── train.tsv
  │     └── dev_matched.tsv
  ├── hans/heuristics_evaluation_set.txt
  ├── qqp_paws/
  │     ├── qqp_train.tsv
  │     ├── qqp_dev.tsv
  │     └── paws_devtest.tsv
  └── fever/
        ├── fever.train.jsonl
        ├── fever.dev.jsonl
        ├── symmetric_v0.1/fever_symmetric_generated.jsonl
        └── symmetric_v0.2/fever_symmetric_test.jsonl

Original links for the datasets:

Usage

Training

Train the mixture-of-experts and save the one that performs the best on ID dev. Here, we specify the seed that yields near the average performance shown in the paper. The default seed is 777, and the analyses were conducted on that seed.

mkdir -p saved_models/mnli
mkdir -p saved_models/qqp
mkdir -p saved_models/fever

# MNLI
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
    --config_file accelerate_config.yaml --main_process_port 20880 \
    src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
    --dataset mnli --batch_size 32 --epochs 10 \
    --num_experts 10 --router_loss 0.5 --router_tau 1 \
    --num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/mnli \
    --best_model_name bert_mos_e10_rs05k8_ep10_lr2e-5_8 --save

# QQP
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
    --config_file accelerate_config.yaml --main_process_port 20880 \
    src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
    --dataset qqp --batch_size 32 --epochs 10 \
    --num_experts 15 --router_loss 1 --router_tau 1 \
    --num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/qqp \
    --best_model_name bert_mos_e15_rs1k8_ep10_lr2e-5_8 --save

# FEVER
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
    --config_file accelerate_config.yaml --main_process_port 20880 \
    src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
    --dataset fever --batch_size 32 --epochs 10 \
    --num_experts 10 --router_loss 1 --router_tau 1 \
    --num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/fever \
    --best_model_name bert_mos_e10_rs1k8_ep10_lr2e-5_8 --save

For the DeBERTa-v3-large ablation study:

# Make sure to use the environment and dependencies prepared for DeBERTa-v3-large
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
    --config_file accelerate_config_deberta.yaml --main_process_port 20880 \
    src/main_mix.py --model bert_mos --pretrained_path microsoft/deberta-v3-large \
    --dataset mnli --batch_size 32 --epochs 10 \
    --num_experts 10 --router_loss 0.5 --router_tau 1 \
    --num_topk_mask 8 --lr 5e-6 --max_grad_norm 1 --seed 888 \
    --save_dir saved_models/mnli \
    --best_model_name deberta_mos_e10_rs05k8_ep10_lr5e-6g1_bf16_8 --save

Evaluation

Evaluate the post-hoc control over the mixture-of-experts on OOD tests. Some saved models are available here for those who want to check the results quickly. Download and place them under saved_models/[task_name]/.

# HANS
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
    --config_file accelerate_config.yaml --main_process_port 20880 \
    src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
    --dataset mnli --batch_size 32 --epochs 10 \
    --num_experts 10 --router_loss 0.5 --router_tau 1 \
    --num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/mnli \
    --resume bert_mos_e10_rs05k8_ep10_lr2e-5_8 --evaluate

# PAWS
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
    --config_file accelerate_config.yaml --main_process_port 20880 \
    src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
    --dataset qqp --batch_size 32 --epochs 10 \
    --num_experts 15 --router_loss 1 --router_tau 1 \
    --num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/qqp \
    --resume bert_mos_e15_rs1k8_ep10_lr2e-5_8 --evaluate

# Symm. v1 and v2
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
    --config_file accelerate_config.yaml --main_process_port 20880 \
    src/main_mix.py --model bert_mos --pretrained_path bert-base-uncased \
    --dataset fever --batch_size 32 --epochs 10 \
    --num_experts 10 --router_loss 1 --router_tau 1 \
    --num_topk_mask 8 --lr 2e-5 --seed 888 --save_dir saved_models/fever \
    --resume bert_mos_e10_rs1k8_ep10_lr2e-5_8 --evaluate

For the DeBERTa-v3-large ablation study:

# Make sure to use the environment and dependencies prepared for DeBERTa-v3-large
CUDA_VISIBLE_DEVICES=0,1 accelerate launch \
    --config_file accelerate_config_deberta.yaml --main_process_port 20880 \
    src/main_mix.py --model bert_mos --pretrained_path microsoft/deberta-v3-large \
    --dataset mnli --batch_size 32 --epochs 10 \
    --num_experts 10 --router_loss 0.5 --router_tau 1 \
    --num_topk_mask 8 --lr 5e-6 --max_grad_norm 1 --seed 888 \
    --save_dir saved_models/mnli \
    --resume deberta_mos_e10_rs05k8_ep10_lr5e-6g1_bf16_8 --evaluate

Citation

If you find our work useful for your research, please consider citing our paper:

@article{honda-etal-2024-eliminate,
    author = {Honda, Ukyo and Oka, Tatsushi and Zhang, Peinan and Mita, Masato},
    title = {Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding},
    journal = {Transactions of the Association for Computational Linguistics},
    volume = {12},
    pages = {1268--1289},
    year = {2024},
    month = {10},
    address = {Cambridge, MA},
    publisher = {MIT Press},
    issn = {2307-387X},
    doi = {10.1162/tacl_a_00701},
    url = {https://doi.org/10.1162/tacl\_a\_00701},
}

About

[TACL 2024] Code of "Not Eliminate but Aggregate: Post-Hoc Control over Mixture-of-Experts to Address Shortcut Shifts in Natural Language Understanding"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages