Skip to content

Commit

Permalink
Contrastive Reranker/Reward model (#9171)
Browse files Browse the repository at this point in the history
* wip contrastive reranker

Signed-off-by: arendu <[email protected]>

* wip

Signed-off-by: arendu <[email protected]>

* wip

Signed-off-by: arendu <[email protected]>

* working reranker training and validation

Signed-off-by: arendu <[email protected]>

* default peft for reranker

Signed-off-by: arendu <[email protected]>

* validation time update

Signed-off-by: arendu <[email protected]>

* reranker test

Signed-off-by: arendu <[email protected]>

* reranker inference

Signed-off-by: arendu <[email protected]>

* reranker inference

Signed-off-by: arendu <[email protected]>

* Apply isort and black reformatting

Signed-off-by: arendu <[email protected]>

* updates

Signed-off-by: arendu <[email protected]>

* Apply isort and black reformatting

Signed-off-by: arendu <[email protected]>

* updates

Signed-off-by: arendu <[email protected]>

* Apply isort and black reformatting

Signed-off-by: arendu <[email protected]>

* also can support rlhf style reward model loss

Signed-off-by: arendu <[email protected]>

* Apply isort and black reformatting

Signed-off-by: arendu <[email protected]>

* Apply isort and black reformatting

Signed-off-by: arendu <[email protected]>

* typo in cicd

Signed-off-by: arendu <[email protected]>

---------

Signed-off-by: arendu <[email protected]>
Signed-off-by: arendu <[email protected]>
Signed-off-by: Adi Renduchintala <[email protected]>
Co-authored-by: arendu <[email protected]>
  • Loading branch information
arendu and arendu authored Jul 10, 2024
1 parent 355d3c5 commit 74e32c8
Show file tree
Hide file tree
Showing 16 changed files with 1,115 additions and 51 deletions.
41 changes: 41 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3198,6 +3198,47 @@ jobs:
- uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main"
if: "failure()"

L2_Megatron_GPT_Reranker:
needs: [cicd-test-container-setup]
runs-on: self-hosted-azure
timeout-minutes: 10
container:
image: nemoci.azurecr.io/nemo_container_${{ github.run_id }}
options:
# --user 0:128
--device=/dev/nvidia0
--gpus all
--shm-size=8g
--env TRANSFORMERS_OFFLINE=0
--env HYDRA_FULL_ERROR=1
--volume /mnt/datadrive/TestData:/home/TestData
steps:
- name: Checkout repository
uses: actions/checkout@v4
- run: |
rm -rf /home/TestData/nlp/megatron_ir/working_dir
python examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py \
exp_manager.exp_dir='/home/TestData/nlp/megatron_ir/working_dir' \
model.global_batch_size=4 \
model.micro_batch_size=4 \
trainer.devices=1 \
trainer.num_nodes=1 \
trainer.max_epochs=null \
trainer.max_steps=20 \
trainer.val_check_interval=10 \
model.restore_from_path='/home/TestData/nlp/megatron_gpt/mcore_45M/megatron_llama.nemo' \
model.peft.lora_tuning.adapter_dim=8 \
model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_ir/train.jsonl] \
model.data.validation_ds.write_embeddings_to_file=True \
model.data.validation_ds.output_file_path_prefix='/home/TestData/nlp/megatron_ir/working_dir/val_embs' \
model.data.train_ds.file_names=[/home/TestData/nlp/megatron_ir/train.jsonl]
rm -rf /home/TestData/nlp/megatron_ir/working_dir
- uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main"
if: "failure()"

L2_Megatron_GPT_Embedding:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ model:
tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre

data:
return_output_tensors: True
test_ds:
query_file_names: ??? # Path to a list of JSONL files corresponding to the query data. Data format is identical to validation_ds.
doc_file_names: ??? # Path to a list of JSONL files corresponding to the doc data. Data format is identical to validation_ds.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ model:
use_flash_attention: True
precision: bf16
apply_rope_fusion: False
reward_model_loss: False # Set this to true to perform RLHF style reward model loss -log(sigmoid(accept_logit - reject_logit))

peft:
peft_scheme: "lora" # can be either adapter,ia3, or ptuning
Expand Down Expand Up @@ -126,7 +127,6 @@ model:
tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre

data:
return_output_tensors: True
train_ds:
# Example of how to specify paths to multiple datasets
# file_names:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
name: megatron_gpt_peft_reranker_tuning

