diff --git a/.gitignore b/.gitignore index cbc04eaf..7fdc48ae 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,8 @@ cython_debug/ checkpoints/ wandb/ + +nanotron-ckpt/* +nanotron-ckpt-idefics3/* +nanotron_checkpoints/* +nanotron-ckpt-vit/* \ No newline at end of file diff --git a/docs/vision.md b/docs/vision.md new file mode 100644 index 00000000..1a8eb0ca --- /dev/null +++ b/docs/vision.md @@ -0,0 +1,29 @@ +# Doc on VLM training + +## Installation + +Required packages are mostly the same as for LLMs. Need to install PIL package to work with images: `pip install pillow`. + +Some HF-related scripts may also require `pip install accelerate`. + +`hf_transfer` may be installed to speed-up downloads from HF. + +All scripts also use [Idefics3 processor](https://huggingface.co/docs/transformers/en/model_doc/idefics3#transformers.Idefics3Processor) instead of tokenizer. By default it is expected to be saved in `hf_idefics3_processor` folder. + +## Overview + +VLM functionality uses [Idefics3](https://arxiv.org/pdf/2408.12637) architecture. It combines a CLIP-style model with a Llama-style model using their pixel shuffling technique. + +[`models/idefics.py`](/src/nanotron/models/idefics.py) contains code of the VLM implementation. + +[`tools/idefics`](/tools/idefics3/) contains HF/Nanotron conversion and simple evaluation scripts. + +[`examples/vqa`](/examples/vqa/) contains a simple fine-tuning example using a small dataset that fits into RAM. + +[`example/caption-pretrain`](/examples/caption-pretrain/) contains code that runs pretraining on a preprocessed/raw captioning dataset (LAION). + +Training dataloader uses `datasets.IterableDataset` to load and preprocess the dataset step-by-step. It allows having different encoders for different datasets and is inspired by Megatron-Energon dataloader. Each dataset requires a sample encoder, that processes single samples, and a batch encoder that collates and encodes batches. + +> Current sample encoder just appends the \ token to the caption. Need to check whether it is enough or we should use the full prompt template like in [`tools/idefics/loss_on_captions_hf.py`](/tools/idefics3/loss_on_captions_hf.py) + +Dataloader code is present in [`modular_dataloader`](/src/nanotron/modular_dataloader/), and your custom encoders can be added there. Currently dataloader only supports datasets stored in parquet. diff --git a/examples/caption-pretrain/pretrain.yaml b/examples/caption-pretrain/pretrain.yaml new file mode 100644 index 00000000..ea52a895 --- /dev/null +++ b/examples/caption-pretrain/pretrain.yaml @@ -0,0 +1,130 @@ +checkpoints: + checkpoint_interval: 10 + checkpoints_path: checkpoints_tmp + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + hf_dataset_name_or_type: "parquet" + hf_dataset_data_dir: "/capstor/store/cscs/swissai/a06/zzirui/datasets_raw/laion/laion_prep_train/" + sample_encoder: "caption_process" + sample_encoder_args: + text_field: "caption" + image_field: "jpg" + batch_encoder: "caption_preprocessed" + image_scale_factor: 2 + sample_encoding_workers: 16 + sample_encoding_batch: 16 + batch_encoding_workers: 16 + batch_encoding_batch: 16 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: idefics3_%date_%jobid + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + path: nanotron_checkpoints/Nanotron-Idefics3-8B-Llama3 + # std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + image_token_id: 128257 + text_config: + bos_token_id: 128000 + eos_token_id: + - 128001 + - 128008 + - 128009 + hidden_act: silu + hidden_size: 4096 + initializer_range: 0.02 + intermediate_size: 14336 + is_llama_config: true + max_position_embeddings: 131072 + num_attention_heads: 32 + num_hidden_layers: 32 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_interleaved: false + rope_scaling: + factor: 8.0 + high_freq_factor: 4.0 + low_freq_factor: 1.0 + original_max_position_embeddings: 8192 + rope_type: llama3 + rope_theta: 500000.0 + tie_word_embeddings: false + use_cache: true + vocab_size: 128260 + pad_token_id: 128002 + scale_factor: 2 + vision_config: + attention_dropout: 0.0 + hidden_act: gelu_pytorch_tanh + hidden_size: 1152 + image_size: 364 + intermediate_size: 4304 + is_using_mup: false + layer_norm_eps: 1.0e-06 + num_attention_heads: 16 + num_channels: 3 + num_hidden_layers: 27 + num_key_value_heads: 16 + patch_size: 14 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 13 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 3 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: "hf_idefics3_processor" + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 2 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 1 + sequence_length: 2048 + train_steps: 100 + val_check_interval: -1 diff --git a/examples/caption-pretrain/pretrain_preprocessed.yaml b/examples/caption-pretrain/pretrain_preprocessed.yaml new file mode 100644 index 00000000..9f0436f5 --- /dev/null +++ b/examples/caption-pretrain/pretrain_preprocessed.yaml @@ -0,0 +1,127 @@ +checkpoints: + checkpoint_interval: 100 + checkpoints_path: checkpoints_tmp_2 + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + hf_dataset_name_or_type: "parquet" + hf_dataset_data_dir: "/capstor/store/cscs/swissai/a06/zzirui/datasets_raw/laion/laion_train/hold_out_test/" + sample_encoder: "caption_preprocessed" + batch_encoder: "caption_preprocessed" + image_scale_factor: 2 + sample_encoding_workers: 16 + sample_encoding_batch: 16 + batch_encoding_workers: 16 + batch_encoding_batch: 16 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: idefics3_%date_%jobid + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + path: nanotron_checkpoints/Nanotron-Idefics3-8B-Llama3 + # std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + image_token_id: 128257 + text_config: + bos_token_id: 128000 + eos_token_id: + - 128001 + - 128008 + - 128009 + hidden_act: silu + hidden_size: 4096 + initializer_range: 0.02 + intermediate_size: 14336 + is_llama_config: true + max_position_embeddings: 131072 + num_attention_heads: 32 + num_hidden_layers: 32 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_interleaved: false + rope_scaling: + factor: 8.0 + high_freq_factor: 4.0 + low_freq_factor: 1.0 + original_max_position_embeddings: 8192 + rope_type: llama3 + rope_theta: 500000.0 + tie_word_embeddings: false + use_cache: true + vocab_size: 128260 + pad_token_id: 128002 + scale_factor: 2 + vision_config: + attention_dropout: 0.0 + hidden_act: gelu_pytorch_tanh + hidden_size: 1152 + image_size: 364 + intermediate_size: 4304 + is_using_mup: false + layer_norm_eps: 1.0e-06 + num_attention_heads: 16 + num_channels: 3 + num_hidden_layers: 27 + num_key_value_heads: 16 + patch_size: 14 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 13 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 4 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: "hf_idefics3_processor" + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 2 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 4 + sequence_length: 2048 + train_steps: 1000 + val_check_interval: -1 diff --git a/examples/caption-pretrain/run_train.py b/examples/caption-pretrain/run_train.py new file mode 100644 index 00000000..438b663f --- /dev/null +++ b/examples/caption-pretrain/run_train.py @@ -0,0 +1,165 @@ +""" +Nanotron training script example using a custom dataloader. + +Usage: +``` +export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations +torchrun --nproc_per_node=4 examples/caption-pretrain/run_train.py --config-file examples/caption-pretrain/pretrain.yaml +``` + +For preprocessed dataset: +torchrun --nproc_per_node=4 examples/caption-pretrain/run_train.py --config-file examples/caption-pretrain/pretrain_preprocessed.yaml +""" + +import argparse +from typing import Dict, cast + +import datasets +from torch.utils.data import DataLoader + +from nanotron import logging +from nanotron.config import ( + DataArgs, + DatasetStageArgs, +) +from nanotron.config.config import ImageDatasetsArgs +from nanotron.helpers import get_consumed_train_samples_of_a_data_stage_from_ckp +from nanotron.logging import log_rank +from nanotron.modular_dataloader import BATCH_ENCODERS, SAMPLE_ENCODERS +from nanotron.modular_dataloader.iterable import get_train_dataloader +from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks +from nanotron.trainer import DistributedTrainer + +try: + from huggingface_hub import __version__ as hf_hub_version + from transformers import AutoProcessor, AutoTokenizer + from transformers import __version__ as tf_version +except ImportError: + hf_hub_version = None + tf_version = None + +logger = logging.get_logger(__name__) + + +def get_dataloader_from_data_stage( + trainer: DistributedTrainer, + data: DataArgs, + consumed_train_samples: int +): + """ + Returns a dataloader for a given data stage. + + data: The data configuration for the current stage. + """ + + # First, we need to know which ranks to feed the dataloader to + input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) + + if isinstance(data.dataset, ImageDatasetsArgs): + log_rank("Using iterable dataset from `datasets` library", logger=logger, level=logging.INFO, rank=0) + tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path + log_rank( + f"Loading tokenizer from {tokenizer_path} with HF version {hf_hub_version} and Transformers version {tf_version}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + dataset = datasets.load_dataset( + data.dataset.hf_dataset_name_or_type, + data_dir=data.dataset.hf_dataset_data_dir, + split=data.dataset.hf_dataset_splits, + streaming=True + ) + + processor = AutoProcessor.from_pretrained( + tokenizer_path, + size={"longest_edge": data.dataset.image_size * data.dataset.image_scale_factor}, + ) + + sample_encoder = SAMPLE_ENCODERS[data.dataset.sample_encoder]( + processor=processor, + sequence_length=trainer.sequence_length, + **data.dataset.sample_encoder_args + ) + + batch_encoder = BATCH_ENCODERS[data.dataset.batch_encoder]( + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + parallel_context=trainer.parallel_context, + processor=processor, + sequence_length=trainer.sequence_length, + **data.dataset.batch_encoder_args + ) + + dataloader = get_train_dataloader( + train_dataset=dataset, + sample_encoder=sample_encoder, + batch_encoder=batch_encoder, + parallel_context=trainer.parallel_context, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + micro_batch_size=trainer.micro_batch_size, + sample_encoding_batch=data.dataset.sample_encoding_batch, + batch_encoding_batch=data.dataset.batch_encoding_batch, + seed_worker=data.seed, + sample_encoding_workers=data.dataset.sample_encoding_workers, + batch_encoding_workers=data.dataset.batch_encoding_workers, + consumed_train_samples=consumed_train_samples, + drop_last=True, + ) + else: + raise ValueError(f"Unsupported dataset case: {data.dataset}") + + return dataloader + + +def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: + dataloaders = {} + + for stage_idx, stage in enumerate(trainer.config.data_stages): + # NOTE: we only create the dataloader for the first stage, + # then we lazy initialize the dataloader for the other stages + stage = cast(DatasetStageArgs, stage) + consumed_train_samples = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, trainer.metadata) + + log_rank( + f"[Training Plan] Stage {stage.name}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + dataloader = ( + get_dataloader_from_data_stage( + trainer, + stage.data, + consumed_train_samples + ) + if stage_idx == 0 + else lambda stage=stage: get_dataloader_from_data_stage( + trainer, + stage.data, + consumed_train_samples + ) + ) + dataloaders[stage.name] = dataloader + return dataloaders + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + config_file = args.config_file + + # Load trainer and data + trainer = DistributedTrainer(config_file) + dataloader = get_dataloader(trainer) + + # Train + trainer.train(dataloader) diff --git a/examples/vqa/config_vqa.yaml b/examples/vqa/config_vqa.yaml new file mode 100644 index 00000000..ab6f08ba --- /dev/null +++ b/examples/vqa/config_vqa.yaml @@ -0,0 +1,136 @@ +checkpoints: + checkpoint_interval: 10 + checkpoints_path: checkpoints + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 16 + hf_dataset_config_name: null + hf_dataset_or_datasets: cmarkea/doc-vqa + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 16 + hf_dataset_config_name: null + hf_dataset_or_datasets: cmarkea/doc-vqa + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Annealing Phase + start_training_step: 10 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: debug + run: idefics3_%date_%jobid + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + path: nanotron_checkpoints/Nanotron-Idefics3-8B-Llama3 + make_vocab_size_divisible_by: 1 + model_config: + image_token_id: 128257 + text_config: + bos_token_id: 128000 + eos_token_id: + - 128001 + - 128008 + - 128009 + hidden_act: silu + hidden_size: 4096 + initializer_range: 0.02 + intermediate_size: 14336 + is_llama_config: true + max_position_embeddings: 131072 + num_attention_heads: 32 + num_hidden_layers: 32 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-05 + rope_interleaved: false + rope_scaling: + factor: 8.0 + high_freq_factor: 4.0 + low_freq_factor: 1.0 + original_max_position_embeddings: 8192 + rope_type: llama3 + rope_theta: 500000.0 + tie_word_embeddings: false + use_cache: true + vocab_size: 128260 + pad_token_id: 128002 + scale_factor: 2 + vision_config: + attention_dropout: 0.0 + hidden_act: gelu_pytorch_tanh + hidden_size: 1152 + image_size: 364 + intermediate_size: 4304 + is_using_mup: false + layer_norm_eps: 1.0e-06 + num_attention_heads: 16 + num_channels: 3 + num_hidden_layers: 27 + num_key_value_heads: 16 + patch_size: 14 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 13 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: afab + tp: 2 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: "HuggingFaceM4/Idefics3-8B-Llama3" + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 4 + sequence_length: 2048 + train_steps: 15 + val_check_interval: -1 diff --git a/examples/vqa/debug_idefics.sh b/examples/vqa/debug_idefics.sh new file mode 100644 index 00000000..4ad095fa --- /dev/null +++ b/examples/vqa/debug_idefics.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +# Simple script to create a tiny llama model and train it + +set -e -x + +# Create the YAML config file + +EXAMPLE_PATH=$(cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P) +REPO_PATH=$(dirname $EXAMPLE_PATH) + + +# python $EXAMPLE_PATH/debug_config_tiny_llama.py + +# $REPO_PATH="exaples/vqa" + +# Setup from environment variables + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export FI_PROVIDER="efa" + +debugpy-run -m torch.distributed.run -p 5678 \ + -- \ + --nproc_per_node 2 \ + --nnodes 1 \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --tee 3 \ + $REPO_PATH/vqa/run_train.py --config-file $REPO_PATH/vqa/config_vqa.yaml diff --git a/examples/vqa/run_train.py b/examples/vqa/run_train.py new file mode 100644 index 00000000..d9c0824a --- /dev/null +++ b/examples/vqa/run_train.py @@ -0,0 +1,345 @@ +""" +Nanotron training script example using a custom dataloader. + +Usage: +``` +export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations +torchrun --nproc_per_node=2 examples/vqa/run_train.py --config-file examples/vqa/config_vqa.yaml +``` +""" +import argparse +import dataclasses +from typing import Dict, List, Union, cast + +import datasets +import numpy as np +import torch +from torch.utils.data import DataLoader + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import ( + DataArgs, + DatasetStageArgs, + PretrainDatasetsArgs, +) +from nanotron.dataloader import ( + get_datasets, + get_train_dataloader, +) +from nanotron.helpers import ( + compute_remain_train_steps_of_a_data_stage_from_ckp, + get_consumed_train_samples_of_a_data_stage_from_ckp, +) +from nanotron.logging import log_rank +from nanotron.parallel.context import ParallelContext +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer +from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks +from nanotron.trainer import DistributedTrainer +from nanotron.utils import main_rank_first + +try: + from huggingface_hub import __version__ as hf_hub_version + from transformers import AutoProcessor, AutoTokenizer + from transformers import __version__ as tf_version +except ImportError: + hf_hub_version = None + tf_version = None + +logger = logging.get_logger(__name__) + + +def vqa_process( + raw_dataset: datasets.Dataset, + processor: AutoProcessor, + dataset_processing_num_proc_per_process: int, + dataset_overwrite_cache: bool, + sequence_length: int, +): + def format_example(example): + messages = [] + for i, x in enumerate(example["en"]): + user_message = { + "role": "user", + "content": [ + {"type": "text", "text": x["question"]}, + ] + } + + if i == 0: + user_message["content"].append( + {"type": "image"}, + ) + + messages.append(user_message) + assistant_message = { + "role": "assistant", + "content": [ + {"type": "text", "text": x["answer"]}, + ] + } + + messages.append(assistant_message) + return messages + + def _process_examples(examples: Dict, images) -> Dict[str, List[np.ndarray]]: + inputs = [ + processor( + text=processor.apply_chat_template(format_example(ex), add_generation_prompt=True), + images = [img], + return_tensors="np", max_length=sequence_length + 1, padding="longest", truncation=True + ) + for ex, img in zip(examples, images) + ] + + inputs = { + k: [v[k] for v in inputs] for k in ["input_ids", "pixel_values"] + } + + return inputs + + train_dataset = raw_dataset.map( + _process_examples, + input_columns=["qa", "image"], + remove_columns=raw_dataset.column_names, + batched=True, + num_proc=dataset_processing_num_proc_per_process, + load_from_cache_file=not dataset_overwrite_cache, + ) + + return train_dataset + + +@dataclasses.dataclass +class DataCollatorForVQA: + sequence_length: int + input_pp_rank: int + output_pp_rank: int + parallel_context: ParallelContext + padding_idx: int = 128_002 + + def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + # Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data. + current_pp_rank = dist.get_rank(self.parallel_context.pp_pg) + if current_pp_rank not in [ + self.input_pp_rank, + self.output_pp_rank, + ]: + assert all(len(example) == 0 for example in examples) + return { + "input_ids": TensorPointer(group_rank=self.input_pp_rank), + "input_mask": TensorPointer(group_rank=self.input_pp_rank), + "label_ids": TensorPointer(group_rank=self.output_pp_rank), + "label_mask": TensorPointer(group_rank=self.output_pp_rank), + "pixel_values": TensorPointer(group_rank=self.input_pp_rank), + } + + # Make sure we load only what's necessary, ie we only load `input_ids` and `pixel_values` column. + assert all(list(example.keys()) == ["input_ids", "pixel_values"] for example in examples) + + max_n_patches = max([len(examples[i]["pixel_values"][0]) for i in range(len(examples))]) + + padded_pixel_values = [] + + for example in examples: + pixel_values = example["pixel_values"] + current_patches = len(pixel_values[0]) + + # Pad the pixel_values to have max_n_patches along dimension 1 (patches) + padding = ((0, 0), (0, max_n_patches - current_patches), (0, 0), (0, 0), (0, 0)) # Only pad the patches dimension + padded_values = np.pad(pixel_values, pad_width=padding, mode='constant', constant_values=0) + padded_pixel_values.append(padded_values) + + + # Step 3: Stack padded pixel_values and pixel_attention_masks + pixel_values = np.vstack(padded_pixel_values) # Stacked pixel_values + result: Dict[str, Union[np.ndarray, TensorPointer]] = {} + + result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank) + result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank) + result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank) + result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank) + result["pixel_values"] = TensorPointer(group_rank=self.input_pp_rank) + + def pad_tokens(inputs = True): + padded_tokens = [] + token_masks = [] + + max_seq_length = max([len(examples[i]["input_ids"][0]) for i in range(len(examples))]) - 1 + # make it divisible by 4 for tp + max_seq_length = max_seq_length + (4 - max_seq_length % 4) % 4 + + for example in examples: + input_ids = example["input_ids"] + if isinstance(input_ids, list): + input_ids = np.array(input_ids) + + if inputs: + input_ids = input_ids[:, :-1] + else: + input_ids = input_ids[:, 1:] + + current_length = input_ids.shape[1] + + padding = ((0, 0), (0, max_seq_length - current_length)) + input_ids = np.pad(input_ids, pad_width=padding, mode='constant', constant_values=self.padding_idx) + padded_tokens.append(input_ids) + + mask = np.ones((1, current_length), dtype=np.bool_) + mask = np.pad(mask, pad_width=padding, mode='constant', constant_values=0) + token_masks.append(mask) + + padded_tokens = np.vstack(padded_tokens) + token_masks = np.vstack(token_masks) + + return padded_tokens, token_masks + + if current_pp_rank == self.input_pp_rank: + result["input_ids"], result["input_mask"] = pad_tokens(inputs=True) + result["pixel_values"] = pixel_values + + if current_pp_rank == self.output_pp_rank: + result["label_ids"], result["label_mask"] = pad_tokens(inputs=False) + + result = {k: v if isinstance(v, TensorPointer) else torch.from_numpy(v) for k, v in result.items()} + return result + + +def get_dataloader_from_data_stage( + trainer: DistributedTrainer, + data: DataArgs, + consumed_train_samples: int, + num_remaining_train_steps: int, +): + """ + Returns a dataloader for a given data stage. + + data: The data configuration for the current stage. + consumed_train_samples: The number of samples consumed by the model in the this stage (each stage starts from zero). + num_remaining_train_steps: The number of remaining training steps for this stage. + """ + assert consumed_train_samples >= 0, "consumed_train_samples should be greater than 0" + assert num_remaining_train_steps >= 0, "num_remaining_train_steps should be greater than 0" + + # First, we need to know which ranks to feed the dataloader to + input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) + + if isinstance(data.dataset, PretrainDatasetsArgs): + log_rank("Using `datasets` library", logger=logger, level=logging.INFO, rank=0) + tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path + log_rank( + f"Loading tokenizer from {tokenizer_path} and transformers/hf_hub versions {tf_version, hf_hub_version}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + # We need to the 1st device to process dataset and cache it, then other devices load from cache + with main_rank_first(trainer.parallel_context.world_pg): + # We load the raw dataset + raw_dataset = get_datasets( + hf_dataset_or_datasets=data.dataset.hf_dataset_or_datasets, + hf_dataset_config_name=data.dataset.hf_dataset_config_name, + splits=data.dataset.hf_dataset_splits, + )["train"] + + raw_dataset = raw_dataset.select(range(1000)) + + processor = AutoProcessor.from_pretrained(tokenizer_path, size= {"longest_edge": 2*364}) + train_dataset = vqa_process( + raw_dataset=raw_dataset, + processor=processor, + dataset_processing_num_proc_per_process=data.dataset.dataset_processing_num_proc_per_process, + dataset_overwrite_cache=data.dataset.dataset_overwrite_cache, + sequence_length=trainer.sequence_length, + ) + + + # We load the processed dataset on the ranks requiring it + dataloader = get_train_dataloader( + train_dataset=train_dataset, + sequence_length=trainer.sequence_length, + parallel_context=trainer.parallel_context, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + micro_batch_size=trainer.micro_batch_size, + consumed_train_samples=consumed_train_samples, + dataloader_num_workers=data.num_loading_workers, + seed_worker=data.seed, + dataloader_drop_last=True, + dataset_columns=["input_ids", "pixel_values"], + collator_builder=DataCollatorForVQA + ) + + # Check if we have enough samples for train_steps + total_tokens_dataset = len(dataloader.dataset) * trainer.sequence_length + num_tokens_needed_for_training = ( + num_remaining_train_steps * trainer.global_batch_size * trainer.sequence_length + ) + assert num_tokens_needed_for_training <= total_tokens_dataset, ( + f"Dataset is too small for steps ({total_tokens_dataset} < {num_tokens_needed_for_training}), " + f"Try train_steps<={len(dataloader.dataset) // trainer.global_batch_size + trainer.iteration_step}" + ) + else: + raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}") + + return dataloader + + +def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: + dataloaders = {} + + for stage_idx, stage in enumerate(trainer.config.data_stages): + # NOTE: we only create the dataloader for the first stage, + # then we lazy initialize the dataloader for the other stages + stage = cast(DatasetStageArgs, stage) + consumed_train_samples = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, trainer.metadata) + assert ( + consumed_train_samples is not None + ), f"Cannot find consumed_train_samples for stage {stage.start_training_step} in the checkpoint" + + num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( + stage, trainer.config, trainer.metadata + ) + log_rank( + f"[Training Plan] Stage {stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {consumed_train_samples} samples", + logger=logger, + level=logging.INFO, + rank=0, + ) + + dataloader = ( + get_dataloader_from_data_stage( + trainer, + stage.data, + consumed_train_samples=consumed_train_samples, + num_remaining_train_steps=num_remaining_train_steps, + ) + if stage_idx == 0 + else lambda stage=stage: get_dataloader_from_data_stage( + trainer, + stage.data, + consumed_train_samples=consumed_train_samples, + num_remaining_train_steps=num_remaining_train_steps, + ) + ) + dataloaders[stage.name] = dataloader + return dataloaders + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file") + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + config_file = args.config_file + + # Load trainer and data + trainer = DistributedTrainer(config_file) + dataloader = get_dataloader(trainer) + + # Train + trainer.train(dataloader) diff --git a/run_train.py b/run_train.py index 021d955d..3209b93b 100644 --- a/run_train.py +++ b/run_train.py @@ -8,11 +8,13 @@ ``` """ import argparse +import datasets from typing import Dict, cast import numpy as np from nanotron import logging from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs +from nanotron.config.config import ImageDatasetsArgs from nanotron.data.dataloader_builder import build_nanoset_dataloader from nanotron.dataloader import ( clm_process, @@ -25,6 +27,7 @@ get_consumed_train_samples_of_a_data_stage_from_ckp, ) from nanotron.logging import log_rank +from nanotron.modular_dataloader import BATCH_ENCODERS, SAMPLE_ENCODERS from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks from nanotron.trainer import DistributedTrainer from nanotron.utils import main_rank_first @@ -32,7 +35,7 @@ try: from huggingface_hub import __version__ as hf_hub_version - from transformers import AutoTokenizer + from transformers import AutoTokenizer, AutoProcessor from transformers import __version__ as tf_version except ImportError: hf_hub_version = None @@ -96,6 +99,8 @@ def get_dataloader_from_data_stage( splits=data.dataset.hf_dataset_splits, )["train"] + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" @@ -127,6 +132,7 @@ def get_dataloader_from_data_stage( dataloader_num_workers=data.num_loading_workers, seed_worker=data.seed, dataloader_drop_last=True, + dataset_columns=["input_ids"] ) # Check if we have enough samples for train_steps diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index adc1eafd..9f90b0fe 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -107,11 +107,40 @@ def __post_init__(self): self.dataset_weights = list(tmp_dataset_folder.values()) +@dataclass +class ImageDatasetsArgs: + hf_dataset_name_or_type: str + sample_encoder: str + batch_encoder: str + sample_encoding_workers: int + batch_encoding_workers: int + image_scale_factor: int + + image_size: int = 364 + sample_encoding_batch: int = 1000 + batch_encoding_batch: int = 1000 + hf_dataset_splits: Optional[Union[str, list]] = None + hf_dataset_config_name: Optional[str] = None + hf_dataset_data_dir: Optional[str] = None + + sample_encoder_args: Optional[dict] = None + batch_encoder_args: Optional[dict] = None + + def __post_init__(self): + if self.hf_dataset_splits is None: + self.hf_dataset_splits = "train" + if self.image_size is None: + self.image_size = 364 + if self.sample_encoder_args is None: + self.sample_encoder_args = {} + if self.batch_encoder_args is None: + self.batch_encoder_args = {} + @dataclass class DataArgs: """Arguments related to the data and data files processing""" - dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs]] + dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs, ImageDatasetsArgs]] seed: Optional[int] num_loading_workers: Optional[int] = 1 @@ -423,6 +452,7 @@ def get_config_from_dict( for k, v in config_dict.items() if v is not None } + return from_dict( data_class=config_class, data=config_dict, diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 2630e1d6..49ef08a6 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -33,7 +33,7 @@ class LlamaConfig: """ bos_token_id: int = 1 - eos_token_id: int = 2 + eos_token_id: Union[int, List[int]] = 2 hidden_act: str = "silu" hidden_size: int = 4096 initializer_range: float = 0.02 @@ -134,6 +134,39 @@ def n_positions(self): @property def n_inner(self): return self.intermediate_size + +@dataclass +class Idefics3VisionConfig: + """Configuration for a Idefics3 vision model (Siglip modification) + + Be careful on having a coherent typing as we use it to reconstruct the model from yaml + """ + hidden_size: int = 768 + image_size: int = 224 + patch_size: int = 32 + num_channels: int = 3 + num_attention_heads: int = 12 + num_key_value_heads: int = 12 + is_using_mup: bool = False + intermediate_size: int = 3072 + hidden_act: str = "gelu_pytorch_tanh" + layer_norm_eps: float = 1e-6 + attention_dropout: float = 0.0 + num_hidden_layers: int = 12 + +@dataclass +class Idefics3Config: + """Configuration for a Idefics3 model + + Be careful on having a coherent typing as we use it to reconstruct the model from yaml + """ + vision_config: Idefics3VisionConfig + text_config: LlamaConfig + + image_token_id: int = 128257 + pad_token_id: int = 128_002 + scale_factor: int = 2 + vocab_size: int = 128260 -NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Any] +NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Idefics3Config, Idefics3VisionConfig] \ No newline at end of file diff --git a/src/nanotron/data/visual.py b/src/nanotron/data/visual.py new file mode 100644 index 00000000..f019eb5c --- /dev/null +++ b/src/nanotron/data/visual.py @@ -0,0 +1,17 @@ +from pyarrow.parquet import ParquetDataset + +class RawParquetFolderDataset(Dataset): + def __init__(self, path: str, transform: Optional[Callable] = None): + self.path = path + self.transform = transform + self.files = sorted(glob.glob(os.path.join(path, "*.parquet"))) + + def __len__(self): + return len(self.files) + + def __getitem__(self, idx): + file = self.files[idx] + df = pd.read_parquet(file) + if self.transform: + df = self.transform(df) + return df \ No newline at end of file diff --git a/src/nanotron/dataloader.py b/src/nanotron/dataloader.py index 61f73557..58244aa2 100644 --- a/src/nanotron/dataloader.py +++ b/src/nanotron/dataloader.py @@ -29,7 +29,7 @@ concatenate_datasets, load_dataset, ) - from transformers import PreTrainedTokenizerBase + from transformers import PreTrainedTokenizerBase, Idefics3Processor from transformers.trainer_pt_utils import DistributedSamplerWithLoop except ImportError: warnings.warn("Datasets and/or Transformers not installed, you'll be unable to use the dataloader.") @@ -119,6 +119,7 @@ def get_datasets( hf_dataset_or_datasets, hf_dataset_config_name, split=split, + trust_remote_code=True ) else: raise ValueError(f"hf_dataset_or_datasets must be a dict or string but is {type(hf_dataset_or_datasets)}") @@ -322,7 +323,7 @@ def _tokenize_and_group_texts(texts: List[str]) -> Dict[str, List[np.ndarray]]: desc=f"Grouping texts in chunks of {sequence_length+1}", ) return train_dataset - + # Adapted from: https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/data/data_collator.py#L607 @dataclasses.dataclass @@ -397,7 +398,7 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni # Cast np.array to torch.Tensor result = {k: v if isinstance(v, TensorPointer) else torch.from_numpy(v) for k, v in result.items()} return result - + # Adapted from https://github.com/huggingface/transformers/blob/47e1676255e5dd86b9541f734cd4f4bdcbb50f4a/src/transformers/trainer.py#L763-L835 def get_sampler( @@ -452,6 +453,8 @@ def get_train_dataloader( dataloader_drop_last: bool = True, dataloader_pin_memory: bool = True, use_loop_to_round_batch_size: bool = False, + dataset_columns = ["input_ids"], + collator_builder=None ) -> DataLoader: if not isinstance(train_dataset, datasets.Dataset): raise ValueError(f"training requires a datasets.Dataset, but got {type(train_dataset)}") @@ -461,17 +464,17 @@ def get_train_dataloader( input_pp_rank, output_pp_rank, ]: - train_dataset = train_dataset.with_format(type="numpy", columns=["input_ids"], output_all_columns=True) + train_dataset = train_dataset.with_format(type="numpy", columns=dataset_columns, output_all_columns=True) # Case of ranks not requiring data. We give them an infinite dummy dataloader else: # - assert train_dataset.column_names == ["input_ids"], ( - f"Dataset has to have a single column, with `input_ids` as the column name. " + assert train_dataset.column_names == dataset_columns, ( + f"Dataset should only have {dataset_columns} columns" f"Current dataset: {train_dataset}" ) dataset_length = len(train_dataset) - train_dataset = train_dataset.remove_columns(column_names="input_ids") + train_dataset = train_dataset.remove_columns(column_names=dataset_columns) assert ( len(train_dataset) == 0 ), f"Dataset has to be empty after removing the `input_ids` column. Current dataset: {train_dataset}" @@ -480,13 +483,21 @@ def get_train_dataloader( # No need to spawn a lot of workers, we can just use main dataloader_num_workers = 0 - data_collator = DataCollatorForCLM( - sequence_length=sequence_length, - input_pp_rank=input_pp_rank, - output_pp_rank=output_pp_rank, - parallel_context=parallel_context, - ) - + if collator_builder is not None: + data_collator = collator_builder( + sequence_length=sequence_length, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + parallel_context=parallel_context, + ) + else: + data_collator = DataCollatorForCLM( + sequence_length=sequence_length, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + parallel_context=parallel_context, + ) + # Compute size and rank of dataloader workers dp_ranks_size = parallel_context.dp_pg.size() dp_rank = parallel_context.dp_pg.rank() diff --git a/src/nanotron/distributed.py b/src/nanotron/distributed.py index 0156b1bb..40640428 100644 --- a/src/nanotron/distributed.py +++ b/src/nanotron/distributed.py @@ -24,6 +24,21 @@ def new_group( # pylint: disable=function-redefined return dist.new_group(ranks=ranks, timeout=timeout, backend=backend, pg_options=pg_options) +def scatter( + output: torch.Tensor, + scatter_list: Optional[List[torch.Tensor]], + src: int, + group: Optional[ProcessGroup] = None, + async_op: bool = False, +) -> Optional[Work]: + if group is None: + group = dist.torch_dist.distributed_c10d._get_default_group() + + assert ( + group.size() > 1 + ), "You should probably not call `scatter` with a single rank, as it copies data over" + + return dist.scatter(tensor=output, scatter_list=scatter_list, src=src, group=group, async_op=async_op) def reduce_scatter_tensor( # pylint: disable=function-redefined output: torch.Tensor, diff --git a/src/nanotron/models/idefics.py b/src/nanotron/models/idefics.py new file mode 100644 index 00000000..403e05db --- /dev/null +++ b/src/nanotron/models/idefics.py @@ -0,0 +1,1136 @@ +import torch +from typing import Dict, Optional, Union +from torch import nn + +from nanotron import logging +from nanotron.config.config import Config +from nanotron.config.models_config import LlamaConfig, RandomInit, SpectralMupInit +from nanotron.config.parallelism_config import ParallelismArgs +from nanotron.logging import log_rank +from nanotron.models.base import NanotronModel +from nanotron.nn.layer_norm import TritonLayerNorm +from nanotron.parallel.context import ParallelContext +from nanotron.parallel.parameters import NanotronParameter +from nanotron.parallel.pipeline_parallel.block import PipelineBlock +from nanotron.parallel.pipeline_parallel.p2p import P2P +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer +from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import differentiable_all_gather, differentiable_identity, differentiable_scatter +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear +from nanotron.distributed import dist +from nanotron.config import Idefics3VisionConfig, Idefics3Config +from nanotron.generation.generate_store import AttachableStore +from nanotron.random import RandomStates, branch_random_state +from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator +from nanotron.utils import checkpoint_method +from nanotron.models.llama import GLUActivation, LlamaDecoderLayer, LlamaModel +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy + + +logger = logging.get_logger(__name__) + +class LlamaEmbeddings(nn.Module, AttachableStore): + def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]): + super().__init__() + self.token_embedding = TensorParallelEmbedding( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + padding_idx=config.pad_token_id, + pg=tp_pg, + mode=TensorParallelLinearMode.ALL_REDUCE, + ) + self.pg = tp_pg + + def forward(self, input_ids: torch.Tensor, input_mask: torch.Tensor): # [batch_size, seq_length] + store = self.get_local_store() + if store is not None: + if "past_length" in store: + past_length = store["past_length"] + else: + past_length = torch.zeros(1, dtype=torch.long, device=input_ids.device).expand(input_ids.shape[0]) + + cumsum_mask = input_mask.cumsum(-1, dtype=torch.long) + # Store new past_length in store + store["past_length"] = past_length + cumsum_mask[:, -1] + + # Format input in `[seq_length, batch_size]` to support high TP with low batch_size + input_embeds = self.token_embedding(input_ids) + return {"input_embeds": input_embeds} + + +class VisionEmbedding(nn.Module, AttachableStore): + """ + Sharded implementation of the Idefics3VisionEmbeddings from huggingface for nanotron. Uses CLIPVit for megatron as a reference. + """ + def __init__(self, tp_pg: dist.ProcessGroup, config: Idefics3VisionConfig, parallel_config: Optional[ParallelismArgs]): + super().__init__() + self.tp_pg = tp_pg + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side ** 2 + self.num_positions = self.num_patches + + self.position_embedding = TensorParallelEmbedding( + num_embeddings=self.num_positions, + embedding_dim=self.embed_dim, + pg=tp_pg, + mode=TensorParallelLinearMode.ALL_REDUCE, + ) + + def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> Dict[str, torch.Tensor]: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + first_dim = position_ids.shape[0] + group_size = self.tp_pg.size() + + if first_dim % group_size != 0: + position_ids = nn.functional.pad(position_ids, (0, 0, 0, group_size - first_dim % group_size), mode="constant", value=0) + + position_ids = position_ids.to(self.position_embedding.weight.device) + + position_ids = self.position_embedding(position_ids) + + embeddings = embeddings + position_ids[:first_dim] + + return { + "embeddings": embeddings, + } + + +class VisionCoreAttention(nn.Module): + def __init__(self, config: Idefics3VisionConfig, parallel_config: Optional[ParallelismArgs]): + super().__init__() + + assert ( + config.hidden_size % config.num_attention_heads == 0 + ), "hidden_size must be divisible by num_attention_heads" + + self.d_qk = config.hidden_size // config.num_attention_heads + self.d_v = config.hidden_size // config.num_attention_heads + + self.is_using_mup = config.is_using_mup + self.checkpoint_attention = False + + self.dropout = config.attention_dropout + + @checkpoint_method(attr_name="checkpoint_attention") + def forward( + self, + query_states: torch.Tensor, # [batch_size, q_length, n_local_q_heads, inner_dim] + key_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] + value_states: torch.Tensor, # [batch_size, kv_length, n_local_kv_heads, inner_dim] + ): + from flash_attn.flash_attn_interface import flash_attn_func + + # NOTE: this scale is for µTransfer, + # in SP, we use sqrt(1/d_h) + softmax_scale = 1 / query_states.shape[-1] if self.is_using_mup else None + causal = False + dropout_rate = self.dropout if self.training else 0.0 + + attn_output = flash_attn_func( + q=query_states, + k=key_states, + v=value_states, + dropout_p=dropout_rate, + softmax_scale=softmax_scale, + causal=causal, + return_attn_probs=False, + ) + + return attn_output + +class VisionSelfAttention(nn.Module, AttachableStore): + def __init__(self, config: Idefics3VisionConfig, parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, layer_idx: int): + super().__init__() + self.layer_idx = layer_idx + assert ( + config.num_attention_heads % tp_pg.size() == 0 + ), f"Number of attention heads ({config.num_attention_heads}) must be divisible by TP size ({tp_pg.size()})." + try: + assert ( + config.num_key_value_heads % tp_pg.size() == 0 + ), f"Number of key/value heads ({config.num_key_value_heads}) must be divisible by TP size ({tp_pg.size()})." + except AttributeError: + log_rank( + "WARNING: num_key_value_heads not defined, assuming it is equal to num_attention_heads", + logger=logger, + level=logging.WARNING, + rank=0, + ) + # If num_key_value_heads is not defined, we assume that it is equal to num_attention_heads + config.num_key_value_heads = config.num_attention_heads + assert ( + config.num_attention_heads % config.num_key_value_heads == 0 + ), f"Number of attention heads ({config.num_attention_heads}) must be divisible by number of key/value heads ({config.num_key_value_heads})." + self.n_local_q_heads = config.num_attention_heads // tp_pg.size() + self.n_local_kv_heads = config.num_key_value_heads // tp_pg.size() + self.n_repeats = config.num_attention_heads // config.num_key_value_heads + self.is_gqa = config.num_attention_heads != config.num_key_value_heads # Whether we are using GQA or not + self.d_qk = config.hidden_size // config.num_attention_heads + self.d_v = config.hidden_size // config.num_attention_heads + self.d_model = config.hidden_size + self.is_using_mup = config.is_using_mup + + + + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + qkv_contiguous_chunks = ( + config.num_attention_heads * self.d_qk, + config.num_key_value_heads * self.d_qk, + config.num_key_value_heads * self.d_qk + ) + + self.qkv_proj = TensorParallelColumnLinear( + self.d_model, + config.num_attention_heads * self.d_qk + 2 * config.num_key_value_heads * self.d_qk, + pg=tp_pg, + mode=tp_mode, + bias=True, + async_communication=tp_linear_async_communication, + contiguous_chunks=qkv_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather + ) + + self.o_proj = TensorParallelRowLinear( + config.num_attention_heads * self.d_qk, + self.d_model, + pg=tp_pg, + mode=tp_mode, + bias=True, + async_communication=tp_linear_async_communication, + ) + + self.attention = VisionCoreAttention( + config, + parallel_config=parallel_config, + ) + + def forward( + self, + image_hidden_states: torch.Tensor, + sequence_mask + ): + from flash_attn import bert_padding + from flash_attn.flash_attn_interface import ( + flash_attn_varlen_func, + flash_attn_with_kvcache, + ) + + hidden_states = image_hidden_states + + qkv_states = self.qkv_proj( + hidden_states + ) # [batch_size, seq_length, n_local_q_heads * d_qk + 2 * n_local_kv_heads * d_qk] + batch_size, q_length, _ = qkv_states.shape + + + if self.is_gqa: + query_states, key_states, value_states = torch.split( + qkv_states, + [ + self.n_local_q_heads * self.d_qk, + self.n_local_kv_heads * self.d_qk, + self.n_local_kv_heads * self.d_qk, + ], + dim=-1, + ) + + query_states = ( + query_states.contiguous().view(batch_size, q_length, self.n_local_q_heads, self.d_qk) + ) + key_states = ( + key_states.contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk) + ) + value_states = ( + value_states.contiguous().view(batch_size, q_length, self.n_local_kv_heads, self.d_qk) + ) + else: + query_states, key_states, value_states = ( + qkv_states.view(batch_size, q_length, 3, self.n_local_q_heads, self.d_qk) + .permute(2, 0, 1, 3, 4) + .contiguous() + ) # [3, batch_size, seq_length, n_local_q_heads, d_qk] + + + # Apply rotary embeddings to query/key states + # NOTE: The layout is different from models/llama.py which is [batch_size, num_heads, seq_length, d_qk] + # Here it is, [batch_size, seq_length, num_heads, d_qk] + # [2, batch_size, seq_length, num_heads, d_qk] + key_value_states = torch.cat([key_states.unsqueeze(0), value_states.unsqueeze(0)], dim=0) + # [batch_size, seq_length, 2, num_heads, d_qk] + key_value_states = key_value_states.permute(1, 2, 0, 3, 4).contiguous() + + # [batch_size, seq_length, num_heads, d_qk] + key_states, value_states = torch.split(key_value_states, 1, dim=2) + + kv_length = key_states.shape[1] + key_states = key_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_qk) + value_states = value_states.view(batch_size, kv_length, self.n_local_kv_heads, self.d_v) + + attention_output = self.attention( + query_states=query_states, + key_states=key_states, + value_states=value_states, + ) + + attention_output = ( + attention_output.contiguous().view(batch_size, q_length, self.n_local_q_heads * self.d_v) + ) + + output = self.o_proj(attention_output) + + return {"image_hidden_states": output, "sequence_mask": sequence_mask} + + +class VisionMLP(nn.Module): + def __init__( + self, + config: Idefics3VisionConfig, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + super().__init__() + + # TODO @thomasw21: refactor so that we store that default in a single place. + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + first_contiguous_chunks = ( + config.intermediate_size, # shape of up_linear + ) + self.fc1 = TensorParallelColumnLinear( + config.hidden_size, + config.intermediate_size, + pg=tp_pg, + mode=tp_mode, + bias=True, + async_communication=tp_linear_async_communication, + contiguous_chunks=first_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, + ) + self.fc2 = TensorParallelRowLinear( + config.intermediate_size, + config.hidden_size, + pg=tp_pg, + mode=tp_mode, + bias=True, + async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + ) + self.act = torch.compile(lambda x: nn.functional.gelu(x, approximate="tanh")) + + def forward(self, image_hidden_states): # [seq_length, batch_size, hidden_dim] + merged_states = self.fc1(image_hidden_states) + image_hidden_states = self.fc2(self.act(merged_states)) + return {"image_hidden_states": image_hidden_states} + + + +class VisionEncoderLayer(nn.Module): + def __init__( + self, + config: Idefics3VisionConfig, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + layer_id: int, + ): + super().__init__() + + self.self_attn = VisionSelfAttention( + config, + parallel_config=parallel_config, + tp_pg=tp_pg, + layer_idx=layer_id, + ) + + self.layer_norm1 = TritonLayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + ) + + self.layer_norm2 = TritonLayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + ) + + + self.mlp = VisionMLP( + config, + parallel_config=parallel_config, + tp_pg=tp_pg, + ) + + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + self.layer_id = layer_id + + + def forward( + self, + image_hidden_states: Union[torch.Tensor, TensorPointer], + sequence_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + + hidden_states = image_hidden_states + + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + output = self.self_attn(image_hidden_states=hidden_states, sequence_mask=sequence_mask) + hidden_states = output["image_hidden_states"] + + hidden_states = hidden_states + residual + + residual = hidden_states + + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(image_hidden_states=hidden_states)["image_hidden_states"] + hidden_states = hidden_states + residual + + return { + "image_hidden_states": hidden_states, + "sequence_mask": output["sequence_mask"], + } + + +class VisionTransformer(nn.Module): + def __init__( + self, + config: Idefics3VisionConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + self.config = config + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + + self.embeddings = VisionEmbedding( + tp_pg=parallel_context.tp_pg, + config=config, + parallel_config=parallel_config, + ) + + self.encoder = nn.ModuleList( + [ + VisionEncoderLayer( + config=config, + parallel_config=parallel_config, + tp_pg=parallel_context.tp_pg, + layer_id=i, + ) + for i in range(config.num_hidden_layers) + ] + ) + + self.post_layernorm = TritonLayerNorm( + normalized_shape=config.hidden_size, + eps=config.layer_norm_eps, + ) + + def forward( + self, + pixel_values: Union[torch.Tensor, TensorPointer], + pixel_attention_mask: Union[torch.Tensor, TensorPointer] = None, + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + batch_size = pixel_values.size(0) + + batch_size, num_images, num_channels, height, width = pixel_values.size() + + pixel_values = pixel_values.view(batch_size * num_images, num_channels, height, width) + + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask/pP p + pixel_attention_mask = pixel_attention_mask.view( + batch_size * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() + + patch_size = self.config.patch_size + patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) + patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) == patch_size * patch_size).bool() + + pixel_values = pixel_values.bfloat16() + + if patch_attention_mask is None: + patch_size = self.config.patch_size + patch_attention_mask = torch.ones( + ( + batch_size, + pixel_values.size(2) // patch_size, + pixel_values.size(3) // patch_size, + ) + ) + + patch_attention_mask = patch_attention_mask.to(pixel_values.device, dtype=torch.bool) + + image_hidden_states = self.embeddings( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + )["embeddings"] + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + + hidden_encoder_states = { + "image_hidden_states": image_hidden_states, + "sequence_mask": patch_attention_mask, + } + + for i, encoder_layer in enumerate(self.encoder): + hidden_encoder_states = encoder_layer(**hidden_encoder_states) + + image_hidden_states = hidden_encoder_states["image_hidden_states"] + image_hidden_states = self.post_layernorm(input=image_hidden_states) + + return image_hidden_states + + +class Idefics3MLP(nn.Module): + def __init__( + self, + config: Idefics3VisionConfig, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + super().__init__() + + # TODO @thomasw21: refactor so that we store that default in a single place. + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + gate_up_contiguous_chunks = ( + config.intermediate_size, # shape of gate_linear + config.intermediate_size, # shape of up_linear + ) + self.gate_up_proj = TensorParallelColumnLinear( + config.hidden_size, + 2 * config.intermediate_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + contiguous_chunks=gate_up_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, + ) + self.down_proj = TensorParallelRowLinear( + config.intermediate_size, + config.hidden_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + ) + self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act)) + + def forward(self, image_hidden_states): # [seq_length, batch_size, hidden_dim] + merged_states = self.gate_up_proj(image_hidden_states) + image_hidden_states = self.down_proj(self.split_silu_mul(merged_states)) + return {"image_hidden_states": image_hidden_states} + +class Idefics3SimpleMLP(nn.Module): + def __init__( + self, + config: Idefics3Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + super().__init__() + + # TODO @thomasw21: refactor so that we store that default in a single place. + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + hidden_size = config.vision_config.hidden_size + + self.input_size = hidden_size * (config.scale_factor ** 2) + self.output_size = config.text_config.hidden_size + self.proj = nn.Linear( + self.input_size, + self.output_size, + bias = False + ) + + def forward(self, image_hidden_states): + image_hidden_states = self.proj(image_hidden_states) + return {"image_hidden_states": image_hidden_states} + + +class Idefics3Connector(nn.Module): + def __init__( + self, + config: Idefics3Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup + ): + super().__init__() + self.scale_factor = config.scale_factor + self.modality_projector = Idefics3SimpleMLP( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + ) + + def pixel_shuffle(self, x, scale_factor=2): + bsz, seq, embed_dim = x.size() + height = width = int(seq**0.5) + + x = x.view(bsz, height, width, embed_dim) + x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(width / scale_factor), int(height / scale_factor), embed_dim * (scale_factor**2)) + x = x.permute(0, 2, 1, 3) + x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2)) + return x + + def forward(self, image_hidden_states): + image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor) + image_hidden_states = self.modality_projector(image_hidden_states=image_hidden_states)["image_hidden_states"] + return {"image_hidden_states": image_hidden_states} + +class InputsMerger(nn.Module): + def __init__( + self, + config: Idefics3Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup + ): + super().__init__() + self.tp_pg = tp_pg + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + + self.image_token_id = config.image_token_id + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], + inputs_embeds: Union[torch.Tensor, TensorPointer], + image_hidden_states: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + + num_images, _, vision_hidden_size = image_hidden_states.shape + special_image_token_mask = input_ids == self.image_token_id + new_inputs_embeds = inputs_embeds.clone() + reshaped_image_hidden_states = image_hidden_states.view(-1, vision_hidden_size) + new_inputs_embeds[special_image_token_mask] = reshaped_image_hidden_states + + new_inputs_embeds = new_inputs_embeds.transpose(0, 1) + + if self.tp_mode is TensorParallelLinearMode.ALL_REDUCE: + new_inputs_embeds = differentiable_identity(new_inputs_embeds, group=self.tp_pg) + elif self.tp_mode is TensorParallelLinearMode.REDUCE_SCATTER: + new_inputs_embeds = differentiable_scatter(new_inputs_embeds, group=self.tp_pg) + else: + raise ValueError(f"Got unexpected mode: {self.tp_mode}.") + return {"new_inputs_embeds": new_inputs_embeds} + + +class CombinedEmbeddings(nn.Module): + def __init__( + self, + config: Idefics3Config, + parallel_config: Optional[ParallelismArgs], + parallel_context: ParallelContext, + tp_pg: dist.ProcessGroup, + ): + super().__init__() + self.text_embeddings = LlamaEmbeddings( + tp_pg=tp_pg, + config=config.text_config, + parallel_config=parallel_config, + ) + + self.vision_model = VisionTransformer( + config=config.vision_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + ) + + self.connector = Idefics3Connector( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + ) + + self.inputs_merger = InputsMerger( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + ) + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], + input_mask: Union[torch.Tensor, TensorPointer], + pixel_values: Union[torch.Tensor, TensorPointer], + pixel_attention_mask: Union[torch.Tensor, TensorPointer] = None, + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + + llama_output = self.text_embeddings( + input_ids=input_ids, + input_mask=input_mask, + ) + + vision_output = self.vision_model( + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + ) + + connector_output = self.connector( + image_hidden_states=vision_output, + ) + + inputs_merger_output = self.inputs_merger( + input_ids=input_ids, + inputs_embeds=llama_output["input_embeds"], + image_hidden_states=connector_output["image_hidden_states"], + ) + + inputs_merger_output["input_mask"] = input_mask + + return inputs_merger_output + +class Idefics3Model(nn.Module): + def __init__( + self, + config: Idefics3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + super().__init__() + + self.config = config + self.image_token_id = config.image_token_id + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + self.tp_pg = tp_pg + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) + + self.combined_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=CombinedEmbeddings, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "parallel_context": parallel_context, + "tp_pg": tp_pg, + }, + module_input_keys={ + "input_ids", + "input_mask", + "pixel_values", + }, + module_output_keys={ + "new_inputs_embeds", + "input_mask", + } + ) + + self.llama = LlamaModel( + config=config.text_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + p2p = self.p2p + ) + + del self.llama.lm_head + del self.llama.token_position_embeddings + + self.lm_head = PipelineBlock( + p2p=self.p2p, + # Understand that this means that we return sharded logits that are going to need to be gathered + module_builder=TensorParallelColumnLinear, + module_kwargs={ + "in_features": config.text_config.hidden_size, + "out_features": config.text_config.vocab_size, + "pg": parallel_context.tp_pg, + "bias": False, + # TODO @thomasw21: refactor so that we store that default in a single place. + "mode": self.tp_mode, + "async_communication": tp_linear_async_communication, + "tp_recompute_allgather": parallel_config.tp_recompute_allgather, + }, + module_input_keys={"x"}, + module_output_keys={"logits"}, + ) + + + self.cast_to_fp32 = PipelineBlock( + p2p=self.llama.p2p, + module_builder=lambda: lambda x: x.float(), + module_kwargs={}, + module_input_keys={"x"}, + module_output_keys={"output"}, + ) + + def forward_with_hidden_states( + self, + input_ids: Union[torch.Tensor, TensorPointer], + input_mask: Union[torch.Tensor, TensorPointer], + pixel_values: Union[torch.Tensor, TensorPointer], + pixel_attention_mask: Union[torch.Tensor, TensorPointer] = None, + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + # Calculated combined visual and textual embeddings + embeds = self.combined_embeddings( + input_ids=input_ids, + input_mask=input_mask, + pixel_values=pixel_values, + ) + + hidden_encoder_states = { + "hidden_states": embeds["new_inputs_embeds"], + "sequence_mask": embeds["input_mask"], + } + + for encoder_block in self.llama.decoder: + hidden_encoder_states = encoder_block(**hidden_encoder_states) + + hidden_states = self.llama.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] + + sharded_logits = self.lm_head(x=hidden_states)["logits"] + + fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] + + return fp32_sharded_logits, hidden_states + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], + input_mask: Union[torch.Tensor, TensorPointer], + pixel_values: Union[torch.Tensor, TensorPointer] = None, + pixel_attention_mask: Union[torch.Tensor, TensorPointer] = None, + ): + return self.forward_with_hidden_states( + input_ids=input_ids, input_mask=input_mask, pixel_values=pixel_values, pixel_attention_mask=pixel_attention_mask + )[0] + + def get_block_compute_costs_vision(self): + config = self.config.vision_config + d_ff = config.intermediate_size + d_qkv = config.hidden_size // config.num_attention_heads + + return { + CombinedEmbeddings: 4 * config.num_attention_heads * d_qkv * config.hidden_size + + 3 * d_ff * config.hidden_size * config.num_hidden_layers + } + + def get_block_compute_costs(self): + llama_cost = self.llama.get_block_compute_costs() + costs = self.get_block_compute_costs_vision() + + costs[LlamaDecoderLayer] = llama_cost[LlamaDecoderLayer] + costs[TensorParallelColumnLinear] = self.config.text_config.hidden_size * self.config.text_config.vocab_size + + return costs + + +class Idefics3ForTraining(NanotronModel): + def __init__( + self, + config: Idefics3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: Optional[RandomStates] = None, + ): + super().__init__() + + self.model = Idefics3Model( + config=config, + parallel_context=parallel_context, + parallel_config=parallel_config, + tp_pg=parallel_context.tp_pg + ) + + self.loss = PipelineBlock( + p2p=self.model.llama.p2p, + module_builder=Loss, + module_kwargs={"tp_pg": parallel_context.tp_pg}, + module_input_keys={ + "sharded_logits", + "label_ids", + "label_mask", + }, + module_output_keys={"loss"}, + ) + self.parallel_context = parallel_context + self.config = config + self.parallel_config = parallel_config + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], + input_mask: Union[torch.Tensor, TensorPointer], + pixel_values: Union[torch.Tensor, TensorPointer], + label_ids: Union[torch.Tensor, TensorPointer], + label_mask: Union[torch.Tensor, TensorPointer], + pixel_attention_mask: Union[torch.Tensor, TensorPointer] = None, + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + outputs = self.model( + input_ids=input_ids, + input_mask=input_mask, + pixel_values=pixel_values, + pixel_attention_mask=pixel_attention_mask, + ) + + loss = self.loss( + sharded_logits=outputs, + label_ids=label_ids, + label_mask=label_mask, + )["loss"] + + return {"loss": loss} + + + @torch.no_grad() + def init_model_randomly(self, config: Config): + """Initialize model parameters randomly. + Note: + Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` + """ + init_method = config.model.init_method + if isinstance(init_method, RandomInit): + parametrizator_cls = StandardParametrizator + elif isinstance(init_method, SpectralMupInit): + parametrizator_cls = SpectralMupParametrizator + else: + raise ValueError(f"Unknown init method {init_method}") + + parametrizator = parametrizator_cls(config=config.model) + + log_rank( + f"Parametrizing model parameters using {parametrizator.__class__.__name__}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + model = self + initialized_parameters = set() + # Handle tensor parallelism + module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} + # Fix the root_model + module_id_to_prefix[id(model)] = "" + + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) + + module_name, param_name = param_name.rsplit(".", 1) + + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + module = model.get_submodule(module_name) + parametrizator.parametrize(param_name, module) + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + + assert initialized_parameters == { + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + if param.is_tied + else name + for name, param in model.named_parameters() + }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" + + def get_block_compute_costs(self): + return self.model.get_block_compute_costs() + + def get_embeddings_lm_head_tied_names(self): + """Get the names of the tied embeddings and lm_head weights""" + if self.config.text_config.tie_word_embeddings is True: + return ["model.llama.token_position_embeddings.pp_block.token_embedding.weight", "model.llama.lm_head.pp_block.weight"] + else: + return [] + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + return self.model.llama.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size) + + +class VisionTransformerNanotron(NanotronModel): + def __init__( + self, + config: Idefics3VisionConfig, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: Optional[RandomStates] = None, + ): + super().__init__() + + p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) + + self.model = VisionTransformer( + config=config, + p2p=p2p, + parallel_context=parallel_context, + parallel_config=parallel_config, + ) + + self.parallel_context = parallel_context + self.config = config + self.parallel_config = parallel_config + + def forward( + self, + pixel_values: Union[torch.Tensor, TensorPointer], + patch_attention_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + return self.model(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + + @torch.no_grad() + def init_model_randomly(self, config: Config): + """Initialize model parameters randomly. + Note: + Layernorm weight all 0 or 1 depending on `apply_layernorm_1p` + """ + init_method = config.model.init_method + if isinstance(init_method, RandomInit): + parametrizator_cls = StandardParametrizator + elif isinstance(init_method, SpectralMupInit): + parametrizator_cls = SpectralMupParametrizator + else: + raise ValueError(f"Unknown init method {init_method}") + + parametrizator = parametrizator_cls(config=config.model) + + log_rank( + f"Parametrizing model parameters using {parametrizator.__class__.__name__}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + model = self + initialized_parameters = set() + # Handle tensor parallelism + module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} + # Fix the root_model + module_id_to_prefix[id(model)] = "" + + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) + + module_name, param_name = param_name.rsplit(".", 1) + + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + module = model.get_submodule(module_name) + parametrizator.parametrize(param_name, module) + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + + assert initialized_parameters == { + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + if param.is_tied + else name + for name, param in model.named_parameters() + }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" + + def get_block_compute_costs(self): + return self.model.get_block_compute_costs() + + def get_embeddings_lm_head_tied_names(self): + """Get the names of the tied embeddings and lm_head weights""" + return [] + +@torch.jit.script +def masked_mean(loss, label_mask, dtype): + # type: (Tensor, Tensor, torch.dtype) -> Tensor + return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + +class Loss(nn.Module): + def __init__(self, tp_pg: dist.ProcessGroup): + super().__init__() + self.tp_pg = tp_pg + + def forward( + self, + sharded_logits: torch.Tensor, # [seq_length, batch_size, logits] + label_ids: torch.Tensor, # [batch_size, seq_length] + label_mask: torch.Tensor, # [batch_size, seq_length] + ) -> Dict[str, torch.Tensor]: + # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. + # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 + + loss = sharded_cross_entropy( + sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float + ).transpose(0, 1) + # TODO @thomasw21: It's unclear what kind of normalization we want to do. + loss = masked_mean(loss, label_mask, dtype=torch.float) + # I think indexing causes a sync we don't actually want + # loss = loss[label_mask].sum() + return {"loss": loss} \ No newline at end of file diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 28a2e30f..96dd24a6 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -676,11 +676,15 @@ def __init__( config: LlamaConfig, parallel_context: ParallelContext, parallel_config: Optional[ParallelismArgs], + p2p: Optional[P2P] = None ): super().__init__() # Declare all the nodes - self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) + if p2p is None: + p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) + + self.p2p = p2p self.config = config self.parallel_config = parallel_config self.parallel_context = parallel_context @@ -767,17 +771,17 @@ def forward_with_hidden_states( ): # all tensors are optional as most ranks don't need anything from the dataloader. - output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) + input_embeds = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"] hidden_encoder_states = { - "hidden_states": output["input_embeds"], + "hidden_states": input_embeds, "sequence_mask": input_mask, } for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] - + sharded_logits = self.lm_head(x=hidden_states)["logits"] fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] diff --git a/src/nanotron/modular_dataloader/__init__.py b/src/nanotron/modular_dataloader/__init__.py new file mode 100644 index 00000000..15ae9ce0 --- /dev/null +++ b/src/nanotron/modular_dataloader/__init__.py @@ -0,0 +1,18 @@ +from .caption import ( + CaptionSampleEncoder, + PreprocessedCollator, + PreprocessedSampleEncoder, + ProcessSampleEncoder, + SingleImageBatchEncoder, +) + +SAMPLE_ENCODERS = { + "caption_simple": CaptionSampleEncoder, + "caption_process": ProcessSampleEncoder, + "caption_preprocessed": PreprocessedSampleEncoder +} + +BATCH_ENCODERS = { + "single_image": SingleImageBatchEncoder, + "caption_preprocessed": PreprocessedCollator +} diff --git a/src/nanotron/modular_dataloader/base.py b/src/nanotron/modular_dataloader/base.py new file mode 100644 index 00000000..93a92f02 --- /dev/null +++ b/src/nanotron/modular_dataloader/base.py @@ -0,0 +1,43 @@ +# Inspired by https://github.com/NVIDIA/Megatron-Energon +from abc import ABC +from typing import Any, Dict, Generic, List, TypeVar, Union + +import torch +from transformers import AutoProcessor + +from nanotron.parallel.context import ParallelContext +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer + +T_encoded_sample = TypeVar("T_encoded_sample") + + +class SampleEncoder(ABC, Generic[T_encoded_sample]): + """ + Processes a single sample. E.g. formats caption text. + """ + + processor: AutoProcessor + sequence_length: int + + def encode(self, sample: Dict[str, Any]) -> T_encoded_sample: + """ + Encode a sample. + """ + raise NotImplementedError + + +class BatchEncoder(ABC, Generic[T_encoded_sample]): + """ + Collates and encodes a batch of samples. + """ + input_pp_rank: int + output_pp_rank: int + parallel_context: ParallelContext + processor: AutoProcessor + sequence_length: int + + def encode(self, batch: List[T_encoded_sample]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + """ + Collate and encode a batch of samples. + """ + raise NotImplementedError diff --git a/src/nanotron/modular_dataloader/caption.py b/src/nanotron/modular_dataloader/caption.py new file mode 100644 index 00000000..305dea5a --- /dev/null +++ b/src/nanotron/modular_dataloader/caption.py @@ -0,0 +1,273 @@ +import io +from dataclasses import dataclass +from typing import Any, Dict, List, Union + +import numpy as np +import torch +from PIL import Image +from transformers import AutoProcessor + +from nanotron import distributed as dist +from nanotron.modular_dataloader.base import BatchEncoder, SampleEncoder +from nanotron.parallel.context import ParallelContext +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer + + +@dataclass +class FormattedTextSample: + """ + A sample with formatted text for processing. + :param text: Text should include a single instance. + """ + text: str + image: bytes + +@dataclass +class ProcessedSample: + """ + Tokenized text and image. + """ + input_ids: torch.Tensor + pixel_values: torch.Tensor + +@dataclass +class CaptionSampleEncoder(SampleEncoder[FormattedTextSample]): + """ + Sample encoder for caption samples. + """ + + text_field: str = "caption" + image_field: str = "jpg" + image_token: str = "" + + def encode(self, sample: Dict[str, Any]) -> FormattedTextSample: + """ + Encode a caption sample. + """ + return FormattedTextSample(text= f"{self.image_token}{sample[self.text_field]}", image=sample[self.image_field]) + + +@dataclass +class ProcessSampleEncoder(SampleEncoder[ProcessedSample]): + """ + Sample encoder for caption samples that also applies processor to it. + """ + + processor: AutoProcessor + sequence_length: int + text_field: str = "caption" + image_field: str = "jpg" + + def encode(self, sample: Dict[str, Any]) -> ProcessedSample: + """ + Encode a caption sample. + """ + image_token = self.processor.image_token.content + + text = f"{image_token}{sample[self.text_field]}" + image = sample[self.image_field] + image_arr = byte_img_to_array(image) + inputs = self.processor(text=text, images=[image_arr], return_tensors="pt", padding="longest", max_length=self.sequence_length + 1, truncation=True) + + return ProcessedSample( + input_ids=inputs["input_ids"], + pixel_values=inputs["pixel_values"] + ) + +@dataclass +class PreprocessedSampleEncoder(SampleEncoder[ProcessedSample]): + """ + Sample encoder for caption samples that also applies processor to it. + """ + + processor: AutoProcessor + sequence_length: int + + def encode(self, sample: Dict[str, Any]) -> ProcessedSample: + """ + Encode a caption sample. + """ + input_ids = torch.tensor(sample["input_ids"], dtype=torch.long) + if input_ids.dim() == 1: + input_ids = input_ids.unsqueeze(0) + + pixel_values = torch.tensor(sample["pixel_values"], dtype=torch.float32) + pixel_shape = sample["pixel_shape"] + + pixel_values = pixel_values.reshape(pixel_shape) + + if len(input_ids.shape) == 4: + input_ids = input_ids.unsqueeze(0) + + return ProcessedSample( + input_ids=input_ids, + pixel_values=pixel_values + ) + + +def byte_img_to_array(bimg): + imageStream = io.BytesIO(bimg) + imageFile = Image.open(imageStream) + img_arr = np.array(imageFile) + if len(img_arr.shape) == 2: # Grayscale image + img_arr = np.expand_dims(img_arr, axis=-1) # Add a channel dimension + # imageFile.show() + return img_arr + +@dataclass +class SingleImageBatchEncoder(BatchEncoder[FormattedTextSample]): + """ + Expects an Idefics3 compatible processor. Pads texts and splits images. + Only works for a single image per caption. + - input_pp_rank: Discards last input id token. Returns input_ids, input_mask, pixel_values + - output_pp_rank: Discards first label id token. Returns label_ids, label_mask + - other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data. + """ + input_pp_rank: int + output_pp_rank: int + parallel_context: ParallelContext + processor: AutoProcessor + sequence_length: int + + def encode(self, batch: List[FormattedTextSample]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + """ + Encode a batch of caption samples. + """ + current_pp_rank = dist.get_rank(self.parallel_context.pp_pg) + if current_pp_rank not in [ + self.input_pp_rank, + self.output_pp_rank, + ]: + assert len(batch) == 0 + return { + "input_ids": TensorPointer(group_rank=self.input_pp_rank), + "input_mask": TensorPointer(group_rank=self.input_pp_rank), + "label_ids": TensorPointer(group_rank=self.output_pp_rank), + "label_mask": TensorPointer(group_rank=self.output_pp_rank), + "pixel_values": TensorPointer(group_rank=self.input_pp_rank), + } + + result = {} + + result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank) + result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank) + result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank) + result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank) + result["pixel_values"] = TensorPointer(group_rank=self.input_pp_rank) + + texts = [sample.text for sample in batch] + images = [[byte_img_to_array(sample.image)] for sample in batch] + + inputs = self.processor(text=texts, images=images, return_tensors="pt", padding="longest", max_length=self.sequence_length + 1, truncation=True) + + + if current_pp_rank == self.input_pp_rank: + result["input_ids"] = inputs["input_ids"][:, :-1] + result["input_mask"] = inputs["attention_mask"] + result["pixel_values"] = inputs["pixel_values"] + + if current_pp_rank == self.output_pp_rank: + result["label_ids"] = inputs["input_ids"][:, 1:] + result["label_mask"] = inputs["input_ids"][:, 1:] < self.processor.tokenizer.vocab_size + + return result + + +@dataclass +class PreprocessedCollator(BatchEncoder[ProcessedSample]): + """ + Collates and encodes a batch of samples. + """ + input_pp_rank: int + output_pp_rank: int + parallel_context: ParallelContext + processor: AutoProcessor + sequence_length: int + padding_side: str = "right" + + def encode(self, batch: List[ProcessedSample]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + """ + Collate and encode a batch of samples. + """ + current_pp_rank = dist.get_rank(self.parallel_context.pp_pg) + if current_pp_rank not in [ + self.input_pp_rank, + self.output_pp_rank, + ]: + assert len(batch) == 0 + return { + "input_ids": TensorPointer(group_rank=self.input_pp_rank), + "input_mask": TensorPointer(group_rank=self.input_pp_rank), + "label_ids": TensorPointer(group_rank=self.output_pp_rank), + "label_mask": TensorPointer(group_rank=self.output_pp_rank), + "pixel_values": TensorPointer(group_rank=self.input_pp_rank), + } + + result = {} + + result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank) + result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank) + result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank) + result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank) + result["pixel_values"] = TensorPointer(group_rank=self.input_pp_rank) + + def pad_tokens(inputs = True): + max_seq_len = max(x.input_ids.shape[1] for x in batch) - 1 + # Make it divisible by tp group size + gs = self.parallel_context.tp_pg.size() + max_seq_len = max_seq_len + (gs - max_seq_len % gs) % gs + + padded_token_ids = [] + padded_attention_mask = [] + + for example in batch: + token_ids = example.input_ids + if inputs: + token_ids = token_ids[:, :-1] + else: + token_ids = token_ids[:, 1:] + + current_seq_len = token_ids.shape[1] + padding_length = max_seq_len - current_seq_len + + if inputs: + attention_mask = torch.ones_like(token_ids) + else: + attention_mask = token_ids < self.processor.tokenizer.vocab_size + + padding = torch.zeros((token_ids.shape[0], padding_length), dtype=token_ids.dtype, device=token_ids.device) + + if self.padding_side == "right": + padded_token_ids.append(torch.cat([token_ids, padding], dim=1)) + padded_attention_mask.append(torch.cat([attention_mask, padding], dim=1)) + + elif self.padding_side == "left": + padded_token_ids.append(torch.cat([padding, token_ids], dim=1)) + padded_attention_mask.append(torch.cat([padding, attention_mask], dim=1)) + + padded_token_ids = torch.cat(padded_token_ids, dim=0) + padded_attention_mask = torch.cat(padded_attention_mask, dim=0) + + return padded_token_ids, padded_attention_mask + + + + if current_pp_rank == self.input_pp_rank: + max_n_patches = max(x.pixel_values.shape[1] for x in batch) + padded_pixel_values = [] + + for example in batch: + pixel_values = example.pixel_values + current_patches = pixel_values.shape[1] + + # Pad the pixel_values to have max_n_patches along dimension 1 (patches) + padding = torch.zeros((1, max_n_patches - current_patches) + pixel_values.shape[2:], dtype=pixel_values.dtype, device=pixel_values.device) + padded_pixel_values.append(torch.cat([pixel_values, padding], dim=1)) + + padded_pixel_values = torch.cat(padded_pixel_values, dim=0) + result["pixel_values"] = padded_pixel_values + result["input_ids"], result["input_mask"] = pad_tokens(inputs=True) + if current_pp_rank == self.output_pp_rank: + result["label_ids"], result["label_mask"] = pad_tokens(inputs=False) + + return result diff --git a/src/nanotron/modular_dataloader/iterable.py b/src/nanotron/modular_dataloader/iterable.py new file mode 100644 index 00000000..0d126547 --- /dev/null +++ b/src/nanotron/modular_dataloader/iterable.py @@ -0,0 +1,126 @@ +import itertools +from multiprocessing.pool import ThreadPool +from typing import Any, Dict, List + +from datasets import IterableDataset +from datasets.distributed import split_dataset_by_node +from torch.utils.data import DataLoader + +from nanotron import distributed as dist +from nanotron.modular_dataloader.base import BatchEncoder, SampleEncoder +from nanotron.parallel.context import ParallelContext +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer + + +def from_columns(batch: Dict[str, List]): + return [{k: batch[k][i] for k in batch} for i in range(len(batch[list(batch.keys())[0]]))] + +class EmptyIterableDataset(IterableDataset): + def __iter__(self): + return iter([]) + + def __len__(self): + return 0 + + + +class EmptyDataset(IterableDataset): + def __init__(self, input_pp_rank: int, output_pp_rank: int, num_shards: int): + super().__init__() + self.input_pp_rank = input_pp_rank + self.output_pp_rank = output_pp_rank + self._num_shards = num_shards + + @property + def num_shards(self): + return self._num_shards + + def __iter__(self): + return itertools.cycle([ + { + "input_ids": TensorPointer(group_rank=self.input_pp_rank), + "input_mask": TensorPointer(group_rank=self.input_pp_rank), + "label_ids": TensorPointer(group_rank=self.output_pp_rank), + "label_mask": TensorPointer(group_rank=self.output_pp_rank), + "pixel_values": TensorPointer(group_rank=self.input_pp_rank), + } + ]) + +def get_train_dataloader( + train_dataset: IterableDataset, + sample_encoder: SampleEncoder, + batch_encoder: BatchEncoder, + parallel_context: ParallelContext, + input_pp_rank: int, + output_pp_rank: int, + micro_batch_size: int, + sample_encoding_batch: int, + consumed_train_samples: int, + batch_encoding_batch: int, + seed_worker: int, + sample_encoding_workers: int, + batch_encoding_workers: int, + drop_last: bool = True, +): + if not isinstance(train_dataset, IterableDataset): + raise ValueError("Dataset should be a datasets.IterableDataset") + + if dist.get_rank(parallel_context.pp_pg) not in [input_pp_rank, output_pp_rank]: + + def generator(): + while True: + yield { + "input_ids": TensorPointer(group_rank=input_pp_rank), + "input_mask": TensorPointer(group_rank=input_pp_rank), + "label_ids": TensorPointer(group_rank=output_pp_rank), + "label_mask": TensorPointer(group_rank=output_pp_rank), + "pixel_values": TensorPointer(group_rank=input_pp_rank), + } + + empty_dataset = IterableDataset.from_generator(generator) + + return DataLoader( + empty_dataset, + batch_size=1, + num_workers=0, + collate_fn=lambda x: x[0], + ) + + train_dataset = split_dataset_by_node(train_dataset, rank=parallel_context.dp_pg.rank(), world_size=parallel_context.dp_pg.size()) + train_dataset = train_dataset.shuffle(seed=seed_worker) + + def encode_samples_batched(batch: Dict[str, List]): + batch = from_columns(batch) + + with ThreadPool(sample_encoding_workers) as sample_worker_pool: + encoded_batch = sample_worker_pool.map(sample_encoder.encode, batch) + + return {"sample_encoded": encoded_batch} + + train_dataset = train_dataset.map(encode_samples_batched, batched=True, remove_columns=train_dataset.column_names, batch_size=sample_encoding_batch) + + + def collate_fn(batch: List[Dict[str, Any]]): + batch = [x["sample_encoded"] for x in batch] + return batch_encoder.encode(batch) + + if consumed_train_samples > 0: + dp_size = parallel_context.dp_pg.size() + skip_batches = consumed_train_samples // dp_size + + train_dataset = train_dataset.skip(skip_batches) + + dataloader = DataLoader( + train_dataset, + batch_size=micro_batch_size, + num_workers=1, + collate_fn=collate_fn, + ) + + return dataloader + + + + + + diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index bd41347a..912b1065 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -125,6 +125,81 @@ def backward(ctx, grad_output): return DifferentiableAllGather.apply(grad_output, group), None +class DifferetiableScatter(torch.autograd.Function): + + @staticmethod + def forward(ctx, tensor, group: Optional[ProcessGroup]): + ctx.group = group + + if group.size() == 1: + return tensor + + # TODO: scatter along another dimension + unsharded_batch_size, *rest_size = tensor.shape + if group is None: + group = torch_dist.distributed_c10d._get_default_group() + assert unsharded_batch_size % group.size() == 0 + + tensor = tensor.contiguous() + + sharded_tensor = torch.empty( + unsharded_batch_size // group.size(), + *rest_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=False, + ) + + if group.rank() == 0: + tensor_list = list(torch.split(tensor, unsharded_batch_size // group.size(), dim=0)) + else: + tensor_list = None + + dist.scatter(sharded_tensor, tensor_list, src=0, group=group) + return sharded_tensor + + @staticmethod + def backward(ctx, grad_output): + group = ctx.group + return DifferentiableAllGather.apply(grad_output, group), None + + +class DifferentiableReduceScatterAvg(torch.autograd.Function): + """Reduce scatter in a differentiable fashion""" + + @staticmethod + def forward(ctx, tensor, group: Optional[ProcessGroup]): + ctx.group = group + + if group.size() == 1: + return tensor + + # TODO @thomasw21: shard along another dimension + unsharded_batch_size, *rest_size = tensor.shape + if group is None: + group = torch_dist.distributed_c10d._get_default_group() + assert unsharded_batch_size % group.size() == 0 + + # TODO @thomasw21: Collectives seem to require tensors to be contiguous + # https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305 + tensor = tensor.contiguous() + + sharded_tensor = torch.empty( + unsharded_batch_size // group.size(), + *rest_size, + device=tensor.device, + dtype=tensor.dtype, + requires_grad=False, + ) + dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.AVG) + return sharded_tensor + + @staticmethod + def backward(ctx, grad_output): + group = ctx.group + return DifferentiableAllGather.apply(grad_output, group), None + + # ----------------- # Helper functions. # ----------------- @@ -144,3 +219,6 @@ def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): def differentiable_reduce_scatter_sum(tensor, group: Optional[ProcessGroup] = None): return DifferentiableReduceScatterSum.apply(tensor, group) + +def differentiable_scatter(tensor, group: Optional[ProcessGroup] = None): + return DifferentiableReduceScatterAvg.apply(tensor, group) diff --git a/src/nanotron/serialize/metadata.py b/src/nanotron/serialize/metadata.py index 0d8708f9..6608ba99 100644 --- a/src/nanotron/serialize/metadata.py +++ b/src/nanotron/serialize/metadata.py @@ -66,7 +66,12 @@ class CheckpointMetadata: metas: TrainingMetadata custom_metas: Optional[Dict[str, Any]] = None - +def to_int(value: str) -> int: + try: + return int(value) + except ValueError: + return int(value.split("(")[1].split(")")[0]) + @dataclasses.dataclass class TensorMetadata: # Mandatory for checkpoint version higher than 1.2 @@ -81,7 +86,7 @@ class TensorMetadata: cast=[Version], type_hooks={ Tuple[SlicesPair, ...]: SlicesPair.tuple_from_str, - Tuple[int, ...]: lambda x: torch.Size(int(size) for size in x.strip("()").split(",") if size), + Tuple[int, ...]: lambda x: torch.Size(to_int(size) for size in x.strip("()").split(",") if size), }, strict=True, ) diff --git a/src/nanotron/serialize/weights.py b/src/nanotron/serialize/weights.py index 96d2be4c..aaf70b1c 100644 --- a/src/nanotron/serialize/weights.py +++ b/src/nanotron/serialize/weights.py @@ -106,6 +106,7 @@ def save_weights(model: nn.Module, parallel_context: ParallelContext, root_folde ) raise e else: + print(f"Parameters {name} should be a NanotronParameter") raise NotImplementedError("Parameters are required to be NanotronParameter") diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 21251a32..66ae8f51 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -58,6 +58,7 @@ from nanotron.models.base import check_model_has_grad from nanotron.models.llama import LlamaForTraining, RotaryEmbedding from nanotron.models.starcoder2 import Starcoder2ForTraining +from nanotron.models.idefics import Idefics3ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp @@ -103,6 +104,7 @@ CONFIG_TO_MODEL_CLASS = { "LlamaConfig": LlamaForTraining, "Starcoder2Config": Starcoder2ForTraining, + "Idefics3Config": Idefics3ForTraining } try: @@ -1008,4 +1010,4 @@ def mark_unsharded_params_as_tied_across_expert( tie_parameters( root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=reduce_op - ) + ) \ No newline at end of file diff --git a/tools/idefics3/README.md b/tools/idefics3/README.md new file mode 100644 index 00000000..6adebcc2 --- /dev/null +++ b/tools/idefics3/README.md @@ -0,0 +1,39 @@ +# Idefics3 Weight conversion tool +## Conversion +This directory contains the scripts to convert the Idefics3 checkpoints from HuggingFace to Nanotron and vice versa. Nanotron to HF conversion requires `accelerate`: `pip install accelerate` (otherwise empty HF model initialization is very slow). + +- Convert from HuggingFace to Nanotron + +`HF_HUB_ENABLE_HF_TRANSFER=1 torchrun --nproc-per-node 1 tools/idefics3/convert_hf_to_nanotron.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Idefics3-8B-Llama3 --pretrained-model-name-or-path HuggingFaceM4/Idefics3-8B-Llama3` +- Convert from Nanotron to HuggingFace + +`torchrun --nproc-per-node 1 tools/idefics3/convert_nanotron_to_hf.py --huggingface-checkpoint-path idefics3_ckpt --pretrained-model-name-or-path nanotron_checkpoints/Nanotron-Idefics3-8B-Llama3` + +- Combine custom HF Llama and CLIP models in a Idefics3-like approach and save it as a Nanotron model: + +`torchrun --nproc-per-node 1 tools/idefics3/build_nanotron_from_hf.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Idefics3-8B-Llama3 --pretrained-model-name-or-path-llama3 meta-llama/Meta-Llama-3-8B-Instruct --pretrained-model-name-or-path-siglip google/siglip-so400m-patch14-384` + +In summary, we will do the following: +- Initialize the HuggingFace model with the pretrained weights. The model definition is [here](https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics3/modeling_idefics3.py). +- Initialize a Nanotron model with empty weights. The model definition is [here](https://github.com/huggingface/nanotron/blob/main/src/nanotron/models/idefics.py). +- Copy the parameters layer by layer from one model to the other. +- Store the Nanotron model. + +When comparing the HuggingFace implementation with the Nanotron implementation, the main difference lies in the Q, K & V matrices and in the MLP projections. In the HuggingFace implementation, these matrices are separated. It is crucial to pay attention to these details to convert the models correctly. + +To perform the conversion, we will need at least **1 GPU**, although the operations will be carried out on the **CPU**. We will convert the models with a parallel configuration of DP = PP = TP = 1, but it should be noted that the checkpoints generated by Nanotron are topology agnostic. + +## Simple evaluation +A simple sanity check for conversion being correct can be made using `loss_on_captions` scripts that run the model on ~100 samples from an image captioning dataset. + +- Check HF model performance: + +`torchrun --nproc-per-node 1 tools/idefics3/loss_on_captions_hf.py --pretrained-model-name-or-path HuggingFaceM4/Idefics3-8B-Llama3` + +- Check Nanotron model performance: + +`torchrun --nproc-per-node 2 tools/idefics3/loss_on_captions_nanotron.py --tp 2 --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Idefics3-8B-Llama3` + +Could also load a saved dataset from the drive: + +`torchrun --nproc-per-node 2 tools/idefics3/loss_on_captions_nanotron.py --tp 2 --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Idefics3-8B-Llama3 --dataset-path "../datasets/ny_captions.hf"` \ No newline at end of file diff --git a/tools/idefics3/build_nanotron_from_hf.py b/tools/idefics3/build_nanotron_from_hf.py new file mode 100644 index 00000000..548b36cc --- /dev/null +++ b/tools/idefics3/build_nanotron_from_hf.py @@ -0,0 +1,526 @@ +""" +HF_HUB_ENABLE_HF_TRANSFER=1 torchrun --nproc-per-node 1 tools/idefics3/build_nanotron_from_hf.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Idefics3-8B-Llama3 --pretrained-model-name-or-path-llama3 meta-llama/Meta-Llama-3-8B-Instruct --pretrained-model-name-or-path-siglip google/siglip-so400m-patch14-384 +""" +import sys +sys.path.append('.venv/lib/python3.10/site-packages') + +import argparse +from dataclasses import asdict +import json +from pathlib import Path +import torch +from tqdm import tqdm +import yaml +from nanotron import logging +from nanotron.config.config import Config, GeneralArgs, LoggingArgs, ModelArgs, TokenizerArgs +from nanotron.config.models_config import ExistingCheckpointInit, Idefics3VisionConfig, Idefics3Config +from nanotron.config.parallelism_config import ParallelismArgs +from nanotron.logging import log_rank, set_ranks_logging_level +from nanotron.config.models_config import LlamaConfig as LlamaConfigNanotron +from nanotron.models.base import build_model +from nanotron.models.idefics import Idefics3ForTraining, Idefics3Model, VisionTransformer +from nanotron.parallel.context import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.serialize.weights import save_weights +from nanotron.trainer import mark_tied_parameters + +from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel + +logger = logging.get_logger(__name__) + +DEVICE = torch.device("cpu") +TORCH_DTYPE = torch.bfloat16 + + +def copy_weights_from_hf_to_nanotron_llama(nanotron_model, hf_model, nanotron_config, + additional_vocab_size): + nanotron_llama_config = nanotron_config.text_config + # Copy params from HF to Nanotron + log_rank("Copying weights from HF model to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + # Token embeddings + log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0) + + # Decoder layers + for i in tqdm( + range(nanotron_llama_config.num_hidden_layers), + desc="Copying Hidden Layers", + total=nanotron_llama_config.num_hidden_layers, + ): + # Input layer norm + assert ( + hf_model.layers[i].input_layernorm.weight.shape + == nanotron_model.decoder[i].pp_block.input_layernorm.weight.shape + ) + with torch.no_grad(): + nanotron_model.decoder[i].pp_block.input_layernorm.weight.copy_( + hf_model.layers[i].input_layernorm.weight + ) + + # Self attn + ## QKV + tmp_qkv_proj = torch.cat( + [ + hf_model.layers[i].self_attn.q_proj.weight, + hf_model.layers[i].self_attn.k_proj.weight, + hf_model.layers[i].self_attn.v_proj.weight, + ], + dim=0, + ) + assert tmp_qkv_proj.shape == nanotron_model.decoder[i].pp_block.attn.qkv_proj.weight.shape + with torch.no_grad(): + nanotron_model.decoder[i].pp_block.attn.qkv_proj.weight.copy_(tmp_qkv_proj) + + ## O + assert ( + hf_model.layers[i].self_attn.o_proj.weight.shape + == nanotron_model.decoder[i].pp_block.attn.o_proj.weight.shape + ) + with torch.no_grad(): + nanotron_model.decoder[i].pp_block.attn.o_proj.weight.copy_( + hf_model.layers[i].self_attn.o_proj.weight + ) + + # MLP + ## Gate Up Proj + tmp_gate_up_proj = torch.cat( + [ + hf_model.layers[i].mlp.gate_proj.weight, + hf_model.layers[i].mlp.up_proj.weight, + ], + dim=0, + ) + + assert tmp_gate_up_proj.shape == nanotron_model.decoder[i].pp_block.mlp.gate_up_proj.weight.shape + with torch.no_grad(): + nanotron_model.decoder[i].pp_block.mlp.gate_up_proj.weight.copy_(tmp_gate_up_proj) + + ## Down Proj + assert ( + hf_model.layers[i].mlp.down_proj.weight.shape + == nanotron_model.decoder[i].pp_block.mlp.down_proj.weight.shape + ) + with torch.no_grad(): + nanotron_model.decoder[i].pp_block.mlp.down_proj.weight.copy_( + hf_model.layers[i].mlp.down_proj.weight + ) + + # Post attn layer norm + assert ( + hf_model.layers[i].post_attention_layernorm.weight.shape + == nanotron_model.decoder[i].pp_block.post_attention_layernorm.weight.shape + ) + with torch.no_grad(): + nanotron_model.decoder[i].pp_block.post_attention_layernorm.weight.copy_( + hf_model.layers[i].post_attention_layernorm.weight + ) + + # Last layer norm + log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.final_layer_norm.pp_block.weight.shape == hf_model.norm.weight.shape + with torch.no_grad(): + nanotron_model.final_layer_norm.pp_block.weight.copy_(hf_model.norm.weight) + +def nanotron_config_from_hf_config_llama(hf_config, additional_vocab_size=3): + return LlamaConfigNanotron( + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + hidden_act=hf_config.hidden_act, + hidden_size=hf_config.hidden_size, + initializer_range=hf_config.initializer_range, + intermediate_size=hf_config.intermediate_size, + is_llama_config=True, + max_position_embeddings=hf_config.max_position_embeddings, + num_attention_heads=hf_config.num_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + num_key_value_heads=hf_config.num_key_value_heads, + pad_token_id=None, + pretraining_tp=hf_config.pretraining_tp, + rms_norm_eps=hf_config.rms_norm_eps, + rope_scaling=hf_config.rope_scaling, + rope_theta=hf_config.rope_theta, + rope_interleaved=False, + tie_word_embeddings=hf_config.tie_word_embeddings, + use_cache=hf_config.use_cache, + vocab_size=hf_config.vocab_size + additional_vocab_size, + ) + + + +def copy_weights_from_hf_to_nanotron_vision( + nanotron_model: VisionTransformer, + hf_model: AutoModel, + nanotron_vision_config: Idefics3VisionConfig +): + log_rank("Copying weights from Idefic3 ViT model to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + + # Vision Embeddings + log_rank("Copying Vision Embeddings...", logger=logger, level=logging.INFO, rank=0) + + assert ( + nanotron_model.embeddings.patch_embedding.weight.shape == hf_model.embeddings.patch_embedding.weight.shape + ) + + assert( + nanotron_model.embeddings.patch_embedding.bias.shape == hf_model.embeddings.patch_embedding.bias.shape + ) + + assert ( + nanotron_model.embeddings.position_embedding.weight.shape + == hf_model.embeddings.position_embedding.weight.shape + ) + + with torch.no_grad(): + nanotron_model.embeddings.patch_embedding.weight.copy_(hf_model.embeddings.patch_embedding.weight) + + nanotron_model.embeddings.patch_embedding.bias.copy_(hf_model.embeddings.patch_embedding.bias) + + nanotron_model.embeddings.position_embedding.weight.copy_(hf_model.embeddings.position_embedding.weight) + + + log_rank("Copied Vision Embeddings", logger=logger, level=logging.INFO, rank=0) + + for i in tqdm( + range(nanotron_vision_config.num_hidden_layers), + desc="Copying Vision Layers", + total=nanotron_vision_config.num_hidden_layers, + ): + assert ( + nanotron_model.encoder[i].layer_norm1.weight.shape == hf_model.encoder.layers[i].layer_norm1.weight.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].layer_norm1.weight.copy_(hf_model.encoder.layers[i].layer_norm1.weight) + + assert ( + nanotron_model.encoder[i].layer_norm1.bias.shape == hf_model.encoder.layers[i].layer_norm1.bias.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].layer_norm1.bias.copy_(hf_model.encoder.layers[i].layer_norm1.bias) + + tmp_qkv_proj = torch.cat( + [ + hf_model.encoder.layers[i].self_attn.q_proj.weight, + hf_model.encoder.layers[i].self_attn.k_proj.weight, + hf_model.encoder.layers[i].self_attn.v_proj.weight, + ], + dim=0, + ) + + assert ( + tmp_qkv_proj.shape == nanotron_model.encoder[i].self_attn.qkv_proj.weight.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].self_attn.qkv_proj.weight.copy_(tmp_qkv_proj) + + tmp_qkv_proj_bias = torch.cat( + [ + hf_model.encoder.layers[i].self_attn.q_proj.bias, + hf_model.encoder.layers[i].self_attn.k_proj.bias, + hf_model.encoder.layers[i].self_attn.v_proj.bias, + ], + dim=0, + ) + + assert ( + tmp_qkv_proj_bias.shape == nanotron_model.encoder[i].self_attn.qkv_proj.bias.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].self_attn.qkv_proj.bias.copy_(tmp_qkv_proj_bias) + + ## O + + assert ( + nanotron_model.encoder[i].self_attn.o_proj.weight.shape == hf_model.encoder.layers[i].self_attn.out_proj.weight.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].self_attn.o_proj.weight.copy_(hf_model.encoder.layers[i].self_attn.out_proj.weight) + + assert ( + nanotron_model.encoder[i].self_attn.o_proj.bias.shape == hf_model.encoder.layers[i].self_attn.out_proj.bias.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].self_attn.o_proj.bias.copy_(hf_model.encoder.layers[i].self_attn.out_proj.bias) + + # Layer Norm 2 + + assert ( + nanotron_model.encoder[i].layer_norm2.weight.shape == hf_model.encoder.layers[i].layer_norm2.weight.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].layer_norm2.weight.copy_(hf_model.encoder.layers[i].layer_norm2.weight) + + assert ( + nanotron_model.encoder[i].layer_norm2.bias.shape == hf_model.encoder.layers[i].layer_norm2.bias.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].layer_norm2.bias.copy_(hf_model.encoder.layers[i].layer_norm2.bias) + + # MLP + ## FC1 + + assert ( + nanotron_model.encoder[i].mlp.fc1.weight.shape == hf_model.encoder.layers[i].mlp.fc1.weight.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].mlp.fc1.weight.copy_(hf_model.encoder.layers[i].mlp.fc1.weight) + + assert ( + nanotron_model.encoder[i].mlp.fc1.bias.shape == hf_model.encoder.layers[i].mlp.fc1.bias.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].mlp.fc1.bias.copy_(hf_model.encoder.layers[i].mlp.fc1.bias) + + ## FC2 + + assert ( + nanotron_model.encoder[i].mlp.fc2.weight.shape == hf_model.encoder.layers[i].mlp.fc2.weight.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].mlp.fc2.weight.copy_(hf_model.encoder.layers[i].mlp.fc2.weight) + + assert ( + nanotron_model.encoder[i].mlp.fc2.bias.shape == hf_model.encoder.layers[i].mlp.fc2.bias.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].mlp.fc2.bias.copy_(hf_model.encoder.layers[i].mlp.fc2.bias) + + log_rank("Copied Vision Layers", logger=logger, level=logging.INFO, rank=0) + + # Post layer norm + + assert ( + nanotron_model.post_layernorm.weight.shape == hf_model.post_layernorm.weight.shape + ) + + with torch.no_grad(): + nanotron_model.post_layernorm.weight.copy_(hf_model.post_layernorm.weight) + + assert ( + nanotron_model.post_layernorm.bias.shape == hf_model.post_layernorm.bias.shape + ) + + with torch.no_grad(): + nanotron_model.post_layernorm.bias.copy_(hf_model.post_layernorm.bias) + + log_rank("Copied Post Layer Norm", logger=logger, level=logging.INFO, rank=0) + +def copy_weights_from_hf_to_nanotron_remaining( + nanotron_model: Idefics3Model, + hf_model_llama: AutoModel, + nanotron_config: Idefics3Config +): + + log_rank("Copying weights from Idefic3 Llama embeddings to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + + hf_vocab_size = hf_model_llama.config.vocab_size + + assert ( + nanotron_model.combined_embeddings.pp_block.text_embeddings.token_embedding.weight[:hf_vocab_size].shape + == hf_model_llama.embed_tokens.weight.shape + ) + with torch.no_grad(): + nanotron_model.combined_embeddings.pp_block.text_embeddings.token_embedding.weight[:hf_vocab_size].copy_( + hf_model_llama.embed_tokens.weight + ) + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory to store the converted Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="HuggingFace LLama3 Model") + group.add_argument( + "--pretrained-model-name-or-path-llama3", + type=str, + required=True, + help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo on the Hugging Face Hub", + ) + + group = parser.add_argument_group(title="HuggingFace SigLIP Model") + group.add_argument( + "--pretrained-model-name-or-path-siglip", + type=str, + required=True, + help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo on the Hugging Face Hub", + ) + + + args = parser.parse_args() + + return args + + +def main(args): + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs()) + + # Load Llama3-8B HF model + log_rank( + f"Loading pretrained Llama3 Model: {args.pretrained_model_name_or_path_llama3}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + hf_model_llama = AutoModelForCausalLM.from_pretrained( + args.pretrained_model_name_or_path_llama3, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2" + ).to(DEVICE) + hf_config_llama = hf_model_llama.config + + + # Set Nanotron LlamaConfig + vocab_size = hf_config_llama.vocab_size + + # Expand & ensure that it's divisible by 4 + + additional_vocab_size = 4 - (vocab_size % 4) + nanotron_llama_config = nanotron_config_from_hf_config_llama(hf_config_llama, additional_vocab_size) + + # Load SigLIP HF model + log_rank( + f"Loading pretrained SigLIP Model: {args.pretrained_model_name_or_path_siglip}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + hf_model_siglip = AutoModel.from_pretrained( + args.pretrained_model_name_or_path_siglip, torch_dtype=TORCH_DTYPE, + attn_implementation="flash_attention_2", + ).to(DEVICE) + hf_config_siglip = hf_model_siglip.config.vision_config + + # Set Nanotron SigLIPConfig + nanotron_vision_config = Idefics3VisionConfig( + hidden_size=hf_config_siglip.hidden_size, + image_size=hf_config_siglip.image_size, + intermediate_size=hf_config_siglip.intermediate_size, + num_hidden_layers= hf_config_siglip.num_hidden_layers, + num_attention_heads=hf_config_siglip.num_attention_heads, + num_key_value_heads=hf_config_siglip.num_attention_heads, + num_channels=hf_config_siglip.num_channels, + patch_size=hf_config_siglip.patch_size, + hidden_act=hf_config_siglip.hidden_act, + layer_norm_eps=hf_config_siglip.layer_norm_eps, + attention_dropout=hf_config_siglip.attention_dropout, + is_using_mup=False + ) + + pad_token_id = hf_config_llama.pad_token_id + if pad_token_id is None: + pad_token_id = 128002 + + nanotron_idefics3_config = Idefics3Config( + text_config=nanotron_llama_config, + vision_config=nanotron_vision_config, + image_token_id=vocab_size + 1, + pad_token_id=pad_token_id, + scale_factor=2, + vocab_size=vocab_size + additional_vocab_size, + ) + + # Init Idefics3 Nanotron model + log_rank("Init empty Nanotron Idefics3 Model", logger=logger, level=logging.INFO, rank=0) + nanotron_model = build_model( + model_builder=lambda: Idefics3ForTraining( + config=nanotron_idefics3_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, + ) + + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) + sanity_check(root_module=nanotron_model) + + copy_weights_from_hf_to_nanotron_vision( + nanotron_model=nanotron_model.model.combined_embeddings.pp_block.vision_model, + hf_model=hf_model_siglip.vision_model, + nanotron_vision_config=nanotron_vision_config, + ) + + log_rank("Copied weights from HF SigLIP model to Nanotron model!", logger=logger, level=logging.INFO, rank=0) + + # Copy weights from HF to Nanotron + copy_weights_from_hf_to_nanotron_llama( + nanotron_model=nanotron_model.model.llama, + hf_model=hf_model_llama.model, + nanotron_config=nanotron_idefics3_config, + additional_vocab_size=additional_vocab_size + ) + + log_rank("Copied weights from HF Llama model to Nanotron model!", logger=logger, level=logging.INFO, rank=0) + + + copy_weights_from_hf_to_nanotron_remaining( + nanotron_model=nanotron_model.model, + hf_model_llama=hf_model_llama.model, + nanotron_config=nanotron_idefics3_config, + ) + + nanotron_checkpoint_path = Path( + args.nanotron_checkpoint_path + ) + + save_weights( + model=nanotron_model, + root_folder=nanotron_checkpoint_path, + parallel_context=parallel_context, + ) + + # Store Config and Model Config files + with open(nanotron_checkpoint_path / "config.yaml", "w") as f: + config = Config( + general=GeneralArgs(project="Nanotron", run="Idefics-Custom"), + parallelism=parallel_config, + model=ModelArgs( + init_method=ExistingCheckpointInit(nanotron_checkpoint_path), + model_config=nanotron_idefics3_config, + ), + tokenizer=TokenizerArgs(nanotron_checkpoint_path), + ) + log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0) + yaml.dump(config.as_dict(), f) + + with open(nanotron_checkpoint_path / "model_config.json", "w") as f: + log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0) + json.dump(asdict(nanotron_idefics3_config), f) + + log_rank( + f"Checkpoint conversion finished, check {args.nanotron_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + +if __name__ == "__main__": + _args = get_args() + main(_args) \ No newline at end of file diff --git a/tools/idefics3/convert_hf_to_nanotron.py b/tools/idefics3/convert_hf_to_nanotron.py new file mode 100644 index 00000000..7c359ea9 --- /dev/null +++ b/tools/idefics3/convert_hf_to_nanotron.py @@ -0,0 +1,521 @@ +""" +HF_HUB_ENABLE_HF_TRANSFER=1 torchrun --nproc-per-node 1 tools/idefics3/convert_hf_to_nanotron.py --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Idefics3-8B-Llama3 --pretrained-model-name-or-path HuggingFaceM4/Idefics3-8B-Llama3 +""" +import sys +sys.path.append('.venv/lib/python3.10/site-packages') + +import argparse +from dataclasses import asdict +import json +from pathlib import Path +import torch +from tqdm import tqdm +import yaml +from nanotron import logging +from nanotron.config.config import Config, GeneralArgs, LoggingArgs, ModelArgs, TokenizerArgs +from nanotron.config.models_config import ExistingCheckpointInit, Idefics3VisionConfig, Idefics3Config +from nanotron.config.parallelism_config import ParallelismArgs +from nanotron.logging import log_rank, set_ranks_logging_level +from nanotron.config.models_config import LlamaConfig as LlamaConfigNanotron +from nanotron.models.base import build_model +from nanotron.models.idefics import Idefics3ForTraining, Idefics3Model, VisionTransformer +from nanotron.parallel.context import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.serialize.weights import save_weights +from nanotron.trainer import mark_tied_parameters + +from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoModel + + +def copy_weights_from_hf_to_nanotron_llama(nanotron_model, hf_model, nanotron_config, + additional_vocab_size): + nanotron_llama_config = nanotron_config.text_config + # Copy params from HF to Nanotron + log_rank("Copying weights from HF model to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + # Token embeddings + log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0) + + # Decoder layers + for i in tqdm( + range(nanotron_llama_config.num_hidden_layers), + desc="Copying Hidden Layers", + total=nanotron_llama_config.num_hidden_layers, + ): + # Input layer norm + assert ( + hf_model.layers[i].input_layernorm.weight.shape + == nanotron_model.decoder[i].pp_block.input_layernorm.weight.shape + ) + with torch.no_grad(): + nanotron_model.decoder[i].pp_block.input_layernorm.weight.copy_( + hf_model.layers[i].input_layernorm.weight + ) + + # Self attn + ## QKV + tmp_qkv_proj = torch.cat( + [ + hf_model.layers[i].self_attn.q_proj.weight, + hf_model.layers[i].self_attn.k_proj.weight, + hf_model.layers[i].self_attn.v_proj.weight, + ], + dim=0, + ) + assert tmp_qkv_proj.shape == nanotron_model.decoder[i].pp_block.attn.qkv_proj.weight.shape + with torch.no_grad(): + nanotron_model.decoder[i].pp_block.attn.qkv_proj.weight.copy_(tmp_qkv_proj) + + ## O + assert ( + hf_model.layers[i].self_attn.o_proj.weight.shape + == nanotron_model.decoder[i].pp_block.attn.o_proj.weight.shape + ) + with torch.no_grad(): + nanotron_model.decoder[i].pp_block.attn.o_proj.weight.copy_( + hf_model.layers[i].self_attn.o_proj.weight + ) + + # MLP + ## Gate Up Proj + tmp_gate_up_proj = torch.cat( + [ + hf_model.layers[i].mlp.gate_proj.weight, + hf_model.layers[i].mlp.up_proj.weight, + ], + dim=0, + ) + + assert tmp_gate_up_proj.shape == nanotron_model.decoder[i].pp_block.mlp.gate_up_proj.weight.shape + with torch.no_grad(): + nanotron_model.decoder[i].pp_block.mlp.gate_up_proj.weight.copy_(tmp_gate_up_proj) + + ## Down Proj + assert ( + hf_model.layers[i].mlp.down_proj.weight.shape + == nanotron_model.decoder[i].pp_block.mlp.down_proj.weight.shape + ) + with torch.no_grad(): + nanotron_model.decoder[i].pp_block.mlp.down_proj.weight.copy_( + hf_model.layers[i].mlp.down_proj.weight + ) + + # Post attn layer norm + assert ( + hf_model.layers[i].post_attention_layernorm.weight.shape + == nanotron_model.decoder[i].pp_block.post_attention_layernorm.weight.shape + ) + with torch.no_grad(): + nanotron_model.decoder[i].pp_block.post_attention_layernorm.weight.copy_( + hf_model.layers[i].post_attention_layernorm.weight + ) + + # Last layer norm + log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.final_layer_norm.pp_block.weight.shape == hf_model.norm.weight.shape + with torch.no_grad(): + nanotron_model.final_layer_norm.pp_block.weight.copy_(hf_model.norm.weight) + +def nanotron_config_from_hf_config_llama(hf_config, additional_vocab_size=3): + return LlamaConfigNanotron( + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + hidden_act=hf_config.hidden_act, + hidden_size=hf_config.hidden_size, + initializer_range=hf_config.initializer_range, + intermediate_size=hf_config.intermediate_size, + is_llama_config=True, + max_position_embeddings=hf_config.max_position_embeddings, + num_attention_heads=hf_config.num_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + num_key_value_heads=hf_config.num_key_value_heads, + pad_token_id=None, + pretraining_tp=hf_config.pretraining_tp, + rms_norm_eps=hf_config.rms_norm_eps, + rope_scaling=hf_config.rope_scaling, + rope_theta=hf_config.rope_theta, + rope_interleaved=False, + tie_word_embeddings=hf_config.tie_word_embeddings, + use_cache=hf_config.use_cache, + vocab_size=hf_config.vocab_size + additional_vocab_size, + ) + + + +logger = logging.get_logger(__name__) + +DEVICE = torch.device("cpu") +TORCH_DTYPE = torch.bfloat16 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory to store the converted Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="HuggingFace Idefic3 Model") + group.add_argument( + "--pretrained-model-name-or-path", + type=str, + required=True, + help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo on the Hugging Face Hub", + ) + + args = parser.parse_args() + + return args + +def copy_weights_from_hf_to_nanotron_vision( + nanotron_model: VisionTransformer, + hf_model: AutoModel, + nanotron_vision_config: Idefics3VisionConfig +): + log_rank("Copying weights from Idefic3 ViT model to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + + # Vision Embeddings + log_rank("Copying Vision Embeddings...", logger=logger, level=logging.INFO, rank=0) + + assert ( + nanotron_model.embeddings.patch_embedding.weight.shape == hf_model.embeddings.patch_embedding.weight.shape + ) + + assert( + nanotron_model.embeddings.patch_embedding.bias.shape == hf_model.embeddings.patch_embedding.bias.shape + ) + + assert ( + nanotron_model.embeddings.position_embedding.weight.shape + == hf_model.embeddings.position_embedding.weight.shape + ) + + with torch.no_grad(): + nanotron_model.embeddings.patch_embedding.weight.copy_(hf_model.embeddings.patch_embedding.weight) + + nanotron_model.embeddings.patch_embedding.bias.copy_(hf_model.embeddings.patch_embedding.bias) + + nanotron_model.embeddings.position_embedding.weight.copy_(hf_model.embeddings.position_embedding.weight) + + + log_rank("Copied Vision Embeddings", logger=logger, level=logging.INFO, rank=0) + + for i in tqdm( + range(nanotron_vision_config.num_hidden_layers), + desc="Copying Vision Layers", + total=nanotron_vision_config.num_hidden_layers, + ): + assert ( + nanotron_model.encoder[i].layer_norm1.weight.shape == hf_model.encoder.layers[i].layer_norm1.weight.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].layer_norm1.weight.copy_(hf_model.encoder.layers[i].layer_norm1.weight) + + assert ( + nanotron_model.encoder[i].layer_norm1.bias.shape == hf_model.encoder.layers[i].layer_norm1.bias.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].layer_norm1.bias.copy_(hf_model.encoder.layers[i].layer_norm1.bias) + + tmp_qkv_proj = torch.cat( + [ + hf_model.encoder.layers[i].self_attn.q_proj.weight, + hf_model.encoder.layers[i].self_attn.k_proj.weight, + hf_model.encoder.layers[i].self_attn.v_proj.weight, + ], + dim=0, + ) + + assert ( + tmp_qkv_proj.shape == nanotron_model.encoder[i].self_attn.qkv_proj.weight.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].self_attn.qkv_proj.weight.copy_(tmp_qkv_proj) + + tmp_qkv_proj_bias = torch.cat( + [ + hf_model.encoder.layers[i].self_attn.q_proj.bias, + hf_model.encoder.layers[i].self_attn.k_proj.bias, + hf_model.encoder.layers[i].self_attn.v_proj.bias, + ], + dim=0, + ) + + assert ( + tmp_qkv_proj_bias.shape == nanotron_model.encoder[i].self_attn.qkv_proj.bias.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].self_attn.qkv_proj.bias.copy_(tmp_qkv_proj_bias) + + ## O + + assert ( + nanotron_model.encoder[i].self_attn.o_proj.weight.shape == hf_model.encoder.layers[i].self_attn.out_proj.weight.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].self_attn.o_proj.weight.copy_(hf_model.encoder.layers[i].self_attn.out_proj.weight) + + assert ( + nanotron_model.encoder[i].self_attn.o_proj.bias.shape == hf_model.encoder.layers[i].self_attn.out_proj.bias.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].self_attn.o_proj.bias.copy_(hf_model.encoder.layers[i].self_attn.out_proj.bias) + + # Layer Norm 2 + + assert ( + nanotron_model.encoder[i].layer_norm2.weight.shape == hf_model.encoder.layers[i].layer_norm2.weight.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].layer_norm2.weight.copy_(hf_model.encoder.layers[i].layer_norm2.weight) + + assert ( + nanotron_model.encoder[i].layer_norm2.bias.shape == hf_model.encoder.layers[i].layer_norm2.bias.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].layer_norm2.bias.copy_(hf_model.encoder.layers[i].layer_norm2.bias) + + # MLP + ## FC1 + + assert ( + nanotron_model.encoder[i].mlp.fc1.weight.shape == hf_model.encoder.layers[i].mlp.fc1.weight.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].mlp.fc1.weight.copy_(hf_model.encoder.layers[i].mlp.fc1.weight) + + assert ( + nanotron_model.encoder[i].mlp.fc1.bias.shape == hf_model.encoder.layers[i].mlp.fc1.bias.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].mlp.fc1.bias.copy_(hf_model.encoder.layers[i].mlp.fc1.bias) + + ## FC2 + + assert ( + nanotron_model.encoder[i].mlp.fc2.weight.shape == hf_model.encoder.layers[i].mlp.fc2.weight.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].mlp.fc2.weight.copy_(hf_model.encoder.layers[i].mlp.fc2.weight) + + assert ( + nanotron_model.encoder[i].mlp.fc2.bias.shape == hf_model.encoder.layers[i].mlp.fc2.bias.shape + ) + + with torch.no_grad(): + nanotron_model.encoder[i].mlp.fc2.bias.copy_(hf_model.encoder.layers[i].mlp.fc2.bias) + + log_rank("Copied Vision Layers", logger=logger, level=logging.INFO, rank=0) + + # Post layer norm + + assert ( + nanotron_model.post_layernorm.weight.shape == hf_model.post_layernorm.weight.shape + ) + + with torch.no_grad(): + nanotron_model.post_layernorm.weight.copy_(hf_model.post_layernorm.weight) + + assert ( + nanotron_model.post_layernorm.bias.shape == hf_model.post_layernorm.bias.shape + ) + + with torch.no_grad(): + nanotron_model.post_layernorm.bias.copy_(hf_model.post_layernorm.bias) + + log_rank("Copied Post Layer Norm", logger=logger, level=logging.INFO, rank=0) + + +def copy_weights_from_hf_to_nanotron_remaining( + nanotron_model: Idefics3Model, + hf_model: AutoModel, + nanotron_config: Idefics3Config +): + + log_rank("Copying weights from Idefic3 Llama embeddings to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + + hf_vocab_size = hf_model.model.text_model.config.vocab_size + + assert ( + nanotron_model.combined_embeddings.pp_block.text_embeddings.token_embedding.weight[:hf_vocab_size].shape + == hf_model.model.text_model.embed_tokens.weight.shape + ) + with torch.no_grad(): + nanotron_model.combined_embeddings.pp_block.text_embeddings.token_embedding.weight[:hf_vocab_size].copy_( + hf_model.model.text_model.embed_tokens.weight + ) + + log_rank("Copying weights from Idefic3 Connector to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + + assert ( + nanotron_model.combined_embeddings.pp_block.connector.modality_projector.proj.weight.shape == hf_model.model.connector.modality_projection.proj.weight.shape + ) + + with torch.no_grad(): + nanotron_model.combined_embeddings.pp_block.connector.modality_projector.proj.weight.copy_(hf_model.model.connector.modality_projection.proj.weight) + + log_rank("Copied Connector", logger=logger, level=logging.INFO, rank=0) + + log_rank("Copying weights from Idefic3 Head to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + + vocab_size = hf_model.vocab_size + + assert ( + nanotron_model.lm_head.pp_block.weight[:vocab_size].shape == hf_model.lm_head.weight.shape + ) + + with torch.no_grad(): + nanotron_model.lm_head.pp_block.weight[:vocab_size].copy_(hf_model.lm_head.weight) + + log_rank("Copied Head", logger=logger, level=logging.INFO, rank=0) + + +def main(args): + additional_vocab_size = 1 + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs()) + + # Load Llama3-8B HF model + log_rank( + f"Loading pretrained Idefics3 model: {args.pretrained_model_name_or_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + hf_model = AutoModelForVision2Seq.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2" + ).to(DEVICE) + hf_config = hf_model.config + hf_config_vision = hf_config.vision_config + + # Set Nanotron LlamaConfig + nanotron_llama_config = nanotron_config_from_hf_config_llama(hf_config.text_config, additional_vocab_size) + + # Set Nanotron SigLIPConfig + nanotron_vision_config = Idefics3VisionConfig( + hidden_size=hf_config_vision.hidden_size, + image_size=hf_config_vision.image_size, + intermediate_size=hf_config_vision.intermediate_size, + num_hidden_layers= hf_config_vision.num_hidden_layers, + num_attention_heads=hf_config_vision.num_attention_heads, + num_key_value_heads=hf_config_vision.num_attention_heads, + num_channels=hf_config_vision.num_channels, + patch_size=hf_config_vision.patch_size, + hidden_act=hf_config_vision.hidden_act, + layer_norm_eps=hf_config_vision.layer_norm_eps, + attention_dropout=hf_config_vision.attention_dropout, + is_using_mup=False + ) + + nanotron_idefics3_config = Idefics3Config( + text_config=nanotron_llama_config, + vision_config=nanotron_vision_config, + image_token_id=hf_config.image_token_id, + pad_token_id=hf_config.vision_config.pad_token_id, + scale_factor=hf_config.scale_factor, + vocab_size=nanotron_llama_config.vocab_size + ) + # Init Idefics3 Nanotron model + log_rank("Init empty Nanotron Idefics3 Model", logger=logger, level=logging.INFO, rank=0) + nanotron_model = build_model( + model_builder=lambda: Idefics3ForTraining( + config=nanotron_idefics3_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, + ) + + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) + sanity_check(root_module=nanotron_model) + + copy_weights_from_hf_to_nanotron_vision( + nanotron_model=nanotron_model.model.combined_embeddings.pp_block.vision_model, + hf_model=hf_model.model.vision_model, + nanotron_vision_config=nanotron_vision_config, + ) + + log_rank("Copied weights from HF SigLIP model to Nanotron model!", logger=logger, level=logging.INFO, rank=0) + + + + # Copy weights from HF to Nanotron + copy_weights_from_hf_to_nanotron_llama( + nanotron_model=nanotron_model.model.llama, + hf_model=hf_model.model.text_model, + nanotron_config=nanotron_idefics3_config, + additional_vocab_size=additional_vocab_size, + ) + + log_rank("Copied weights from HF Llama model to Nanotron model!", logger=logger, level=logging.INFO, rank=0) + + + copy_weights_from_hf_to_nanotron_remaining( + nanotron_model=nanotron_model.model, + hf_model=hf_model, + nanotron_config=nanotron_idefics3_config, + ) + + nanotron_checkpoint_path = Path( + args.nanotron_checkpoint_path + ) + + save_weights( + model=nanotron_model, + root_folder=nanotron_checkpoint_path, + parallel_context=parallel_context, + ) + + # Store Config and Model Config files + with open(nanotron_checkpoint_path / "config.yaml", "w") as f: + config = Config( + general=GeneralArgs(project="Nanotron", run="Idefics3"), + parallelism=parallel_config, + model=ModelArgs( + init_method=ExistingCheckpointInit(nanotron_checkpoint_path), + model_config=nanotron_idefics3_config, + ), + tokenizer=TokenizerArgs(nanotron_checkpoint_path), + ) + log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0) + yaml.dump(config.as_dict(), f) + + with open(nanotron_checkpoint_path / "model_config.json", "w") as f: + log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0) + json.dump(asdict(nanotron_idefics3_config), f) + + log_rank( + f"Checkpoint conversion finished, check {args.nanotron_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + +if __name__ == "__main__": + _args = get_args() + main(_args) \ No newline at end of file diff --git a/tools/idefics3/convert_nanotron_to_hf.py b/tools/idefics3/convert_nanotron_to_hf.py new file mode 100644 index 00000000..3b23280e --- /dev/null +++ b/tools/idefics3/convert_nanotron_to_hf.py @@ -0,0 +1,520 @@ +""" +torchrun --nproc-per-node 1 tools/idefics3/convert_nanotron_to_hf.py --huggingface-checkpoint-path idefics3_ckpt --pretrained-model-name-or-path nanotron_checkpoints/Nanotron-Idefics3-8B-Llama3 +""" + +import sys +sys.path.insert(0, '/capstor/scratch/cscs/eguo/vlm_convert/nanotron') + +import argparse +import os +from dataclasses import asdict +import json +from pathlib import Path +import torch +from tqdm import tqdm +import yaml +from nanotron import logging +from nanotron.config import Config, LoggingArgs, ParallelismArgs, get_config_from_file +from nanotron.config.models_config import ExistingCheckpointInit, Idefics3VisionConfig, Idefics3Config +from nanotron.config.parallelism_config import ParallelismArgs +from nanotron.logging import log_rank, set_ranks_logging_level +from nanotron.config.models_config import LlamaConfig as LlamaConfigNanotron +from nanotron.models.base import build_model +from nanotron.models.llama import LlamaForTraining +from nanotron.models.idefics import Idefics3ForTraining, Idefics3Model, VisionTransformer +from nanotron.parallel.context import ParallelContext +from nanotron.trainer import mark_tied_parameters +from nanotron.parallel.parameters import sanity_check +from nanotron.serialize import load_weights +from nanotron.serialize.weights import save_weights + +from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoModel, AutoModelForCausalLM, AutoConfig +from transformers.models.llama import LlamaConfig as LlamaConfigHF +from transformers import Idefics3Config as Idefics3ConfigHF +from accelerate import init_empty_weights + + +logger = logging.get_logger(__name__) + +DEVICE = torch.device("cpu") +TORCH_DTYPE = torch.bfloat16 + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Path to Save Converted HuggingFace Idefic3 Model") + group.add_argument( + "--huggingface-checkpoint-path", + type=str, + required=True, + help="A path to a directory to store the converted HF Checkpoint", + ) + + group = parser.add_argument_group(title="Nanotron Idefic3 Model") + group.add_argument( + "--pretrained-model-name-or-path", + type=str, + required=True, + help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo in nanotron", + ) + + args = parser.parse_args() + + return args + +def copy_weights_from_nanotron_to_hf_llama(nanotron_model, hf_model, nanotron_llama_config, additional_vocab_size): + # Copy params from Nanotron to HF + log_rank("Copying weights from Nanotron model to HF model...", logger=logger, level=logging.INFO, rank=0) + + # Decoder layers + for i in tqdm( + range(nanotron_llama_config.num_hidden_layers), + desc="Copying Hidden Layers", + total=nanotron_llama_config.num_hidden_layers, + ): + # Input layer norm + assert ( + hf_model.layers[i].input_layernorm.weight.shape + == nanotron_model.decoder[i].pp_block.input_layernorm.weight.shape + ) + with torch.no_grad(): + hf_model.layers[i].input_layernorm.weight.copy_( + nanotron_model.decoder[i].pp_block.input_layernorm.weight + ) + + # Self-attention QKV split + qkv_proj = nanotron_model.decoder[i].pp_block.attn.qkv_proj.weight + q_size = nanotron_llama_config.num_attention_heads * nanotron_llama_config.hidden_size // nanotron_llama_config.num_attention_heads + k_size = nanotron_llama_config.num_key_value_heads * nanotron_llama_config.hidden_size // nanotron_llama_config.num_attention_heads + v_size = nanotron_llama_config.num_key_value_heads * nanotron_llama_config.hidden_size // nanotron_llama_config.num_attention_heads + + q, k, v = torch.split(qkv_proj, [q_size, k_size, v_size], dim=0) + + assert q.shape == hf_model.layers[i].self_attn.q_proj.weight.shape + assert k.shape == hf_model.layers[i].self_attn.k_proj.weight.shape + assert v.shape == hf_model.layers[i].self_attn.v_proj.weight.shape + + with torch.no_grad(): + hf_model.layers[i].self_attn.q_proj.weight.copy_(q) + hf_model.layers[i].self_attn.k_proj.weight.copy_(k) + hf_model.layers[i].self_attn.v_proj.weight.copy_(v) + + # Output projection (O) + assert ( + hf_model.layers[i].self_attn.o_proj.weight.shape + == nanotron_model.decoder[i].pp_block.attn.o_proj.weight.shape + ) + with torch.no_grad(): + hf_model.layers[i].self_attn.o_proj.weight.copy_( + nanotron_model.decoder[i].pp_block.attn.o_proj.weight + ) + + # MLP: Gate and Up Proj + gate_up_proj = nanotron_model.decoder[i].pp_block.mlp.gate_up_proj.weight + split_size = nanotron_llama_config.intermediate_size + gate_proj, up_proj = torch.split(gate_up_proj, [split_size, split_size], dim=0) + + assert gate_proj.shape == hf_model.layers[i].mlp.gate_proj.weight.shape + assert up_proj.shape == hf_model.layers[i].mlp.up_proj.weight.shape + + with torch.no_grad(): + hf_model.layers[i].mlp.gate_proj.weight.copy_(gate_proj) + hf_model.layers[i].mlp.up_proj.weight.copy_(up_proj) + + # MLP: Down Proj + assert ( + hf_model.layers[i].mlp.down_proj.weight.shape + == nanotron_model.decoder[i].pp_block.mlp.down_proj.weight.shape + ) + with torch.no_grad(): + hf_model.layers[i].mlp.down_proj.weight.copy_( + nanotron_model.decoder[i].pp_block.mlp.down_proj.weight + ) + + # Post-attention Layer Norm + assert ( + hf_model.layers[i].post_attention_layernorm.weight.shape + == nanotron_model.decoder[i].pp_block.post_attention_layernorm.weight.shape + ) + with torch.no_grad(): + hf_model.layers[i].post_attention_layernorm.weight.copy_( + nanotron_model.decoder[i].pp_block.post_attention_layernorm.weight + ) + + # Final Layer Norm + log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.final_layer_norm.pp_block.weight.shape == hf_model.norm.weight.shape + with torch.no_grad(): + hf_model.norm.weight.copy_(nanotron_model.final_layer_norm.pp_block.weight) + + log_rank("Llama weight copying completed successfully!", logger=logger, level=logging.INFO, rank=0) + + +def copy_weights_from_nanotron_to_hf_vision( + nanotron_model: VisionTransformer, + hf_model: AutoModel, + nanotron_vision_config: Idefics3VisionConfig +): + log_rank("Copying weights from Nanotron model to HF model...", logger=logger, level=logging.INFO, rank=0) + + # Vision Embeddings + log_rank("Copying Vision Embeddings...", logger=logger, level=logging.INFO, rank=0) + + assert ( + nanotron_model.embeddings.patch_embedding.weight.shape == hf_model.embeddings.patch_embedding.weight.shape + ) + assert ( + nanotron_model.embeddings.patch_embedding.bias.shape == hf_model.embeddings.patch_embedding.bias.shape + ) + assert ( + nanotron_model.embeddings.position_embedding.weight.shape + == hf_model.embeddings.position_embedding.weight.shape + ) + + with torch.no_grad(): + hf_model.embeddings.patch_embedding.weight.copy_( + nanotron_model.embeddings.patch_embedding.weight + ) + hf_model.embeddings.patch_embedding.bias.copy_( + nanotron_model.embeddings.patch_embedding.bias + ) + hf_model.embeddings.position_embedding.weight.copy_( + nanotron_model.embeddings.position_embedding.weight + ) + + log_rank("Copied Vision Embeddings", logger=logger, level=logging.INFO, rank=0) + + for i in tqdm( + range(nanotron_vision_config.num_hidden_layers), + desc="Copying Vision Layers", + total=nanotron_vision_config.num_hidden_layers, + ): + # Layer Norm 1 + assert ( + nanotron_model.encoder[i].layer_norm1.weight.shape == hf_model.encoder.layers[i].layer_norm1.weight.shape + ) + assert ( + nanotron_model.encoder[i].layer_norm1.bias.shape == hf_model.encoder.layers[i].layer_norm1.bias.shape + ) + + with torch.no_grad(): + hf_model.encoder.layers[i].layer_norm1.weight.copy_( + nanotron_model.encoder[i].layer_norm1.weight + ) + hf_model.encoder.layers[i].layer_norm1.bias.copy_( + nanotron_model.encoder[i].layer_norm1.bias + ) + + # QKV Projections + tmp_qkv_proj = nanotron_model.encoder[i].self_attn.qkv_proj.weight.chunk(3, dim=0) + + assert ( + tmp_qkv_proj[0].shape == hf_model.encoder.layers[i].self_attn.q_proj.weight.shape + ) + assert ( + tmp_qkv_proj[1].shape == hf_model.encoder.layers[i].self_attn.k_proj.weight.shape + ) + assert ( + tmp_qkv_proj[2].shape == hf_model.encoder.layers[i].self_attn.v_proj.weight.shape + ) + + with torch.no_grad(): + hf_model.encoder.layers[i].self_attn.q_proj.weight.copy_(tmp_qkv_proj[0]) + hf_model.encoder.layers[i].self_attn.k_proj.weight.copy_(tmp_qkv_proj[1]) + hf_model.encoder.layers[i].self_attn.v_proj.weight.copy_(tmp_qkv_proj[2]) + + # QKV Biases + tmp_qkv_proj_bias = nanotron_model.encoder[i].self_attn.qkv_proj.bias.chunk(3, dim=0) + + assert ( + tmp_qkv_proj_bias[0].shape == hf_model.encoder.layers[i].self_attn.q_proj.bias.shape + ) + assert ( + tmp_qkv_proj_bias[1].shape == hf_model.encoder.layers[i].self_attn.k_proj.bias.shape + ) + assert ( + tmp_qkv_proj_bias[2].shape == hf_model.encoder.layers[i].self_attn.v_proj.bias.shape + ) + + with torch.no_grad(): + hf_model.encoder.layers[i].self_attn.q_proj.bias.copy_(tmp_qkv_proj_bias[0]) + hf_model.encoder.layers[i].self_attn.k_proj.bias.copy_(tmp_qkv_proj_bias[1]) + hf_model.encoder.layers[i].self_attn.v_proj.bias.copy_(tmp_qkv_proj_bias[2]) + + # Output Projection + assert ( + nanotron_model.encoder[i].self_attn.o_proj.weight.shape == hf_model.encoder.layers[i].self_attn.out_proj.weight.shape + ) + assert ( + nanotron_model.encoder[i].self_attn.o_proj.bias.shape == hf_model.encoder.layers[i].self_attn.out_proj.bias.shape + ) + + with torch.no_grad(): + hf_model.encoder.layers[i].self_attn.out_proj.weight.copy_( + nanotron_model.encoder[i].self_attn.o_proj.weight + ) + hf_model.encoder.layers[i].self_attn.out_proj.bias.copy_( + nanotron_model.encoder[i].self_attn.o_proj.bias + ) + + # Layer Norm 2 + assert ( + nanotron_model.encoder[i].layer_norm2.weight.shape == hf_model.encoder.layers[i].layer_norm2.weight.shape + ) + assert ( + nanotron_model.encoder[i].layer_norm2.bias.shape == hf_model.encoder.layers[i].layer_norm2.bias.shape + ) + + with torch.no_grad(): + hf_model.encoder.layers[i].layer_norm2.weight.copy_( + nanotron_model.encoder[i].layer_norm2.weight + ) + hf_model.encoder.layers[i].layer_norm2.bias.copy_( + nanotron_model.encoder[i].layer_norm2.bias + ) + + # MLP Layers + assert ( + nanotron_model.encoder[i].mlp.fc1.weight.shape == hf_model.encoder.layers[i].mlp.fc1.weight.shape + ) + assert ( + nanotron_model.encoder[i].mlp.fc1.bias.shape == hf_model.encoder.layers[i].mlp.fc1.bias.shape + ) + + with torch.no_grad(): + hf_model.encoder.layers[i].mlp.fc1.weight.copy_( + nanotron_model.encoder[i].mlp.fc1.weight + ) + hf_model.encoder.layers[i].mlp.fc1.bias.copy_( + nanotron_model.encoder[i].mlp.fc1.bias + ) + + assert ( + nanotron_model.encoder[i].mlp.fc2.weight.shape == hf_model.encoder.layers[i].mlp.fc2.weight.shape + ) + assert ( + nanotron_model.encoder[i].mlp.fc2.bias.shape == hf_model.encoder.layers[i].mlp.fc2.bias.shape + ) + + with torch.no_grad(): + hf_model.encoder.layers[i].mlp.fc2.weight.copy_( + nanotron_model.encoder[i].mlp.fc2.weight + ) + hf_model.encoder.layers[i].mlp.fc2.bias.copy_( + nanotron_model.encoder[i].mlp.fc2.bias + ) + + log_rank("Copied Vision Layers", logger=logger, level=logging.INFO, rank=0) + + # Post Layer Norm + assert ( + nanotron_model.post_layernorm.weight.shape == hf_model.post_layernorm.weight.shape + ) + assert ( + nanotron_model.post_layernorm.bias.shape == hf_model.post_layernorm.bias.shape + ) + + with torch.no_grad(): + hf_model.post_layernorm.weight.copy_(nanotron_model.post_layernorm.weight) + hf_model.post_layernorm.bias.copy_(nanotron_model.post_layernorm.bias) + + log_rank("Copied Post Layer Norm", logger=logger, level=logging.INFO, rank=0) + +def copy_weights_from_nanotron_to_hf_remaining( + nanotron_model: Idefics3Model, + hf_model: AutoModel, + nanotron_config: Idefics3Config +): + log_rank("Copying weights from Idefic3 Llama embeddings to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + + hf_vocab_size = hf_model.model.text_model.config.vocab_size + + assert ( + nanotron_model.combined_embeddings.pp_block.text_embeddings.token_embedding.weight[:hf_vocab_size].shape + == hf_model.model.text_model.embed_tokens.weight.shape + ) + with torch.no_grad(): + hf_model.model.text_model.embed_tokens.weight.copy_( + nanotron_model.combined_embeddings.pp_block.text_embeddings.token_embedding.weight[:hf_vocab_size] + ) + + log_rank("Copying weights from Idefic3 Connector to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + + assert ( + nanotron_model.combined_embeddings.pp_block.connector.modality_projector.proj.weight.shape == hf_model.model.connector.modality_projection.proj.weight.shape + ) + + with torch.no_grad(): + hf_model.model.connector.modality_projection.proj.weight.copy_(nanotron_model.combined_embeddings.pp_block.connector.modality_projector.proj.weight) + + log_rank("Copied Connector", logger=logger, level=logging.INFO, rank=0) + + vocab_size = hf_model.vocab_size + + assert ( + nanotron_model.lm_head.pp_block.weight[:vocab_size].shape == hf_model.lm_head.weight.shape + ) + + with torch.no_grad(): + hf_model.lm_head.weight.copy_(nanotron_model.lm_head.pp_block.weight[:vocab_size]) + + log_rank("Copied Head", logger=logger, level=logging.INFO, rank=0) + + +def main(args): + # Init Nanotron Parallel Utilities + additional_vocab_size = 1 + parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs()) + + # Load Nanotron checkpoint config + log_rank( + f"Loading Nanotron checkpoint config file: {args.pretrained_model_name_or_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + nanotron_config = get_config_from_file( + os.path.join(args.pretrained_model_name_or_path, "config.yaml"), config_class=Config, model_config_class=None + ) + nanotron_idefics3_config = nanotron_config.model.model_config + nanotron_llama_config = nanotron_idefics3_config.text_config + nanotron_vision_config = nanotron_idefics3_config.vision_config + + + # Init Idefics3 Nanotron model + log_rank("Init empty Nanotron Idefics3 Model", logger=logger, level=logging.INFO, rank=0) + + nanotron_model = build_model( + model_builder=lambda: Idefics3ForTraining( + config=nanotron_idefics3_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, + ) + + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) + sanity_check(root_module=nanotron_model) + + # # Load Nanotron Checkpoint + log_rank("Loading Nanotron Idefics3 Model...", logger=logger, level=logging.INFO, rank=0) + load_weights( + model=nanotron_model, parallel_context=parallel_context, root_folder=Path(args.pretrained_model_name_or_path) + ) + + # Build empty HF Model + # log_rank("Init empty HF Llama3 Model", logger=logger, level=logging.INFO, rank=0) + + # hf_llama_model = AutoModelForCausalLM.from_config( # WARN This takes a long time + # config=LlamaConfigHF(**asdict(nanotron_llama_config)), + # torch_dtype=TORCH_DTYPE, + # attn_implementation="flash_attention_2", + # ).to(DEVICE) + + # log_rank("Init empty HF SigLIP Model", logger=logger, level=logging.INFO, rank=0) + # hf_siglip_model = AutoModel.from_config( + # config=SigLIPConfigHF(**asdict(nanotron_vision_config)), + # torch_dtype=TORCH_DTYPE, + # attn_implementation="flash_attention_2", + # ).to(DEVICE) + + + log_rank("Init empty HF Idefics3 Model", logger=logger, level=logging.INFO, rank=0) + + with init_empty_weights(): + hf_idefics3_model = AutoModelForVision2Seq.from_config( + config=Idefics3ConfigHF(**asdict(nanotron_idefics3_config)), + torch_dtype=TORCH_DTYPE, + attn_implementation="flash_attention_2", + ).to_empty(device=DEVICE) + + + + # hf_idefics3_model = AutoModelForVision2Seq.from_pretrained( + # args.hf_pretrained_model_name_or_path, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2" + # ).to(DEVICE) + + # Copy weights from Nanotron to Hugging Face + copy_weights_from_nanotron_to_hf_llama( + nanotron_model=nanotron_model.model.llama, + # hf_model=hf_llama_model, + hf_model=hf_idefics3_model.model.text_model, + nanotron_llama_config=nanotron_idefics3_config.text_config, + additional_vocab_size=additional_vocab_size, + ) + + log_rank("Copied weights from Nanotron Llama model to HF model!", logger=logger, level=logging.INFO, rank=0) + + copy_weights_from_nanotron_to_hf_vision( + nanotron_model=nanotron_model.model.combined_embeddings.pp_block.vision_model, + # hf_model=hf_siglip_model, + hf_model=hf_idefics3_model.model.vision_model, + nanotron_vision_config=nanotron_vision_config, + ) + + log_rank("Copied weights from Nanotron SigLIP model to HF model!", logger=logger, level=logging.INFO, rank=0) + + copy_weights_from_nanotron_to_hf_remaining( + nanotron_model=nanotron_model.model, + hf_model=hf_idefics3_model, + nanotron_config=nanotron_idefics3_config, + ) + + # log_rank("Copied weights from Nanotron Idefics3 model to HF model!", logger=logger, level=logging.INFO, rank=0) + + hf_checkpoint_path = Path( + args.huggingface_checkpoint_path, + ) + + # save_weights( + # model=hf_idefics3_model, + # parallel_context=parallel_context, + # root_folder=hf_checkpoint_path, + # ) + + # Store weights + log_rank("Saving HF model Checkpoint and Tokenizer!", logger=logger, level=logging.INFO, rank=0) + # hf_llama_model.save_pretrained(args.hugging_face_checkpoint_path_llama, from_pt=True) + # # Store tokenizer + # tokenizer_llama = AutoTokenizer.from_pretrained(nanotron_llama_config.tokenizer.tokenizer_name_or_path) + # tokenizer_llama.save_pretrained(args.hugging_face_checkpoint_path_llama) + # log_rank( + # f"Checkpoint conversion finished, check {args.hugging_face_checkpoint_path_llama}", + # logger=logger, + # level=logging.INFO, + # rank=0, + # ) + + # # Store weights + # hf_siglip_model.save_pretrained(args.hugging_face_checkpoint_path_siglip, from_pt=True) + # # Store tokenizer + # tokenizer_siglip = AutoTokenizer.from_pretrained(nanotron_vision_config.tokenizer.tokenizer_name_or_path) + # tokenizer_siglip.save_pretrained(args.hugging_face_checkpoint_path_siglip) + + hf_idefics3_model.save_pretrained(args.huggingface_checkpoint_path, from_pt=True) + hf_idefics3_model.config.save_pretrained(args.huggingface_checkpoint_path) + # tokenizer_idefics3 = AutoTokenizer.from_pretrained(nanotron_config.tokenizer.tokenizer_name_or_path) + # tokenizer_idefics3.save_pretrained(args.huggingface_checkpoint_path) + + log_rank( + f"Checkpoint conversion finished, check {args.huggingface_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + +if __name__ == "__main__": + _args = get_args() + main(_args) + + \ No newline at end of file diff --git a/tools/idefics3/generate_hf_predictions.py b/tools/idefics3/generate_hf_predictions.py new file mode 100644 index 00000000..73186545 --- /dev/null +++ b/tools/idefics3/generate_hf_predictions.py @@ -0,0 +1,630 @@ +""" +torchrun --nproc-per-node 1 tools/idefics3/generate_hf_predictions.py --pretrained-model-name-or-path HuggingFaceM4/Idefics3-8B-Llama3 +""" + +import argparse +import os +from typing import List, Optional +from PIL import Image + +import numpy as np +import requests +import torch + + +from transformers import AutoProcessor, AutoModelForVision2Seq +from transformers.image_utils import load_image +from transformers.modeling_flash_attention_utils import _flash_attention_forward + +DEVICE = torch.device("cuda") +TORCH_DTYPE = torch.bfloat16 + +messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": "What’s the difference between these two images?"}, + {"type": "image"}, + {"type": "image"}, + ], +}, +{ + "role": "assistant", + "content": [ + {"type": "text", "text": "The difference is that one image is about dogs and the other one about cats."}, + ], +}] + + +url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg" +url_2 = "http://images.cocodataset.org/val2017/000000219578.jpg" + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="HuggingFace Model") + group.add_argument( + "--pretrained-model-name-or-path", + type=str, + required=True, + help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo on the Hugging Face Hub", + ) + + args = parser.parse_args() + + return args + + +def forward_embedding( + self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + + patch_embeds = self.patch_embedding(pixel_values) + + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + +def forward_attn( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value = None, + output_attentions: bool = False, + use_cache: bool = False, + layer_idx: int = 0, + **kwargs, + ): + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (Idefics2VisionRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + # logger.warning_once( + # f"The input hidden states seems to be silently casted in float32, this might be related to" + # f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + # f" {target_dtype}." + # ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + +def forward_encoder_layer( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + layer_idx: int, + output_attentions: Optional[bool] = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + + hidden_states, attn_weights = forward_attn( + self.self_attn, + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + layer_idx=layer_idx, + ) + + hidden_states = residual + hidden_states + + + residual = hidden_states + + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + +# Ignore copy +def forward_encoder( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +): + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for i, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = forward_encoder_layer( + encoder_layer, + hidden_states, + attention_mask, + layer_idx=i, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return hidden_states, encoder_states, all_attentions + + +def forward_vision( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_size = self.patch_size + patch_attention_mask = torch.ones( + ( + batch_size, + pixel_values.size(2) // patch_size, + pixel_values.size(3) // patch_size, + ) + ) + patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device) + + hidden_states = forward_embedding(self.embeddings, pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + patch_attention_mask = None + elif not self._use_flash_attention_2: + patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) + + encoder_outputs = forward_encoder( + self.encoder, + inputs_embeds=hidden_states, + attention_mask=patch_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + if not return_dict: + return (last_hidden_state,) + encoder_outputs[1:] + + return last_hidden_state + +def forward_text_model( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + # logger.warning_once( + # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + # ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + # logger.warning_once( + # "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + # "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + # "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + # ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for i, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return hidden_states, next_cache, all_hidden_states, all_self_attns + + +def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_attention_mask: Optional[torch.BoolTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_seen_tokens = 0 + if use_cache: + if past_key_values is None: + past_key_values = DynamicCache() + past_seen_tokens = past_key_values.get_seq_length() + + if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: + raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") + + if inputs_embeds is None: + inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device) + + # START VISUAL INPUTS INTEGRATION + if pixel_values is not None and image_hidden_states is not None: + raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time") + elif pixel_values is not None: + batch_size, num_images, num_channels, height, width = pixel_values.shape + pixel_values = pixel_values.to(dtype=self.dtype) # fp16 compatibility + pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:]) + + # Remove padding images - padding images are full 0. + nb_values_per_image = pixel_values.shape[1:].numel() + real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image + pixel_values = pixel_values[real_images_inds].contiguous() + + # Handle the vision attention mask + if pixel_attention_mask is None: + pixel_attention_mask = torch.ones( + size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)), + dtype=torch.bool, + device=pixel_values.device, + ) + else: + # Remove padding images from the mask + pixel_attention_mask = pixel_attention_mask.view( + batch_size * num_images, *pixel_attention_mask.shape[2:] + ) + pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous() + + patch_size = self.config.vision_config.patch_size + patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size) + patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size) + patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + # Get sequence from the vision encoder + image_hidden_states = forward_vision( + self.vision_model, + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + + # Modality projection & resampling + image_hidden_states = self.connector(image_hidden_states) + + elif image_hidden_states is not None: + image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device) + + if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None: + # When we generate, we don't want to replace the potential image_token_id that we generated by images + # that simply don't exist + + inputs_embeds = self.inputs_merger( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + image_hidden_states=image_hidden_states, + ) + + outputs = forward_text_model( + self.text_model, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return tuple(v for v in [*outputs, image_hidden_states] if v is not None) + + return outputs + + + +def main(args): + + model = AutoModelForVision2Seq.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=TORCH_DTYPE, + attn_implementation="flash_attention_2", + device_map="auto", + ).eval() + + image_1 = Image.open(requests.get(url_1, stream=True).raw) + image_2 = Image.open(requests.get(url_2, stream=True).raw) + images = [image_1, image_2] + + processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3") + text = processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = processor(images=images, text=text, return_tensors="pt").to(DEVICE) + + with torch.no_grad(): + # output = model(**inputs) + + output = model.model( + use_cache=False, + **inputs + ) + + logits = model.lm_head(output.last_hidden_state) + + predicted_tokens = [5, 27, 34] # Index of the predictions to compare across models + term_cols = int(os.get_terminal_size().columns / 3) + + for predicted_token in predicted_tokens: + + print("\n", "=" * term_cols, f"Predictions of token {predicted_token}", "=" * term_cols) + next_tokens = torch.softmax(logits[0, predicted_token, :], -1) + topk_next_tokens = torch.topk(next_tokens, 10) + + print( + *[ + f"[HF Model] Next token: {idx.item()}, probability: {prob}" + for idx, prob in zip(topk_next_tokens.indices, topk_next_tokens.values) + ], + sep="\n", + ) + + # Compute accuracy + # predictions = np.argmax(output.logits.cpu(), axis=2).flatten().tolist() + # labels = tokens.cpu().flatten()[1:].tolist() + # print(f"\nAccuracy: {accuracy_score(labels, predictions)}") + + +if __name__ == "__main__": + _args = get_args() + main(_args) diff --git a/tools/idefics3/generate_nanotron_predictions.py b/tools/idefics3/generate_nanotron_predictions.py new file mode 100644 index 00000000..d2b52bb8 --- /dev/null +++ b/tools/idefics3/generate_nanotron_predictions.py @@ -0,0 +1,170 @@ +""" +torchrun --nproc-per-node 2 tools/idefics3/generate_nanotron_predictions.py --tp 2 --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Idefics3-8B-Llama3 +""" +import argparse +import os +from pathlib import Path + +import requests + +import nanotron.distributed as dist +import numpy as np +import torch +from nanotron.config import Config, ParallelismArgs, get_config_from_file +from nanotron.models import build_model +from nanotron.models.idefics import Idefics3ForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.parallel.pipeline_parallel.engine import AllForwardAllBackwardPipelineEngine +from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode +from nanotron.serialize import load_weights +from nanotron.trainer import mark_tied_parameters +# from sklearn.metrics import accuracy_score +from transformers import AutoTokenizer, AutoProcessor +from PIL import Image + + +messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": "What’s the difference between these two images?"}, + {"type": "image"}, + {"type": "image"}, + ], +}, +{ + "role": "assistant", + "content": [ + {"type": "text", "text": "The difference is that one image is about dogs and the other one about cats."}, + ], +}] + + +url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg" +url_2 = "http://images.cocodataset.org/val2017/000000219578.jpg" + +SEQ_LENGTH = 512 # For truncating the TXT if GPU can't fit too many tokens + +DEVICE = torch.device("cuda") +TORCH_DTYPE = torch.bfloat16 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory containing a Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="Nanotron Parallelism") + group.add_argument("--tp", type=int, required=True, help="Tensor Parallelism Degree of the Nanotron Checkpoint") + + args = parser.parse_args() + + return args + + +def main(args): + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs( + dp=1, + pp=1, + tp=args.tp, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + assert ( + parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE + and parallel_config.tp_linear_async_communication is False + ) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + RANK = dist.get_rank(parallel_context.world_pg) + + nanotron_config = get_config_from_file( + os.path.join(args.nanotron_checkpoint_path, "config.yaml"), config_class=Config, model_config_class=None + ) + + model = build_model( + model_builder=lambda: Idefics3ForTraining( + config=nanotron_config.model.model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, # TODO Check with different parallelism if cpu is available + ) + + + #torch.Size([484, 26, 768]) + + mark_tied_parameters(model=model, parallel_context=parallel_context) + sanity_check(root_module=model) + + # Load checkpoint directly in memory and then only keep the state dictionary + load_weights(model=model, parallel_context=parallel_context, root_folder=Path(args.nanotron_checkpoint_path)) + + + image_1 = Image.open(requests.get(url_1, stream=True).raw) + image_2 = Image.open(requests.get(url_2, stream=True).raw) + images = [image_1, image_2] + + # Using non-Idefics3 image size may break the pixel shuffle + # For example, instead of 384 you should use either 364 or 404 + image_size = nanotron_config.model.model_config.vision_config.image_size + + image_size = 364 + + target_image_seq_len = int(((image_size // nanotron_config.model.model_config.vision_config.patch_size) ** 2) / (nanotron_config.model.model_config.scale_factor**2)) + + processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3", image_seq_len=target_image_seq_len, size= {"longest_edge": 4*image_size}, max_image_size = {"longest_edge": image_size}) + + text = processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = processor(images=images, text=text, return_tensors="pt", image_seq_len=target_image_seq_len).to(DEVICE) + + inputs = { + "input_ids": inputs['input_ids'], + "input_mask": inputs['attention_mask'], + "pixel_values": inputs['pixel_values'].bfloat16(), + "pixel_attention_mask": inputs['pixel_attention_mask'], + } + + model.eval() + + with torch.no_grad(): + output = model.model(**inputs) + + if not RANK: + print(output.shape) + + predicted_tokens = [5, 27, 34] # Index of the predictions to compare across models + term_cols = int(os.get_terminal_size().columns / 3) + + for predicted_token in predicted_tokens: + + print("\n", "=" * term_cols, f"Predictions of token {predicted_token}", "=" * term_cols) + next_tokens = torch.softmax(output.transpose(0, 1)[0, predicted_token, :], -1) + + topk_next_tokens = torch.topk(next_tokens, 10) + + print( + *[ + f"[Nanotron Model] Next token: {idx.item()}, probability: {prob}" + for idx, prob in zip(topk_next_tokens.indices, topk_next_tokens.values) + ], + sep="\n", + ) + +if __name__ == "__main__": + _args = get_args() + main(_args) diff --git a/tools/idefics3/loss_on_captions_hf.py b/tools/idefics3/loss_on_captions_hf.py new file mode 100644 index 00000000..5e2ab2fc --- /dev/null +++ b/tools/idefics3/loss_on_captions_hf.py @@ -0,0 +1,132 @@ +""" +torchrun --nproc-per-node 1 tools/idefics3/loss_on_captions_hf.py --pretrained-model-name-or-path HuggingFaceM4/Idefics3-8B-Llama3 +""" + +import argparse +import os +from typing import List, Optional +from PIL import Image + +import numpy as np +import requests +import torch + +from transformers import AutoProcessor, AutoModelForVision2Seq +from transformers.image_utils import load_image +from transformers.modeling_flash_attention_utils import _flash_attention_forward + +from torch.utils.data import DataLoader + +from datasets import load_dataset + +DEVICE = torch.device("cuda") +TORCH_DTYPE = torch.bfloat16 + + +def caption_to_messages(caption): + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What do we see in this image?"}, + ] + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "This image shows: " + caption}, + ] + }, + ] + + return messages + +def check_image(image): + image = np.array(image) + if image.ndim == 2: + image = image[:, :, None] + return image + +def collate_fn(examples, processor): + captions = [ + processor.apply_chat_template(caption_to_messages(example["image_description"])) for example in examples + ] + images = [[check_image(example["image"])] for example in examples] + + inputs = processor(text=captions, images=images, return_tensors="pt", padding="max_length", max_length=2049, truncation=True, padding_side="right") + + input_ids = inputs["input_ids"][:, :-1] + attention_mask = inputs["attention_mask"][:, :-1] == 1 + label_ids = inputs["input_ids"][:, 1:] + label_mask = label_ids < processor.tokenizer.vocab_size + pixel_values = inputs["pixel_values"] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": label_ids, "label_mask": label_mask, "pixel_values": pixel_values} + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="HuggingFace Model") + group.add_argument( + "--pretrained-model-name-or-path", + type=str, + required=True, + help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo on the Hugging Face Hub", + ) + + args = parser.parse_args() + + return args + + +def main(args): + + model = AutoModelForVision2Seq.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=TORCH_DTYPE, + attn_implementation="flash_attention_2", + device_map="auto", + ).eval() + + + dataset = load_dataset("jmhessel/newyorker_caption_contest", 'explanation', split="validation[:100]") + processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3", size={"longest_edge": 2 * 364}) + + dataloader = DataLoader(dataset, batch_size=16, num_workers=16, collate_fn=lambda x: collate_fn(x, processor)) + + total_loss = 0 + total_acc = 0 + + for batch in dataloader: + inputs = {k: v.to(DEVICE) for k, v in batch.items()} + + with torch.no_grad(): + output = model( + use_cache=False, + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + pixel_values=inputs["pixel_values"], + ) + + logits = output[0] + label_mask = inputs["label_mask"] + labels = inputs["labels"] + + loss = torch.nn.functional.cross_entropy(logits[label_mask], labels[label_mask]) + total_loss += loss.item() + + acc = (logits.argmax(dim=-1)[label_mask] == labels[label_mask]).float().mean().item() + + total_acc += acc + + print(f"Average Loss: {total_loss / len(dataloader)}") + print(f"Average Accuracy: {total_acc / len(dataloader)}") + + # Average Loss: 2.1875 + # Average Accuracy: 0.6541961346353803 + + +if __name__ == "__main__": + _args = get_args() + main(_args) diff --git a/tools/idefics3/loss_on_captions_nanotron.py b/tools/idefics3/loss_on_captions_nanotron.py new file mode 100644 index 00000000..b426b1f2 --- /dev/null +++ b/tools/idefics3/loss_on_captions_nanotron.py @@ -0,0 +1,206 @@ +""" +torchrun --nproc-per-node 1 tools/idefics3/loss_on_captions_nanotron.py --tp 1 --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Idefics3-8B-Llama3 --dataset-path "../datasets/ny_captions.hf" +""" +import argparse +import os +from pathlib import Path + +import nanotron.distributed as dist +import torch.distributed as torch_dist +import numpy as np +import torch +from nanotron.config import Config, ParallelismArgs, get_config_from_file +from nanotron.models import build_model +from nanotron.models.idefics import Idefics3ForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.parallel.pipeline_parallel.engine import AllForwardAllBackwardPipelineEngine +from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode +from nanotron.serialize import load_weights +from nanotron.trainer import mark_tied_parameters +from transformers import AutoProcessor +from datasets import load_dataset, load_from_disk +from torch.utils.data import DataLoader +from tqdm.auto import tqdm + +DEVICE = torch.device("cuda") +TORCH_DTYPE = torch.bfloat16 + +def caption_to_messages(caption): + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What do we see in this image?"}, + ] + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "This image shows: " + caption}, + ] + }, + ] + + return messages + +def check_image(image): + image = np.array(image) + if image.ndim == 2: + image = image[:, :, None] + return image + +def collate_fn(examples, processor): + captions = [ + processor.apply_chat_template(caption_to_messages(example["image_description"])) for example in examples + ] + images = [[check_image(example["image"])] for example in examples] + + inputs = processor(text=captions, images=images, return_tensors="pt", padding="max_length", max_length=2049, truncation=True, padding_side="right") + + input_ids = inputs["input_ids"][:, :-1] + attention_mask = inputs["attention_mask"][:, :-1] == 1 + label_ids = inputs["input_ids"][:, 1:] + label_mask = label_ids < processor.tokenizer.vocab_size + pixel_values = inputs["pixel_values"] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": label_ids, "label_mask": label_mask, "pixel_values": pixel_values} + + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory containing a Nanotron Checkpoint", + ) + group = parser.add_argument_group(title="Dataset") + group.add_argument( + "--dataset-path", + type=str, + required=False, + help="A path to a directory containing the dataset", + ) + + group = parser.add_argument_group(title="Nanotron Parallelism") + group.add_argument("--tp", type=int, required=True, help="Tensor Parallelism Degree of the Nanotron Checkpoint") + + args = parser.parse_args() + + return args + +def main(args): + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs( + dp=1, + pp=1, + tp=args.tp, + pp_engine=AllForwardAllBackwardPipelineEngine(), + tp_mode=TensorParallelLinearMode.ALL_REDUCE, + tp_linear_async_communication=False, + ) + assert ( + parallel_config.tp_mode == TensorParallelLinearMode.ALL_REDUCE + and parallel_config.tp_linear_async_communication is False + ) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + RANK = dist.get_rank(parallel_context.world_pg) + + nanotron_config = get_config_from_file( + os.path.join(args.nanotron_checkpoint_path, "config.yaml"), config_class=Config, model_config_class=None + ) + + model = build_model( + model_builder=lambda: Idefics3ForTraining( + config=nanotron_config.model.model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, # TODO Check with different parallelism if cpu is available + ) + + mark_tied_parameters(model=model, parallel_context=parallel_context) + sanity_check(root_module=model) + + # Load checkpoint directly in memory and then only keep the state dictionary + load_weights(model=model, parallel_context=parallel_context, root_folder=Path(args.nanotron_checkpoint_path)) + + if args.dataset_path is not None: + dataset = load_from_disk(args.dataset_path) + else: + dataset = load_dataset("jmhessel/newyorker_caption_contest", 'explanation', split="validation[:100]") + + processor = AutoProcessor.from_pretrained("HuggingFaceM4/Idefics3-8B-Llama3", size={"longest_edge": 2 * 364}) + + dataloader = DataLoader(dataset, batch_size=4, num_workers=16, collate_fn=lambda x: collate_fn(x, processor)) + + total_loss = 0 + total_acc = 0 + n_samples = 0 + + def gather_logits(logits, parallel_context): + tp_pg = parallel_context.tp_pg + if tp_pg.size() == 1: + return logits + + sharded_shape = logits.shape + + tensor_list = [torch.empty(sharded_shape, device=logits.device, dtype=logits.dtype) for _ in range(tp_pg.size())] + + torch_dist.all_gather(tensor_list, logits, group=tp_pg) + + logits = torch.cat(tensor_list, dim=-1) + + return logits + + for batch in tqdm(dataloader): + inputs = {k: v.to(DEVICE) for k, v in batch.items()} + + with torch.no_grad(): + output = model.model( + input_ids=inputs["input_ids"], + input_mask=inputs["attention_mask"], + pixel_values=inputs["pixel_values"], + ) + + logits = gather_logits(output, parallel_context).transpose(0, 1) + + label_mask = inputs["label_mask"] + labels = inputs["labels"] + + loss = torch.nn.functional.cross_entropy(logits[label_mask], labels[label_mask]) + + acc = (logits.argmax(dim=-1)[label_mask] == labels[label_mask]).float().mean().item() + + if RANK == 0: + total_acc += acc + total_loss += loss.item() + + + if RANK == 0: + print(f"Average Loss: {total_loss / len(dataloader)}") + print(f"Average Accuracy: {total_acc / len(dataloader)}") + + # Average Loss: 2.1875 (HF) + # Average Loss: 2.112454278128488 (Nanotron TP=1) + # Average Loss: 2.112218448093959 (Nanotron TP=2) + + # Average Accuracy: 0.6541961346353803 (HF) + # Average Loss: 0.6715155754770551 (Nanotron TP=1) + # Average Loss: 0.6702071257999965 (Nanotron TP=2) + +if __name__ == "__main__": + _args = get_args() + main(_args) diff --git a/tools/llama3/convert_hf_to_nanotron.py b/tools/llama3/convert_hf_to_nanotron.py index e30610a3..561c3d72 100644 --- a/tools/llama3/convert_hf_to_nanotron.py +++ b/tools/llama3/convert_hf_to_nanotron.py @@ -51,82 +51,17 @@ def get_args(): return args - -def main(args): - # Init Nanotron Parallel Utilities - parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) - - parallel_context = ParallelContext( - data_parallel_size=parallel_config.dp, - pipeline_parallel_size=parallel_config.pp, - tensor_parallel_size=parallel_config.tp, - ) - - set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs()) - - # Load Llama3-8B HF model - log_rank( - f"Loading pretrained Llama3 Model: {args.pretrained_model_name_or_path}", - logger=logger, - level=logging.INFO, - rank=0, - ) - hf_model = AutoModelForCausalLM.from_pretrained( - args.pretrained_model_name_or_path, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2" - ).to(DEVICE) - hf_config = hf_model.config - - # Set Nanotron LlamaConfig - nanotron_llama_config = LlamaConfigNanotron( - bos_token_id=hf_config.bos_token_id, - eos_token_id=hf_config.eos_token_id, - hidden_act=hf_config.hidden_act, - hidden_size=hf_config.hidden_size, - initializer_range=hf_config.initializer_range, - intermediate_size=hf_config.intermediate_size, - is_llama_config=True, - max_position_embeddings=hf_config.max_position_embeddings, - num_attention_heads=hf_config.num_attention_heads, - num_hidden_layers=hf_config.num_hidden_layers, - num_key_value_heads=hf_config.num_key_value_heads, - pad_token_id=None, - pretraining_tp=hf_config.pretraining_tp, - rms_norm_eps=hf_config.rms_norm_eps, - rope_scaling=hf_config.rope_scaling, - rope_theta=hf_config.rope_theta, - rope_interleaved=False, - tie_word_embeddings=hf_config.tie_word_embeddings, - use_cache=hf_config.use_cache, - vocab_size=hf_config.vocab_size, - ) - - # Init Llama3-8B Nanotron model - log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0) - nanotron_model = build_model( - model_builder=lambda: LlamaForTraining( - config=nanotron_llama_config, - parallel_context=parallel_context, - parallel_config=parallel_config, - random_states=None, - ), - parallel_context=parallel_context, - dtype=TORCH_DTYPE, - device=DEVICE, - ) - - mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) - sanity_check(root_module=nanotron_model) - +def copy_weights_from_hf_to_nanotron(nanotron_model, hf_model, nanotron_llama_config): # Copy params from HF to Nanotron log_rank("Copying weights from HF model to Nanotron model...", logger=logger, level=logging.INFO, rank=0) # Token embeddings log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0) assert ( - nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape + nanotron_model.token_position_embeddings.pp_block.token_embedding.weight.shape == hf_model.model.embed_tokens.weight.shape ) with torch.no_grad(): - nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.copy_( + nanotron_model.token_position_embeddings.pp_block.token_embedding.weight.copy_( hf_model.model.embed_tokens.weight ) @@ -139,10 +74,10 @@ def main(args): # Input layer norm assert ( hf_model.model.layers[i].input_layernorm.weight.shape - == nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.shape + == nanotron_model.decoder[i].pp_block.input_layernorm.weight.shape ) with torch.no_grad(): - nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.copy_( + nanotron_model.decoder[i].pp_block.input_layernorm.weight.copy_( hf_model.model.layers[i].input_layernorm.weight ) @@ -156,17 +91,17 @@ def main(args): ], dim=0, ) - assert tmp_qkv_proj.shape == nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.shape + assert tmp_qkv_proj.shape == nanotron_model.decoder[i].pp_block.attn.qkv_proj.weight.shape with torch.no_grad(): - nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.copy_(tmp_qkv_proj) + nanotron_model.decoder[i].pp_block.attn.qkv_proj.weight.copy_(tmp_qkv_proj) ## O assert ( hf_model.model.layers[i].self_attn.o_proj.weight.shape - == nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.shape + == nanotron_model.decoder[i].pp_block.attn.o_proj.weight.shape ) with torch.no_grad(): - nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.copy_( + nanotron_model.decoder[i].pp_block.attn.o_proj.weight.copy_( hf_model.model.layers[i].self_attn.o_proj.weight ) @@ -180,41 +115,118 @@ def main(args): dim=0, ) - assert tmp_gate_up_proj.shape == nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight.shape + assert tmp_gate_up_proj.shape == nanotron_model.decoder[i].pp_block.mlp.gate_up_proj.weight.shape with torch.no_grad(): - nanotron_model.model.decoder[i].pp_block.mlp.gate_up_proj.weight.copy_(tmp_gate_up_proj) + nanotron_model.decoder[i].pp_block.mlp.gate_up_proj.weight.copy_(tmp_gate_up_proj) ## Down Proj assert ( hf_model.model.layers[i].mlp.down_proj.weight.shape - == nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.shape + == nanotron_model.decoder[i].pp_block.mlp.down_proj.weight.shape ) with torch.no_grad(): - nanotron_model.model.decoder[i].pp_block.mlp.down_proj.weight.copy_( + nanotron_model.decoder[i].pp_block.mlp.down_proj.weight.copy_( hf_model.model.layers[i].mlp.down_proj.weight ) # Post attn layer norm assert ( hf_model.model.layers[i].post_attention_layernorm.weight.shape - == nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.shape + == nanotron_model.decoder[i].pp_block.post_attention_layernorm.weight.shape ) with torch.no_grad(): - nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.copy_( + nanotron_model.decoder[i].pp_block.post_attention_layernorm.weight.copy_( hf_model.model.layers[i].post_attention_layernorm.weight ) - # Last layer norm - log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0) - assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape - with torch.no_grad(): - nanotron_model.model.final_layer_norm.pp_block.weight.copy_(hf_model.model.norm.weight) + # Last layer norm + log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape + with torch.no_grad(): + nanotron_model.final_layer_norm.pp_block.weight.copy_(hf_model.model.norm.weight) + + # LM_Head + log_rank("Copying LM Head...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape + with torch.no_grad(): + nanotron_model.lm_head.pp_block.weight.copy_(hf_model.lm_head.weight) + +def nanotron_config_from_hf_config(hf_config): + return LlamaConfigNanotron( + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + hidden_act=hf_config.hidden_act, + hidden_size=hf_config.hidden_size, + initializer_range=hf_config.initializer_range, + intermediate_size=hf_config.intermediate_size, + is_llama_config=True, + max_position_embeddings=hf_config.max_position_embeddings, + num_attention_heads=hf_config.num_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + num_key_value_heads=hf_config.num_key_value_heads, + pad_token_id=None, + pretraining_tp=hf_config.pretraining_tp, + rms_norm_eps=hf_config.rms_norm_eps, + rope_scaling=hf_config.rope_scaling, + rope_theta=hf_config.rope_theta, + rope_interleaved=False, + tie_word_embeddings=hf_config.tie_word_embeddings, + use_cache=hf_config.use_cache, + vocab_size=hf_config.vocab_size, + ) + +def main(args): + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs()) + + # Load Llama3-8B HF model + log_rank( + f"Loading pretrained Llama3 Model: {args.pretrained_model_name_or_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + hf_model = AutoModelForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2" + ).to(DEVICE) + hf_config = hf_model.config + + # Set Nanotron LlamaConfig + nanotron_llama_config = nanotron_config_from_hf_config(hf_config) + + # Init Llama3-8B Nanotron model + log_rank("Init empty Nanotron Llama3 Model", logger=logger, level=logging.INFO, rank=0) + nanotron_model = build_model( + model_builder=lambda: LlamaForTraining( + config=nanotron_llama_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, + ) + + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context) + sanity_check(root_module=nanotron_model) + + # Copy weights from HF to Nanotron + copy_weights_from_hf_to_nanotron( + nanotron_model=nanotron_model.model, + hf_model=hf_model, + nanotron_llama_config=nanotron_llama_config, + ) - # LM_Head - log_rank("Copying LM Head...", logger=logger, level=logging.INFO, rank=0) - assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape - with torch.no_grad(): - nanotron_model.model.lm_head.pp_block.weight.copy_(hf_model.lm_head.weight) log_rank("Copied weights from HF model to Nanotron model!", logger=logger, level=logging.INFO, rank=0) # Store weights diff --git a/tools/llama3/generate_hf_predictions.py b/tools/llama3/generate_hf_predictions.py index b16774a4..d83d02f9 100644 --- a/tools/llama3/generate_hf_predictions.py +++ b/tools/llama3/generate_hf_predictions.py @@ -65,9 +65,9 @@ def main(args): ) # Compute accuracy - predictions = np.argmax(output.logits.cpu(), axis=2).flatten().tolist() - labels = tokens.cpu().flatten()[1:].tolist() - print(f"\nAccuracy: {accuracy_score(labels, predictions)}") + # predictions = np.argmax(output.logits.cpu(), axis=2).flatten().tolist() + # labels = tokens.cpu().flatten()[1:].tolist() + # print(f"\nAccuracy: {accuracy_score(labels, predictions)}") # Results ## [TP=1] HF 8B: 0.8308823529411765 ## [TP=2]HF 70B: 0.8860294117647058 diff --git a/tools/llama3/generate_nanotron_predictions.py b/tools/llama3/generate_nanotron_predictions.py index fbede799..389354a0 100644 --- a/tools/llama3/generate_nanotron_predictions.py +++ b/tools/llama3/generate_nanotron_predictions.py @@ -1,5 +1,5 @@ """ -torchrun --nproc-per-node 2 tools/llama3/generate_nanotron_predictions.py --tp 2 --nanotron-checkpoint-path nanotron_checkpoints/NanotronLlama38B +torchrun --nproc-per-node 2 tools/llama3/generate_nanotron_predictions.py --tp 2 --nanotron-checkpoint-path nanotron_checkpoints/Nanotron-Llama-3-8B """ import argparse import os