trainer:
devices: 1
accelerator: gpu
num_nodes: 1
precision: bf16
logger: False # logger provided by exp_manager
enable_checkpointing: False
use_distributed_sampler: False
max_epochs: null
max_steps: 20000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10 # frequency with which training steps are logged
val_check_interval: ${trainer.max_steps} # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch
gradient_clip_val: null
num_sanity_val_steps: 0

exp_manager:
explicit_log_dir: null
exp_dir: null
name: ${name}
create_wandb_logger: False
wandb_logger_kwargs:
project: null
name: null
resume_if_exists: True
resume_ignore_no_checkpoint: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: validation_${model.data.validation_ds.metric.name}
save_top_k: 1
mode: min
save_nemo_on_train_end: True
filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{consumed_samples}'
model_parallel_size: ${model.tensor_model_parallel_size}
always_save_nemo: False
save_best_model: True
create_early_stopping_callback: False
early_stopping_callback_params:
monitor: "val_loss"
mode: "min"
min_delta: 0.001
patience: 10
verbose: True
strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training.

model:
seed: 1234
tensor_model_parallel_size: 1 # intra-layer model parallelism
pipeline_model_parallel_size: 1 # inter-layer model parallelism

global_batch_size: 128
micro_batch_size: 4
restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training.
sync_batch_comm: False
megatron_amp_O2: True

## Sequence Parallelism
# Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially
# See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
sequence_parallel: False

## Activation Checkpoint
activations_checkpoint_granularity: selective # 'selective' or 'full'
activations_checkpoint_method: uniform # 'uniform', 'block', not used with 'selective'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation
# of each chunk at the specified granularity
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
activations_checkpoint_num_layers: null # not used with 'selective'
activations_checkpoint_layers_per_pipeline: null
gradient_as_bucket_view: False

hidden_dropout: 0.0
attention_dropout: 0.0
ffn_dropout: 0.0
temperature: 0.02
num_soft_negatives: 0 # Number of soft negatives to use for contrastive loss,it should be max(batch_size - 1), 0 means use hard negatives only
use_all_possible_negatives: False # If True, use all possible negatives for contrastive loss, otherwise use num_soft_negatives, if num_soft_negatives is 0, use hard negatives only
post_process: False # should be False.
apply_rope_fusion: False
transformer_engine: True # required to be True for newer versions of Megatron-LM based models
mcore_gpt: True # required to be True for newer versions of Megatron-LM based models
use_flash_attention: True
precision: bf16

peft:
peft_scheme: "mlp_head,lora" # can be either adapter,ia3, or ptuning
restore_from_path: null

# Used for adapter peft training
adapter_tuning:
type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter'
adapter_dim: 32
adapter_dropout: 0.0
norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used.
column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal
row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal
norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm']
layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers
weight_tying: False
position_embedding_strategy: null # used only when weight_tying is True

lora_tuning:
target_modules: ['attention_qkv', 'attention_dense', 'mlp_fc1', 'mlp_fc2'] #
adapter_dim: 32
adapter_dropout: 0.0
column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal
row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal
layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers
weight_tying: False
position_embedding_strategy: null # used only when weight_tying is True

# Used for p-tuning peft training
p_tuning:
virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence
bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck
embedding_dim: 1024 # the size of the prompt encoder embeddings
init_std: 0.023

# Instead of using the GPT LM Head, we can use a custom head for the reranking task
mlp_head_tuning:
out_features: 1

ia3_tuning:
layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers

selective_tuning:
tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre

data:
train_ds:
# Example of how to specify paths to multiple datasets
# file_names:
# - /path/to/squad.jsonl
# - /path/to/mnli.jsonl
# - /path/to/boolq.jsonl
# Example of how each dataset is formatted
# {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'}
file_names: ??? # Path to a list of JSONL files corresponding to the source data.
global_batch_size: ${model.global_batch_size}
micro_batch_size: ${model.micro_batch_size}
shuffle: True
num_workers: 0
memmap_workers: 2
pin_memory: True
max_seq_length: 512 # Even if the base model can handle longer sequences, 512 is generally a good choice for training efficiency.
min_seq_length: 1
drop_last: True
# Example of how to specify concat_sampling_probabilities
# concat_sampling_probabilities:
# - 0.5
# - 0.25
# - 0.25
concat_sampling_probabilities:
- 1.0
label_key: 'output'
add_eos: True
add_bos: False
index_mapping_dir: null # Path to a directory to write index mapping files.
truncation_method: 'right' # Truncation from which position, Options: ['left', 'right']
validation_ds:
file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds.
names: ["validation"] # Names of the corresponding datasets used to log metrics.
global_batch_size: ${model.global_batch_size}
micro_batch_size: ${model.micro_batch_size}
shuffle: False
num_workers: 0
memmap_workers: ${model.data.train_ds.memmap_workers}
pin_memory: True
max_seq_length: ${model.data.train_ds.max_seq_length}
min_seq_length: 1
drop_last: False
label_key: ${model.data.train_ds.label_key}
add_eos: ${model.data.train_ds.add_eos}
add_bos: ${model.data.train_ds.add_bos}
write_embeddings_to_file: False
output_file_path_prefix: "validation_rankings" # Prefix of the file to write predictions to.
index_mapping_dir: null # Path to a directory to write index mapping files.
truncation_method: 'right' # Truncation from which position, Options: ['left', 'right']
metric:
name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
num_classes: null
test_ds:
file_names: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds.
names: null # Names of the corresponding datasets used to log metrics.
global_batch_size: ${model.global_batch_size}
micro_batch_size: ${model.micro_batch_size}
shuffle: False
num_workers: 0
memmap_workers: ${model.data.train_ds.memmap_workers}
pin_memory: True
max_seq_length: ${model.data.train_ds.max_seq_length}
min_seq_length: 1
drop_last: False
add_eos: ${model.data.train_ds.add_eos}
add_bos: ${model.data.train_ds.add_bos}
write_predictions_to_file: True
output_file_path_prefix: "test_embeddings" # Prefix of the file to write predictions to.
index_mapping_dir: null # Path to a directory to write index mapping files.
truncation_method: 'right' # Truncation from which position, Options: ['left', 'right']
metric:
name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss']
average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported.
num_classes: null

optim:
name: fused_adam
lr: 1e-4
weight_decay: 0.01
betas:
- 0.9
- 0.98
sched:
name: CosineAnnealing
warmup_steps: 50
min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1
constant_steps: 0 # Constant steps should also be 0 when min_lr=0
monitor: val_loss
reduce_on_plateau: false
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def use_inference_server(cfg, model, trainer):
web_ui = get_demo
loop = asyncio.new_event_loop()
thread = threading.Thread(
target=web_ui, daemon=True, args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop),
target=web_ui,
daemon=True,
args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop),
)
thread.start()
server = MegatronServer(model.cuda())
Expand All @@ -93,7 +95,6 @@ def main(cfg) -> None:
model_cfg = MegatronGPTEmbeddingModel.merge_inference_cfg(cfg.model.restore_from_path, cfg)

with open_dict(model_cfg):
model_cfg.data.return_output_tensors = True
model_cfg.post_process = False

model = MegatronGPTEmbeddingModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections.abc import MutableMapping

import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf
from pytorch_lightning.loggers import WandbLogger

from nemo.collections.nlp.models.information_retrieval.megatron_gpt_reranker_model import MegatronGPTRerankerModel
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder
from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager

mp.set_start_method("spawn", force=True)


def flatten_dict(d: MutableMapping, parent_key: str = '', sep: str = '.') -> MutableMapping:
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, MutableMapping):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)


@hydra_runner(config_path="conf", config_name="megatron_gpt_reranker_tuning_config")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer()
exp_manager(trainer, cfg.exp_manager)

model_cfg = MegatronGPTRerankerModel.merge_cfg_with(cfg.model.restore_from_path, cfg)
if trainer.global_rank == 0:
for logger in trainer.loggers:
if isinstance(logger, WandbLogger):
fd = flatten_dict(dict(model_cfg), sep="/")
logger.experiment.config.update(fd)
model = MegatronGPTRerankerModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer)
peft_cfg_cls_lst = [PEFT_CONFIG_MAP[s] for s in cfg.model.peft.peft_scheme.split(",")]
peft_cfg_cls = [_peft_cfg(model_cfg) for _peft_cfg in peft_cfg_cls_lst]

if cfg.model.peft.restore_from_path is not None:
# initialize peft weights from a checkpoint instead of randomly
# This is not the same as resume training because optimizer states are not restored.
logging.info("PEFT Weights will be loaded from", cfg.model.peft.restore_from_path)
model.load_adapters(cfg.model.peft.restore_from_path, peft_cfg_cls)
elif peft_cfg_cls is not None:
logging.info("Adding adapter weights to the model for PEFT")
# model.add_adapter(peft_cfg_cls(model_cfg))
model.add_adapter(peft_cfg_cls)
else:
logging.info(f"Running full finetuning since no peft scheme is given.\n{model.summarize()}")

trainer.fit(model)


if __name__ == '__main__':
main()
Loading

0 comments on commit 74e32c8

Please sign in to comment.