From 4956922fe8d7654c03fc31aa4fed7b7d2bf53b49 Mon Sep 17 00:00:00 2001 From: Zhijian Liu <5782437+zhijian-liu@users.noreply.github.com> Date: Wed, 14 Aug 2024 19:40:27 -0400 Subject: [PATCH] Revert "Remove `transformer_replace`" (#161) --- environment_setup.sh | 1 + llava/model/language_model/llava_llama.py | 106 +- .../runtime/zero/mics.py | 689 +++ .../runtime/zero/partition_parameters.py | 2287 ++++++++ .../utils/groups.py | 585 ++ .../transformers_replace/modeling_utils.py | 4830 +++++++++++++++++ .../models/gemma/__init__.py | 111 + .../models/gemma/configuration_gemma.py | 145 + .../models/gemma/modeling_gemma.py | 1331 +++++ .../models/llama/configuring_llama.py | 188 + .../models/llama/modeling_llama.py | 1234 +++++ .../models/llama/tokenization_llama.py | 480 ++ .../models/mistral/__init__.py | 57 + .../models/mistral/configuration_mistral.py | 150 + .../models/mistral/modeling_mistral.py | 1376 +++++ .../models/mixtral/__init__.py | 57 + .../models/mixtral/configuration_mixtral.py | 167 + .../models/mixtral/modeling_mixtral.py | 1632 ++++++ tests/python/test_packing.py | 111 +- 19 files changed, 15463 insertions(+), 74 deletions(-) create mode 100644 llava/train/deepspeed_replace_deprecated/runtime/zero/mics.py create mode 100644 llava/train/deepspeed_replace_deprecated/runtime/zero/partition_parameters.py create mode 100644 llava/train/deepspeed_replace_deprecated/utils/groups.py create mode 100755 llava/train/transformers_replace/modeling_utils.py create mode 100755 llava/train/transformers_replace/models/gemma/__init__.py create mode 100755 llava/train/transformers_replace/models/gemma/configuration_gemma.py create mode 100755 llava/train/transformers_replace/models/gemma/modeling_gemma.py create mode 100755 llava/train/transformers_replace/models/llama/configuring_llama.py create mode 100755 llava/train/transformers_replace/models/llama/modeling_llama.py create mode 100755 llava/train/transformers_replace/models/llama/tokenization_llama.py create mode 100755 llava/train/transformers_replace/models/mistral/__init__.py create mode 100755 llava/train/transformers_replace/models/mistral/configuration_mistral.py create mode 100755 llava/train/transformers_replace/models/mistral/modeling_mistral.py create mode 100755 llava/train/transformers_replace/models/mixtral/__init__.py create mode 100755 llava/train/transformers_replace/models/mixtral/configuration_mixtral.py create mode 100755 llava/train/transformers_replace/models/mixtral/modeling_mixtral.py diff --git a/environment_setup.sh b/environment_setup.sh index bac05aa7..4de992fe 100755 --- a/environment_setup.sh +++ b/environment_setup.sh @@ -28,4 +28,5 @@ pip install -e ".[eval]" # Install HF's Transformers pip install git+https://github.com/huggingface/transformers@v4.37.2 site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])') +cp -rv ./llava/train/transformers_replace/* $site_pkg_path/transformers/ cp -rv ./llava/train/deepspeed_replace/* $site_pkg_path/deepspeed/ diff --git a/llava/model/language_model/llava_llama.py b/llava/model/language_model/llava_llama.py index de835f6b..28adb5f0 100755 --- a/llava/model/language_model/llava_llama.py +++ b/llava/model/language_model/llava_llama.py @@ -15,40 +15,28 @@ # This file is modified from https://github.com/haotian-liu/LLaVA/ +import inspect import os -from functools import partial from typing import List, Optional, Tuple, Union -from unittest.mock import patch import torch -from torch.nn import functional as F from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast -from llava.utils import distributed as dist - from ...train.utils import calculate_loss_weight from ..configuration_llava import LlavaConfig from ..llava_arch import LlavaMetaForCausalLM, LlavaMetaModel -def _get_unpad_data(attention_mask, _seqlens_in_batch, *args, **kwargs): - indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() - max_seqlen_in_batch = _seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(_seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) - return (indices, cu_seqlens, max_seqlen_in_batch) - - class LlavaLlamaConfig(LlavaConfig): model_type = "llava_llama" -# FIXME we will follow the convention to add a new class for CausalLM in the future +## FIXME we will follow the convention to add a new class for CausalLM in the future class LlavaLlamaModel(LlavaMetaModel, LlavaMetaForCausalLM, PreTrainedModel): config_class = LlavaLlamaConfig main_input_name = "input_embeds" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True def __init__(self, config: LlavaLlamaConfig = None, *args, **kwargs) -> None: super().__init__(config) @@ -104,15 +92,16 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + seqlens_in_batch: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, - packing: bool = True, - seqlens_in_batch: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, dpo_forward: bool = False, - **kwargs, ) -> Union[Tuple, CausalLMOutputWithPast]: self.freezed_module_patch() - if inputs_embeds is None: ( input_ids, @@ -125,50 +114,71 @@ def forward( input_ids, position_ids, attention_mask, past_key_values, labels, images ) - if packing and self.training and not dpo_forward: + support_packing = "seqlens_in_batch" in inspect.signature(self.llm.forward).parameters + + if self.training and support_packing and not dpo_forward: ( _, + new_position_ids, + new_attention_mask, + _, + new_inputs_embeds, + new_labels, + sorted_seqlens_in_batch, + ) = self.repack_multimodal_data( + input_ids, position_ids, attention_mask, - _, + past_key_values, inputs_embeds, labels, - _seqlens_in_batch, - ) = self.repack_multimodal_data( - input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels ) - if _seqlens_in_batch is None: - _seqlens_in_batch = seqlens_in_batch - - wraps = partial(_get_unpad_data, _seqlens_in_batch=_seqlens_in_batch) - with patch(self.llm.__module__ + "._get_unpad_data", wraps=wraps) as m: - outputs = self.llm( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - labels=labels, - **kwargs, - ) - if 0 in attention_mask and m.call_count == 0: - raise ValueError("The forward function of the model should call `_get_unpad_data` function.") + if sorted_seqlens_in_batch is None: + sorted_seqlens_in_batch = seqlens_in_batch + new_input_ids = None + past_key_values = None else: - outputs = self.llm( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, + new_attention_mask = attention_mask + new_position_ids = position_ids + new_inputs_embeds = inputs_embeds + new_labels = labels + sorted_seqlens_in_batch = attention_mask.sum(-1).int() + new_input_ids = input_ids + + if support_packing: + outputs = self.llm.forward( + input_ids=new_input_ids, + attention_mask=new_attention_mask, + position_ids=new_position_ids, past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - labels=labels, - **kwargs, + inputs_embeds=new_inputs_embeds, + labels=new_labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + seqlens_in_batch=sorted_seqlens_in_batch, + ) + else: + outputs = self.llm.forward( + input_ids=new_input_ids, + attention_mask=new_attention_mask, + position_ids=new_position_ids, + past_key_values=past_key_values, + inputs_embeds=new_inputs_embeds, + labels=new_labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, ) # Loss rescale for SP & DP loss match - if dist.size() > 1: - loss_weight = calculate_loss_weight(labels) - outputs.loss = outputs.loss * loss_weight + loss_weight = calculate_loss_weight(new_labels) + outputs.loss = outputs.loss * loss_weight if dpo_forward: - return outputs.logits, labels + return outputs.logits, new_labels return outputs diff --git a/llava/train/deepspeed_replace_deprecated/runtime/zero/mics.py b/llava/train/deepspeed_replace_deprecated/runtime/zero/mics.py new file mode 100644 index 00000000..529f8578 --- /dev/null +++ b/llava/train/deepspeed_replace_deprecated/runtime/zero/mics.py @@ -0,0 +1,689 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import sys +from typing import List + +import deepspeed +import torch +from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.runtime.zero.mics_utils import MiCS_CommGroups, create_mics_comm_groups, scale_tensors +from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload +from deepspeed.runtime.zero.partition_parameters import AllGatherCoalescedHandle, Init, ZeroParamStatus +from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3 +from deepspeed.runtime.zero.utils import is_zero_param +from deepspeed.utils import instrument_w_nvtx, log_dist +from torch import Tensor +from torch.nn import Parameter + + +def has_hierarchical_all_gather_groups(comm_groups: MiCS_CommGroups): + result = False + if comm_groups.param_intra_node_group is not None and comm_groups.param_inter_node_shard_group is not None: + result = True + return result + + +def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None): + return instrument_w_nvtx(dist.allgather_fn)(output_tensor, input_tensor, group=group, async_op=True) + + +class MiCS_AllGatherCoalescedHandle(AllGatherCoalescedHandle): + """This handle assumes that no need to + copy data out from a contiguous tensor + """ + + def __init__(self, allgather_handle, params: List[Parameter], partitions: List[Tensor], world_size: int) -> None: + super().__init__(allgather_handle, params, partitions, world_size) + + def wait(self) -> None: + """ """ + # let the current stream to op + try: + # print("HANDLE", self.allgather_handle) + instrument_w_nvtx(self.allgather_handle.wait)() + except (ValueError, RuntimeError) as e: + log_dist( + f"WARNING: Runtime Error while waiting the collective all-gather, possibly due to the _IllegalWork", + ranks=[0], + ) + log_dist(f"Error message: {e}", ranks=[0]) + + if self.complete: + return + + for _, param in enumerate(self.params): + assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" + param.ds_status = ZeroParamStatus.AVAILABLE + + self.complete = True + + +class MiCS_Init(Init): + def __init__( + self, + module=None, + data_parallel_group=None, + sequence_data_parallel_group=None, + mem_efficient_linear=True, + remote_device=None, + pin_memory=False, + config_dict_or_path=None, + config=None, + enabled=True, + dtype=None, + mpu=None, + ): + """A context manager to partition the model parameters during the model + construction with MiCS partition strategy. Model states are partitioned + to the number of devices specified via ``mics_shard_size`` field in the + deepspeed config json file. The context manager also introduces + hierarchical communication method to reduce the cost of inter-node + communications, which can be enabled with + ``mics_hierarchical_params_gather`` field in deepspeed config. + + Args: + module (``torch.nn.Module``, optional): If provided, partition the model as + if it was constructed in the context. + data_parallel_group (``deepspeed.comm`` process group, optional): + The group of processes to partition among. Defaults to all processes. + mem_efficient_linear (bool, optional): Replace + torch.nn.functional.linear with an implementation that allows + DeepSpeed to partition parameters. Defaults to ``True``. + remote_device (string, optional): The initial device to store model + weights e.g., ``cpu``, ``nvme``. Passing ``"cpu"`` will create the model in CPU + memory. The model may still be moved to GPU based on the + offload settings for training. Defaults to param offload device if a config is + defined, otherwise GPU. + pin_memory (bool, optional): Potentially increase performance by + using pinned memory for model weights. ``remote_device`` must be + ``"cpu"``. Defaults to pin_memory value in config, otherwise ``False``. + config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration + for swapping fp16 params to NVMe. + config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead. + enabled (bool, optional): If ``False``, this context has no + effect. Defaults to ``True``. + dtype (``dtype``, optional): Can be used to change the data type of the parameters. + Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None`` + mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}. + + This context follows the same logic as ``deepspeed.zero.Init()``, but + with the modification for partition size of each parameter. + + Examples + -------- + + #. Allocate a model and partition it among all processes: + + .. code-block:: python + # the config_dict_or_path is required to let the context manager know + # how partition the parameters. + # The configuration has to include the field ``mics_shard_size`` + with deepspeed.zero.MiCS_Init(config_dict_or_path=ds_config): + model = MyLargeModel() + + + #. Allocate a model in pinned CPU memory and partition it among a subgroup of processes: + + .. code-block:: python + + with deepspeed.zero.MiCS_Init(data_parallel_group=mpu.get_data_parallel_group(), + remote_device="cpu", + pin_memory=True + config_dict_or_path=ds_config): + model = MyLargeModel() + + + #. Partition an already-allocated model in CPU memory: + + .. code-block:: python + + model = deepspeed.zero.MiCS_Init(module=model, + config_dict_or_path=ds_config) + """ + + assert config_dict_or_path is not None, "Must provide configuration for MiCS Initialization" + _ds_config = deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path, mpu) + if not dist.is_initialized(): + dist.init_distributed() + assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm" + + if data_parallel_group is None and sequence_data_parallel_group is None: + ds_process_group = dist.get_world_group() + elif sequence_data_parallel_group is not None: + ds_process_group = sequence_data_parallel_group + elif data_parallel_group is not None: + ds_process_group = data_parallel_group + else: # both given + raise ValueError( + "Both 'data_parallel_group' and 'sequence_data_parallel_group' were specified. Please provide only one of these arguments." + ) + + self.mics_comm_groups = create_mics_comm_groups( + _ds_config.mics_shard_size, + ds_process_group, + hierarchical_allgather=_ds_config.mics_hierarchial_params_gather, + mpu=mpu, + ) + + super().__init__( + module, + data_parallel_group, + mem_efficient_linear, + remote_device, + pin_memory, + config_dict_or_path, + config, + enabled, + dtype, + mpu, + ) + + def _convert_to_deepspeed_param(self, param): + super()._convert_to_deepspeed_param(param) + # attach communication groups to every param + param.comm = self.mics_comm_groups + + # record existing all_gather_coalesced implementation + # so that we can fallback later + old_all_gather_coalesced = param.all_gather_coalesced + + def _param_all_gather_coalesced(params, param_buffers=None, **kwargs): + """""" + mics_comm_groups: MiCS_CommGroups = params[0].comm + hierarchical_all_gather = has_hierarchical_all_gather_groups(mics_comm_groups) + if dist.has_coalescing_manager() and hierarchical_all_gather: + return self._hierarchical_all_gather_params(params, param_buffers) + elif dist.has_coalescing_manager(): + return self._flat_all_gather_with_coalescing_manager(params, param_buffers) + else: + return old_all_gather_coalesced(params, **kwargs) + + # change the all_gather_coalesced method + param.all_gather_coalesced = _param_all_gather_coalesced + + def _pre_all_gather(self, params, params_buffers=None): + # fetches from nvme if the partition is not available and in nvme + self._ensure_availability_of_partitioned_params(params) + + for param in params: + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(param.ds_summary()) + param.ds_status = ZeroParamStatus.INFLIGHT + + # ensure that each rank has params in same order. the allgather + # is done by flattening the parameter list into a single tensor that + # can be allgathered in a single call - this means that if each rank + # gives a list of the same parameters in a different order we will + # silently get incorrect parameter values, and have very difficult + # to debug correctness issues. + params = sorted(params, key=lambda p: p.ds_id) + return params, params_buffers + + def _flat_all_gather_with_coalescing_manager(self, params, params_buffers=None): + """""" + params, params_buffers = self._pre_all_gather(params, params_buffers) + mics_comm_groups: MiCS_CommGroups = params[0].comm + param_shard_size = mics_comm_groups.param_shard_size + rank_in_group = mics_comm_groups.param_shard_rank + partition_sz = sum(p.ds_tensor.ds_numel for p in params) + + output_tensors = [] + input_tensors = [] + for i, p in enumerate(params): + t_size = p.ds_tensor.ds_numel * param_shard_size + if params_buffers is not None and params_buffers[i] is not None: + assert ( + params_buffers[i].numel() == t_size + ), f"params_to_gather_buffers[{i}] size {params_buffers[i].numel()} does not match with t_size {t_size}" + flat_out = params_buffers[i] + else: + flat_out = torch.empty(t_size, dtype=p.dtype, device=self.local_device, requires_grad=False).view(-1) + # flat_out = torch.zeros(t_size, dtype=p.dtype, device=self.local_device).view(-1) + output_tensors.append(flat_out) + _flat_input = p.ds_tensor.data.view(-1) + input_tensors.append(_flat_input) + + input_tensor = torch.cat(input_tensors, dim=0) + flat_tensor = torch.cat(output_tensors, dim=0) + partitions: List[Parameter] = [] + for i in range(param_shard_size): + partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz)) + instrument_w_nvtx(torch.cat)( + [p.ds_tensor.to(get_accelerator().current_device_name()) for p in params], out=partitions[rank_in_group] + ) + # Ensure all gather output size is correct + assert partitions[rank_in_group].numel() * param_shard_size == flat_tensor.numel() + handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, mics_comm_groups.param_shard_group) + + # Clean the buffer after communication + # torch.cuda.empty_cache() + + return AllGatherCoalescedHandle( + allgather_handle=handle, + params=params, + partitions=partitions, + world_size=param_shard_size, + ) + + # flat_tensor = torch.empty(partition_sz * world_size, + # dtype=dtype, + # device=get_accelerator().current_device_name(), + # requires_grad=False) + + # # must have to change the status of the param + # # and ensure they are on the device + # params, params_buffers = self._pre_all_gather(params, params_buffers) + + # mics_comm_groups: MiCS_CommGroups = params[0].comm + # param_shard_size = mics_comm_groups.param_shard_size + + # output_tensors = [] + # input_tensors = [] + # for i, p in enumerate(params): + # t_size = p.ds_tensor.ds_numel * param_shard_size + # if params_buffers is not None and params_buffers[i] is not None: + # assert params_buffers[i].numel( + # ) == t_size, f'params_to_gather_buffers[{i}] size {params_buffers[i].numel()} does not match with t_size {t_size}' + # flat_out = params_buffers[i] + # else: + # # flat_out = torch.empty(t_size, dtype=p.dtype, device=self.local_device, requires_grad=False).view(-1) + # flat_out = torch.zeros(t_size, dtype=p.dtype, device=self.local_device).view(-1) + # output_tensors.append(flat_out) + # _flat_input = p.ds_tensor.data.view(-1) + # input_tensors.append(_flat_input) + + # # all_gather_handle = dist.all_gather_coalesced(output_tensors, + # # input_tensors, + # # group=mics_comm_groups.param_shard_group, + # # async_op=True) + # # print("ALL GATHER COALESCED: output_tensors shape", torch.cat(output_tensors, dim=0).shape, "input_tensors shape", torch.cat(input_tensors, dim=0)) + # input_tensor = torch.cat(input_tensors, dim=0) + # flat_tensor = torch.cat(output_tensors, dim=0) + # partition_sz = sum(p.ds_tensor.ds_numel for p in params) + # # flat_tensor = torch.empty(partition_sz * param_shard_size, + # # dtype=dtype, + # # device=get_accelerator().current_device_name(), + # # requires_grad=False) + # partitions: List[Parameter] = [] + # for i in range(param_shard_size): + # partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz)) + + # all_gather_handle = dist.all_gather_into_tensor( + # input_tensor, + # flat_tensor, + # group=mics_comm_groups.param_shard_group, + # async_op=True) + + # for idx, param in enumerate(params): + # param.data = output_tensors[idx].narrow(0, 0, param.ds_numel).view(param.ds_shape).data + + # # return MiCS_AllGatherCoalescedHandle(allgather_handle=all_gather_handle, + # # params=params, + # # partitions=[], + # # world_size=param_shard_size) + + # return AllGatherCoalescedHandle(allgather_handle=all_gather_handle, + # params=params, + # partitions=partitions, + # world_size=param_shard_size) + + def _hierarchical_all_gather_params(self, params, params_buffers=None): + """""" + + # NOTE: This function is not used in the current implementation. Buggy. + params, params_buffers = self._pre_all_gather(params, params_buffers) + + mics_comm_groups: MiCS_CommGroups = params[0].comm + inter_node_comm_group = mics_comm_groups.param_inter_node_shard_group + intra_node_comm_group = mics_comm_groups.param_intra_node_group + local_rank = dist.get_rank(group=mics_comm_groups.param_intra_node_group) + inter_rank = dist.get_rank(group=mics_comm_groups.param_inter_node_shard_group) + intra_rank = dist.get_rank(group=mics_comm_groups.param_intra_node_group) + + param_shard_size = mics_comm_groups.param_shard_size + # rank_in_group = mics_comm_groups.param_shard_rank + partition_sz = sum(p.ds_tensor.ds_numel for p in params) + + inter_node_size = dist.get_world_size(group=inter_node_comm_group) + intra_node_size = dist.get_world_size(group=intra_node_comm_group) + param_tensors = [] + for i, p in enumerate(params): + param_size = p.ds_tensor.ds_numel * param_shard_size + if params_buffers is not None and params_buffers[i] is not None: + assert ( + params_buffers[i].numel() == param_size + ), f"param_buffers[{i}] size {params_buffers[i].numel()} does not match with param_size {param_size}" + param_tensor = params_buffers[i] + else: + param_tensor = torch.empty( + param_size, dtype=p.dtype, device=self.local_device, requires_grad=False + ).view(-1) + param_tensors.append(param_tensor) + # param_tensor = torch.cat(param_tensors, dim=0) + + # inter node all-gather + inter_outputs = [] + inter_inputs = [] + for i, p in enumerate(params): + inter_size = p.ds_tensor.ds_numel * inter_node_size + _out = param_tensors[i].narrow(0, local_rank * inter_size, inter_size) + inter_outputs.append(_out) + inter_inputs.append(p.ds_tensor.data.view(-1).to(self.local_device)) + # inter_partition = sum(p.ds_tensor.ds_numel for p in params) + + inter_input_tensor = torch.cat(inter_inputs, dim=0) + inter_output_tensor = torch.cat(inter_outputs, dim=0) + inter_partitions: List[Parameter] = [] + # inter_partitions = inter_inputs + # print("inter_rank", inter_rank, "inter_node_size", inter_node_size, "intra_rank", intra_rank, "intra_node_size", intra_node_size, "partition_sz", partition_sz,) + for i in range(inter_node_size): + inter_partitions.append(inter_output_tensor.narrow(0, partition_sz * i, partition_sz)) + instrument_w_nvtx(torch.cat)( + [p.ds_tensor.to(get_accelerator().current_device_name()) for p in params], out=inter_partitions[inter_rank] + ) + # sync enqueue + dist.allgather_fn(inter_output_tensor, inter_partitions[inter_rank], group=inter_node_comm_group, async_op=True) + + # intra node all-gather + intra_outputs = [] + intra_inputs = [] + for i, p in enumerate(params): + # partition param into multiple chunks for allgather + # because inter-node all-gather outputs are in a continues memory + # while in param memory, those inter-node data are placed in different + # location. + # each chunk is an intra-node output + param_chunk = ( + param_tensors[i].view((inter_node_size, intra_node_size, p.ds_tensor.ds_numel)).narrow(1, local_rank, 1) + ) + param_chunk.copy_(inter_output_list[i].detach().clone().view(param_chunk.size())) + output_chunks = torch.chunk(param_tensors[i], inter_node_size) + for j, _out in enumerate(output_chunks): + intra_chunk_size = intra_node_size * p.ds_tensor.ds_numel + local_offset = local_rank * p.ds_tensor.ds_numel + _in = param_tensors[i].narrow(0, j * intra_chunk_size + local_offset, p.ds_tensor.ds_numel) + intra_outputs.append(_out) + intra_inputs.append(_in) + + intra_input_tensor = torch.cat(intra_inputs, dim=0) + intra_output_tensor = torch.cat(intra_outputs, dim=0) + intra_partitions: List[Parameter] = [] + for i in range(intra_node_size): + intra_partitions.append( + intra_output_tensor.narrow(0, partition_sz * inter_node_size * i, partition_sz * inter_node_size) + ) + # for i in range(intra_node_size): + # intra_partitions.append(intra_output_tensor.narrow(0, partition_sz * inter_node_size * i, partition_sz * inter_node_size)) + # instrument_w_nvtx(torch.cat)([p.ds_tensor.to(get_accelerator().current_device_name()) for p in params], + # out=intra_partitions[intra_rank]) + + assert intra_partitions[intra_rank].numel() * intra_node_size == intra_output_tensor.numel() + handle = _dist_allgather_fn(intra_partitions[intra_rank], intra_output_tensor, intra_node_comm_group) + + return AllGatherCoalescedHandle( + allgather_handle=handle, + params=params, + partitions=intra_partitions, + world_size=intra_node_size, + ) + + # for i, param in enumerate(params): + # param.data = param_tensors[i].narrow(0, 0, param.ds_numel).view(param.ds_shape).data + + # return MiCS_AllGatherCoalescedHandle( + # allgather_handle=all_gather_handle, + # params=params, + # partitions=[], + # world_size=param_shard_size, + # ) + + # inter_node_size = dist.get_world_size(group=inter_node_comm_group) + # intra_node_size = dist.get_world_size(group=intra_node_comm_group) + # param_tensors = [] + # for i, p in enumerate(params): + # param_size = p.ds_tensor.ds_numel * param_shard_size + # if params_buffers is not None and params_buffers[i] is not None: + # assert params_buffers[i].numel( + # ) == param_size, f'param_buffers[{i}] size {params_buffers[i].numel()} does not match with param_size {param_size}' + # param_tensor = params_buffers[i] + # else: + # param_tensor = torch.empty(param_size, dtype=p.dtype, device=self.local_device, + # requires_grad=False).view(-1) + # param_tensors.append(param_tensor) + + # # inter node all-gather + # inter_outputs = [] + # inter_inputs = [] + # for i, p in enumerate(params): + # inter_size = p.ds_tensor.ds_numel * inter_node_size + # _out = param_tensors[i].narrow(0, local_rank * inter_size, inter_size) + # inter_outputs.append(_out) + # inter_inputs.append(p.ds_tensor.data.view(-1).to(self.local_device)) + # # sync enqueue + # # dist.all_gather_coalesced(inter_outputs, inter_inputs, group=inter_node_comm_group, async_op=False) + # dist.all_gather(inter_outputs, inter_inputs, group=inter_node_comm_group, async_op=False) + + # # intra node all-gather + # intra_outputs = [] + # intra_inputs = [] + # for i, p in enumerate(params): + # # partition param into multiple chunks for allgather + # # because inter-node all-gather outputs are in a continues memory + # # while in param memory, those inter-node data are placed in different + # # location. + # # each chunk is an intra-node output + # param_chunk = param_tensors[i].view( + # (inter_node_size, intra_node_size, p.ds_tensor.ds_numel)).narrow(1, local_rank, 1) + # param_chunk.copy_(inter_outputs[i].detach().clone().view(param_chunk.size())) + # output_chunks = torch.chunk(param_tensors[i], inter_node_size) + # for j, _out in enumerate(output_chunks): + # intra_chunk_size = intra_node_size * p.ds_tensor.ds_numel + # local_offset = local_rank * p.ds_tensor.ds_numel + # _in = param_tensors[i].narrow(0, j * intra_chunk_size + local_offset, p.ds_tensor.ds_numel) + # intra_outputs.append(_out) + # intra_inputs.append(_in) + + # # all_gather_handle = dist.all_gather_coalesced(intra_outputs, + # # intra_inputs, + # # group=intra_node_comm_group, + # # async_op=True) + # all_gather_handle = dist.all_gather(intra_outputs, + # intra_inputs, + # group=intra_node_comm_group, + # async_op=True) + # for i, param in enumerate(params): + # param.data = param_tensors[i].narrow(0, 0, param.ds_numel).view(param.ds_shape).data + + # return MiCS_AllGatherCoalescedHandle( + # allgather_handle=all_gather_handle, + # params=params, + # partitions=[], + # world_size=param_shard_size, + # ) + + def get_partition_dp_group(self, param): + return param.comm.param_shard_group + + def get_partition_rank(self): + return self.mics_comm_groups.param_shard_rank + + @property + def num_partitions(self): + return self.mics_comm_groups.param_shard_size + + +class MiCS_Offload(DeepSpeedZeRoOffload): + """Wrapper to change the behavior for parameter sharding""" + + def _convert_to_zero_parameters(self, ds_config, module, mpu): + """overload the parent class function for convert the parameters""" + log_dist(f"Convert to zero parameters from MiCS Offload manager", ranks=[0]) + non_zero_params = [p for p in module.parameters() if not is_zero_param(p)] + if non_zero_params: + zero_params = [p for p in module.parameters() if is_zero_param(p)] + if zero_params: + zero_params[0].convert_to_zero_parameters(param_list=non_zero_params) + else: + group = None + if mpu: + group = mpu.get_data_parallel_group() + + MiCS_Init( + module=module, + data_parallel_group=group, + dtype=self.dtype, + config_dict_or_path=ds_config, + remote_device=self.offload_device, + pin_memory=self.offload_param_pin_memory, + mpu=mpu, + ) + + +class MiCS_Optimizer(DeepSpeedZeroOptimizer_Stage3): + """ + MiCS Optimizer + """ + + def __init__( + self, + module, + init_optimizer, + timers, + ds_config, + static_loss_scale=1, + dynamic_loss_scale=False, + dynamic_loss_args=None, + verbose=True, + contiguous_gradients=True, + reduce_bucket_size=500000000, + prefetch_bucket_size=50000000, + max_reuse_distance=1000000000, + max_live_parameters=1000000000, + param_persistence_threshold=100000, + model_persistence_threshold=sys.maxsize, + dp_process_group=None, + reduce_scatter=True, + overlap_comm=False, + offload_optimizer_config=None, + offload_param_config=None, + sub_group_size=1000000000000, + offload_ratio=0.0, + mpu=None, + clip_grad=0, + gradient_accumulation_dtype=torch.float16, + communication_data_type=torch.float16, + postscale_gradients=True, + gradient_predivide_factor=1, + gradient_accumulation_steps=1, + elastic_checkpoint=False, + aio_config=None, + ): + + log_dist("Init MiCS optimizer", ranks=[0]) + super().__init__( + module, + init_optimizer, + timers, + ds_config, + static_loss_scale, + dynamic_loss_scale, + dynamic_loss_args, + verbose, + contiguous_gradients, + reduce_bucket_size, + prefetch_bucket_size, + max_reuse_distance, + max_live_parameters, + param_persistence_threshold, + model_persistence_threshold, + dp_process_group, + reduce_scatter, + overlap_comm, + offload_optimizer_config, + offload_param_config, + sub_group_size, + offload_ratio, + mpu, + clip_grad, + gradient_accumulation_dtype, + communication_data_type, + postscale_gradients, + gradient_predivide_factor, + gradient_accumulation_steps, + elastic_checkpoint, + aio_config, + ) + first_param = next(module.parameters()) + # overload the dp_process_group and partition_count + assert hasattr(first_param, "comm"), " ".join( + [ + "Sharded parameters don't have the MiCS_CommGroups attached.", + "Might due to the use of deepspeed.zero.Init context for initializing the weights.", + "To use MiCS sharding, please use deepspeed.zero.MiCS_Init instead for initializing parameter.", + ] + ) + self.dp_process_group = first_param.comm.param_shard_group + self.partition_count = first_param.comm.param_shard_size + + def initialize_ds_offload( + self, + *args, + **kwargs, + ): + return MiCS_Offload(*args, **kwargs) + + def partition_grads(self, params_to_release: List[Parameter], grad_partitions: List[Tensor]) -> None: + grad_buffers = super().partition_grads(params_to_release, grad_partitions) + # perform all-reduce among replication groups + # the function will perform accumulation boundary check + self.allreduce_mics_shard_grads(params_to_release, grad_buffers) + + @instrument_w_nvtx + def allreduce_mics_shard_grads(self, params, partitioned_grads_buffers: List[Tensor]): + """ """ + # TODO: improve the condition check + if not self.is_gradient_accumulation_boundary or len(partitioned_grads_buffers) == 0: + return + + mics_comm_groups: MiCS_CommGroups = params[0].comm + param_repli_group = mics_comm_groups.param_repli_group + param_repli_size = mics_comm_groups.param_repli_size + + if param_repli_size is None or param_repli_size <= 1: + return + if not get_accelerator().on_accelerator(partitioned_grads_buffers[0]): + raise RuntimeError("Local sharding has no support for CPU offloading") + + if dist.has_all_reduce_coalesced(): + scale_tensors(partitioned_grads_buffers, param_repli_size) + dist.all_reduce_coalesced(tensors=partitioned_grads_buffers, group=param_repli_group) + else: + # manually coalescing all-reduce + aggregated_buffer: Tensor = torch.cat(partitioned_grads_buffers) + aggregated_buffer.div_(param_repli_size) + dist.all_reduce(aggregated_buffer, group=param_repli_group) + offset = 0 + for grad_buff in partitioned_grads_buffers: + grad_buff.view(-1).copy_(aggregated_buffer.narrow(0, offset, grad_buff.numel())) + offset += grad_buff.numel() + + def load_state_dict( + self, + state_dict_list, + load_optimizer_states=True, + load_from_fp32_weights=False, + checkpoint_folder=None, + load_serial=None, + ): + r"""Loading the ZeRO-3/MiCS partitioned checkpoints + Because the self.dp_process_group is replaced with the communicator for + partition group we can call the load_state_dict logic from ZeRO-3. + """ + super().load_state_dict(state_dict_list, load_optimizer_states, load_from_fp32_weights, checkpoint_folder) diff --git a/llava/train/deepspeed_replace_deprecated/runtime/zero/partition_parameters.py b/llava/train/deepspeed_replace_deprecated/runtime/zero/partition_parameters.py new file mode 100644 index 00000000..a24c491d --- /dev/null +++ b/llava/train/deepspeed_replace_deprecated/runtime/zero/partition_parameters.py @@ -0,0 +1,2287 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import functools +import itertools +import logging +import math +import os +import types +from collections import defaultdict +from enum import Enum +from typing import Callable, Iterable, List + +import deepspeed +import torch +from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.comm.comm import init_distributed +from deepspeed.inference.quantization.utils import ( + WEIGHT_QUANTIZATION_LAYERS, + _quantize_param, + wrap_load_from_state_dict, + wrap_quantized_functional, +) +from deepspeed.runtime.config_utils import get_config_default +from deepspeed.runtime.zero.config import DeepSpeedZeroConfig +from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum +from deepspeed.runtime.zero.utils import assert_ints_same_as_other_ranks, is_zero_param +from deepspeed.utils import groups, instrument_w_nvtx, logger +from deepspeed.utils.debug import ( + debug_module2name, + debug_param2name_id, + debug_param2name_id_shape, + debug_param2name_id_shape_device, + debug_param2name_id_shape_status, +) +from torch import Tensor +from torch.nn import Module, Parameter + +from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus +from ..utils import see_memory_usage +from .linear import zero3_linear_wrap + +partitioned_param_data_shape = [0] +zero_init_context = 0 +top_level_context = None + + +class NoGatherHandle: + def __init__(self, param: Parameter) -> None: + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"expected param {param.ds_summary()} to be available") + + if hasattr(param.ds_tensor, "ds_quant_scale"): + param.data = ( + Init.quantizer_module.dequantize(param.ds_tensor.data, param.ds_tensor.ds_quant_scale) + .to(device=get_accelerator().current_device_name(), non_blocking=True) + .view(param.ds_shape) + ) + else: + param.data = param.ds_tensor.data.to( + device=get_accelerator().current_device_name(), non_blocking=True + ).view(param.ds_shape) + self.__param = param + + def wait(self) -> None: + if not get_accelerator().is_synchronized_device(): + get_accelerator().current_stream().synchronize() + self.__param.ds_status = ZeroParamStatus.AVAILABLE + + +class NoGatherCoalescedHandle: + def __init__(self, params: List[Parameter]) -> None: + self.__params = params + self.__complete = False + + for param in self.__params: + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"expected param {param.ds_summary()} to not be available") + if hasattr(param.ds_tensor, "ds_quant_scale"): + param.data = ( + Init.quantizer_module.dequantize(param.ds_tensor.data, param.ds_tensor.ds_quant_scale) + .to(device=get_accelerator().current_device_name(), non_blocking=True) + .view(param.ds_shape) + ) + else: + param.data = param.ds_tensor.data.to( + device=get_accelerator().current_device_name(), non_blocking=True + ).view(param.ds_shape) + + @instrument_w_nvtx + def wait(self) -> None: + if self.__complete: + return + + if not get_accelerator().is_synchronized_device(): + get_accelerator().current_stream().synchronize() + for param in self.__params: + assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" + param.ds_status = ZeroParamStatus.AVAILABLE + + self.__complete = True + + +def _dist_allgather_fn(input_tensor: Tensor, output_tensor: Tensor, group=None): + return instrument_w_nvtx(dist.allgather_fn)(output_tensor, input_tensor, group=group, async_op=True) + + +def print_rank_0(message, debug=False, force=False): + rank = dist.get_rank() + if rank == 0 and (debug or force): + print(message) + # other variations + # - print for all ranks w/o interleaving + # printflock(f"[{rank}] {message}") + # - print to log file per rank + # log_rank_file(rank, message) + + +def debug_rank0(msg: str) -> None: + if dist.get_rank() == 0: + logger.debug(msg) + + +def _init_external_params(module): + if not hasattr(module, "_external_params"): + module._external_params = {} + + def external_parameters(self): + return self._external_params.items() + + def all_parameters(self): + return itertools.chain(self.named_parameters(self, recurse=False), external_parameters(self)) + + module.ds_external_parameters = types.MethodType(external_parameters, module) + module.all_parameters = types.MethodType(all_parameters, module) + + +def register_external_parameter(module, parameter): + """Instruct DeepSpeed to coordinate ``parameter``'s collection and partitioning in + the forward and backward passes of ``module``. + + This is used when a parameter is accessed outside of its owning module's + ``forward()``. DeepSpeed must know to collect it from its partitioned + state and when to release the memory. + + .. note:: + This is only applicable to training with ZeRO stage 3. + + Args: + module (``torch.nn.Module``): The module that requires ``parameter`` in its forward pass. + parameter (``torch.nn.Parameter``): The parameter to register. + + Raises: + RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``. + + + Examples + ======== + + #. Register a weight that is used in another module's forward pass (line 6). + Parameter ``layer1.weight`` is used by ``layer2`` (line 11). + + .. code-block:: python + :linenos: + :emphasize-lines: 6,11 + + class ModuleZ3(torch.nn.Module): + def __init__(self, *args): + super().__init__(self, *args) + self.layer1 = SomeLayer() + self.layer2 = OtherLayer() + deepspeed.zero.register_external_parameter(self, self.layer1.weight) + + def forward(self, input): + x = self.layer1(input) + # self.layer1.weight is required by self.layer2.forward + y = self.layer2(x, self.layer1.weight) + return y + """ + if not isinstance(parameter, torch.nn.Parameter): + raise RuntimeError("Parameter is not a torch.nn.Parameter") + + if not hasattr(module, "_external_params"): + _init_external_params(module) + + key = id(parameter) + module._external_params[key] = parameter + + +def unregister_external_parameter(module, parameter): + """Reverses the effects of :meth:`register_external_parameter`. + + Args: + module (``torch.nn.Module``): The module to affect. + parameter (``torch.nn.Parameter``): The parameter to unregister. + + Raises: + RuntimeError: If ``parameter`` is not of type ``torch.nn.Parameter``. + RuntimeError: If ``parameter`` is not a registered external parameter of ``module``. + """ + if not isinstance(parameter, torch.nn.Parameter): + raise RuntimeError("Parameter is not a torch.nn.Parameter") + + if not hasattr(module, "_external_params") or id(parameter) not in module._external_params: + raise RuntimeError("Parameter is not a registered external parameter of module.") + + key = id(parameter) + del module._external_params[key] + + +class ZeroParamType(Enum): + + # same as regular pytorch parameters + NORMAL = 1 + + # parameters are partitioned across data parallel process + PARTITIONED = 2 + + # the parameter is held with a unique process rank + # and is not available on all other process + REMOTE = 3 + + +class ZeroParamStatus(Enum): + # parameters are fully present and ready for use on all processes + AVAILABLE = 1 + + # parameters are either partitioned or remote in some or all process + NOT_AVAILABLE = 2 + + # parameters are being gathered. + INFLIGHT = 3 + + +_orig_torch_tensor = torch.tensor +_orig_torch_empty = torch.empty +_orig_torch_zeros = torch.zeros +_orig_torch_ones = torch.ones +_orig_torch_full = torch.full +_orig_torch_arange = torch.arange +_orig_torch_eye = torch.eye +_orig_torch_randn = torch.randn + + +def zero_wrapper_for_fp_tensor_constructor(fn: Callable, target_fp_dtype: torch.dtype) -> Callable: + def wrapped_fn(*args, **kwargs) -> Tensor: + if kwargs.get("device", None) is None: + kwargs["device"] = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])) + tensor: Tensor = fn(*args, **kwargs) + if tensor.is_floating_point(): + tensor.data = tensor.data.to(target_fp_dtype) + + return tensor + + return wrapped_fn + + +def get_new_tensor_fn_for_dtype(dtype: torch.dtype) -> Callable: + def new_tensor(cls, *args, **kwargs) -> Tensor: + device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])) + if not args: + args = (0,) + tensor = _orig_torch_empty(0, device=device).new_empty(*args, **kwargs) + if tensor.is_floating_point(): + tensor = tensor.to(dtype) + + return tensor + + return new_tensor + + +# https://stackoverflow.com/a/63851681/9201239 +def get_all_subclasses(cls): + subclass_list = [] + + def recurse(cl): + for subclass in cl.__subclasses__(): + subclass_list.append(subclass) + recurse(subclass) + + recurse(cls) + + return set(subclass_list) + + +@instrument_w_nvtx +def free_param(param: Parameter) -> None: + """Free underlying storage of a parameter.""" + assert not param.ds_active_sub_modules, param.ds_summary() + if get_accelerator().on_accelerator(param.data): + # need to make sure that we don't free the parameter while it is still + # being used for computation + if not get_accelerator().is_synchronized_device(): + param.data.record_stream(get_accelerator().current_stream()) + # param.data doesn't store anything meaningful in partitioned state + param.data = torch.empty(0, dtype=param.dtype, device=param.device) + param.ds_status = ZeroParamStatus.NOT_AVAILABLE + + +reuse_buffers = False +temp_contiguous_tensor = None +empty_buffers = {} + + +# Inserts _post_init_method at the end of init method +# for all sub classes of torch.nn.Module +class InsertPostInitMethodToModuleSubClasses: + num_module_parameters = 0 + num_module_elements = 0 + + def __init__(self, enabled=True, mem_efficient_linear=True, ds_config=None, dtype=None): + self.mem_efficient_linear = mem_efficient_linear + self.enabled = enabled + self._set_dtype(ds_config, dtype) + assert self.dtype in [ + torch.half, + torch.bfloat16, + torch.float, + ], f"Invalid data type {self.dtype}, allowed values are [torch.half, torch.bfloat16, torch.float]" + self.wrapped_cls = set() + self.skip_init_depth = 0 + + self.quantized_initialization = None + if ( + ds_config is not None + and ds_config.weight_quantization_config + and ds_config.weight_quantization_config.quantized_initialization + ): + self.quantized_initialization = ds_config.weight_quantization_config.quantized_initialization + + def __enter__(self): + if not self.enabled: + return + + global zero_init_context + if zero_init_context == 0: + self.patch_init_and_builtins() + global top_level_context + top_level_context = self + + zero_init_context += 1 + + def __exit__(self, exc_type, exc_value, traceback): + if not self.enabled: + return + + global zero_init_context + zero_init_context -= 1 + + # Exiting the top level context + if zero_init_context == 0: + self.unpatch_init_and_builtins() + global top_level_context + top_level_context = None + + if dist.get_rank() == 0: + billion_elems = InsertPostInitMethodToModuleSubClasses.num_module_elements / 1e9 + num_params = InsertPostInitMethodToModuleSubClasses.num_module_parameters + logger.info( + f"finished initializing model - num_params = {num_params}, num_elems = {billion_elems:.2f}B" + ) + + # Now that we cleaned up the metaclass injection, raise the exception. + if exc_type is not None: + return False + + # To be implemented by inheriting classes + def _post_init_method(self, module): + pass + + def _set_dtype(self, ds_config, dtype): + if ds_config is not None and dtype is None: + if ds_config.bfloat16_enabled and ds_config.fp16_enabled: + raise RuntimeError("bfloat16 and fp16 cannot be enabled at once") + + if ds_config.bfloat16_enabled: + self.dtype = torch.bfloat16 + elif ds_config.fp16_enabled: + self.dtype = torch.half + else: + self.dtype = torch.float + else: + self.dtype = ( + dtype or torch.float16 + if get_accelerator().is_fp16_supported() + else torch.bfloat16 + if get_accelerator().is_bf16_supported + else torch.float32 + ) + + def patch_init_and_builtins(self): + def apply_with_gather(orig_module_apply_fn: Callable) -> Callable: + """many models make use of child modules like Linear or Embedding which + perform their own weight initialization in their __init__ methods, + but will then have more weight initialization in a parent module's __init__ + method that modifies weights of child modules, which is typically done + using the Module.apply method. + + since the Init context manager partitions child modules immediately after + they are initialized, without modifying apply we would entirely skip + any initialization done by parent modules. + + to get around this issue, we wrap the function passed to Module.apply + so that the applied function is applied to child modules correctly. + """ + + def get_wrapped_fn_to_apply(fn_to_apply: Callable) -> Callable: + if hasattr(fn_to_apply, "wrapped"): + return fn_to_apply + + @functools.wraps(fn_to_apply) + def wrapped_fn_to_apply(module_to_apply_fn_to: Module) -> None: + """gathers parameters before calling apply function. afterwards + parameters are broadcasted to ensure consistency across all ranks + then re-partitioned. + + takes the following steps: + 1. allgathers parameters for the current module being worked on + 2. calls the original function + 3. broadcasts root rank's parameters to the other ranks + 4. re-partitions the parameters + """ + + # TODO Delay error checking for dangling partitioned parameters to post module init + # raise RuntimeError(f"not all parameters for {module_to_apply_fn_to.__class__.__name__}, " + # f"were zero params, is it possible that the parameters were " + # f"overwritten after they were initialized? " + # f"params: {[p for p in module_to_apply_fn_to.parameters(recurse=False)]} ") + + params_to_apply_fn_to: Iterable[Parameter] = list( + sorted( + [p for p in module_to_apply_fn_to.parameters(recurse=False) if is_zero_param(p)], + key=lambda p: p.ds_id, + ) + ) + + for param in params_to_apply_fn_to: + param.all_gather() + + fn_to_apply(module_to_apply_fn_to) + + for param in params_to_apply_fn_to: + dist.broadcast(param.data, 0, group=param.ds_process_group) + + for param in params_to_apply_fn_to: + param.partition(has_been_updated=True) + + wrapped_fn_to_apply.wrapped = True + + return wrapped_fn_to_apply + + @functools.wraps(orig_module_apply_fn) + def wrapped_apply(module: Module, fn_to_apply: Callable) -> None: + orig_module_apply_fn(module, get_wrapped_fn_to_apply(fn_to_apply)) + + return wrapped_apply + + def hook_for_skip_init(module): + # this function is intended for handling the logic of torch.nn.utils.skip_init + # skip_init:module_cls(*args, **kwargs).to_empty(device=final_device), where kwargs['device']='meta' + # the function call occurs between module_cls(*args, **kwargs) and to_empty(device=final_device). + def partition_after_empty_init(f): + @functools.wraps(f) + def wrapper(module, *args, **kwargs): + _module = f(module, *args, **kwargs) + # here is the post-hook for module.apply(empty_like...) + # after module.apply(empty_like...), the module has completed its empty init on real device + # since skip_init won't involve any computations or weight adjustments, we can directly utilize post_init + self._post_init_method(_module) + return _module + + return wrapper + + def post_wrapper_to_empty(f): + # append some wrapper restoration after to_empty() call + @functools.wraps(f) + def wrapper(*args, **kwargs): + res = f(*args, **kwargs) + # restore _apply hook + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _disable_class_apply(subclass) + # self restore + module.to_empty = f + return res + + return wrapper + + def _enable_class_apply(cls): + cls._old_apply_of_skip_init_hook = cls._apply + cls._apply = partition_after_empty_init(cls._apply) + + def _disable_class_apply(cls): + cls._apply = cls._old_apply_of_skip_init_hook + + # add hooks for to_empty: apply_(empty_like) + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _enable_class_apply(subclass) + + # add a restore hook when exiting skip_init + module.to_empty = post_wrapper_to_empty(module.to_empty) + + def partition_after(f): + @functools.wraps(f) + def wrapper(module, *args, **kwargs): + + # important logic: We want to run post_init only after child's __init__ is + # completed, and do nothing after __init__ of any of its parents and grandparents in + # the inheritance ancestry. This way the partitioning will need to happen only once + # when the whole object is ready to be partitioned and not before. This is because + # often the child module will need to tweak the weights - for example running a + # custom weights init function. So if a parent created the weights param, the child + # won't need to gather it in order to tweak it + + print_rank_0(f"Before initializing {module.__class__.__name__}", force=False) + + is_child_module = False + if not hasattr(module, "_ds_child_entered"): + # child's __init__ was called, since parents all see the same object they can now skip post_init + is_child_module = True + setattr(module, "_ds_child_entered", True) + + init_on_meta = "device" in kwargs and kwargs["device"] == "meta" + if init_on_meta: + self.skip_init_depth += 1 + + f(module, *args, **kwargs) + if init_on_meta and self.skip_init_depth == 1: + # check and handle the logic of empty_init + hook_for_skip_init(module) + if is_child_module: + # child's __init__ is done, now we can run a single post_init on the child object + delattr(module, "_ds_child_entered") + + print_rank_0(f"Running post_init for {module.__class__.__name__}", force=False) + if self.skip_init_depth == 0: + self._post_init_method(module) + + print_rank_0(f"After initializing followed by post init for {module.__class__.__name__}", force=False) + if init_on_meta: + self.skip_init_depth -= 1 + + return wrapper + + def _enable_class(cls): + cls._old_init = cls.__init__ + cls.__init__ = partition_after(cls.__init__) + + def _init_subclass(cls, **kwargs): + cls._old_init = cls.__init__ + cls.__init__ = partition_after(cls.__init__) + + # Replace .__init__() for all existing subclasses of torch.nn.Module recursively + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _enable_class(subclass) + + # holding onto some methods so we can put them back the way they were in __exit__ + torch.nn.modules.module.Module._old_init_subclass = torch.nn.modules.module.Module.__init_subclass__ + torch.nn.modules.module.Module._old_apply = torch.nn.modules.module.Module.apply + torch.Tensor.__old_new__ = torch.Tensor.__new__ + + # Replace .__init__() for future subclasses of torch.nn.Module + torch.nn.modules.module.Module.__init_subclass__ = classmethod(_init_subclass) + if Init.override_module_apply: + torch.nn.modules.module.Module.apply = apply_with_gather(torch.nn.modules.module.Module._old_apply) + + self._add_tensor_creation_wrappers() + + if self.mem_efficient_linear: + print_rank_0( + "nn.functional.linear has been overridden with a more memory efficient version. This will persist unless manually reset.", + force=False, + ) + self.linear_bk = torch.nn.functional.linear + torch.nn.functional.linear = zero3_linear_wrap + + if self.quantized_initialization: + print_rank_0("nn.functional.linear has been overridden with quantized linear version.", force=False) + torch.nn.functional.linear = wrap_quantized_functional(torch.nn.functional.linear) + torch.nn.functional.embedding = wrap_quantized_functional(torch.nn.functional.embedding) + for cls in WEIGHT_QUANTIZATION_LAYERS: + cls._load_from_state_dict = wrap_load_from_state_dict(cls._load_from_state_dict) + + logger.info("Enable Zero3 engine with INT4 quantization.") + + self.patched = True + + def unpatch_init_and_builtins(self): + if self.patched: + + def _disable_class(cls): + cls.__init__ = cls._old_init + + for subclass in get_all_subclasses(torch.nn.modules.module.Module): + _disable_class(subclass) + + # putting methods back the way we found them + torch.nn.modules.module.Module.__init_subclass__ = torch.nn.modules.module.Module._old_init_subclass + if Init.override_module_apply: + torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply + + self._remove_tensor_creation_wrappers() + + self.patched = False + + def _add_tensor_creation_wrappers(self): + torch.Tensor.__new__ = get_new_tensor_fn_for_dtype(self.dtype) + torch.tensor = zero_wrapper_for_fp_tensor_constructor(_orig_torch_tensor, self.dtype) + torch.empty = zero_wrapper_for_fp_tensor_constructor(_orig_torch_empty, self.dtype) + torch.zeros = zero_wrapper_for_fp_tensor_constructor(_orig_torch_zeros, self.dtype) + torch.ones = zero_wrapper_for_fp_tensor_constructor(_orig_torch_ones, self.dtype) + torch.full = zero_wrapper_for_fp_tensor_constructor(_orig_torch_full, self.dtype) + torch.arange = zero_wrapper_for_fp_tensor_constructor(_orig_torch_arange, self.dtype) + torch.eye = zero_wrapper_for_fp_tensor_constructor(_orig_torch_eye, self.dtype) + torch.randn = zero_wrapper_for_fp_tensor_constructor(_orig_torch_randn, self.dtype) + + def _remove_tensor_creation_wrappers(self): + torch.Tensor.__new__ = torch.Tensor.__old_new__ + torch.tensor = _orig_torch_tensor + torch.empty = _orig_torch_empty + torch.zeros = _orig_torch_zeros + torch.ones = _orig_torch_ones + torch.full = _orig_torch_full + torch.arange = _orig_torch_arange + torch.eye = _orig_torch_eye + torch.randn = _orig_torch_randn + + +def shutdown_init_context(): + """ + This function is used to initialize deepspeed engine inside the context of Init. + We need to remove the wrappers but keep the context. + """ + if top_level_context: + top_level_context.unpatch_init_and_builtins() + + +def restore_init_context(): + """ + This function is used to restore the wrappers after deepspeed engine is initialized. + """ + if top_level_context: + top_level_context.patch_init_and_builtins() + + +class AllGatherHandle: + def __init__(self, handle, param: Parameter, quantization=None) -> None: + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"expected param {param.ds_summary()} to be available") + + self.__handle = handle + self.__param = param + self.__quantization = quantization + + def wait(self) -> None: + instrument_w_nvtx(self.__handle.wait)() + if self.__quantization: + instrument_w_nvtx(self.__quantization.quant_handle.wait)() + self.__param.data = self.__quantization.backend.dequantize( + self.__quantization.quantized_param, self.__quantization.scale_buffer + ).to(self.__param.device) + self.__param.ds_status = ZeroParamStatus.AVAILABLE + + +class AllGatherCoalescedHandle: + def __init__( + self, + allgather_handle, + params: List[Parameter], + partitions: List[Tensor], + world_size: int, + use_secondary_tensor=False, + quantization=None, + ) -> None: + self.allgather_handle = allgather_handle + self.params = params + self.partitions = partitions + self.world_size = world_size + self.use_secondary_tensor = use_secondary_tensor + self.complete = False + self.quantization = quantization + + for param in self.params: + if param.ds_status != ZeroParamStatus.INFLIGHT: + raise RuntimeError(f"expected param {param.ds_summary()} to not be available") + + @instrument_w_nvtx + def wait(self) -> None: + if self.complete: + return + + instrument_w_nvtx(self.allgather_handle.wait)() + + if self.quantization: + instrument_w_nvtx(self.quantization.quant_handle.wait)() + flat_tensor = self.quantization.backend.dequantize( + self.quantization.quantized_param, self.quantization.scale_buffer + ).to(self.params[0].device) + + self.partitions: List[Parameter] = [] + for i in range(self.world_size): + self.partitions.append( + flat_tensor.narrow(0, self.quantization.partition_sz * i, self.quantization.partition_sz) + ) + + # split the single tensor out into individual tensors + param_offset = 0 + for param in self.params: + assert param.ds_status == ZeroParamStatus.INFLIGHT, f"expected param {param.ds_summary()} to be inflight" + partitions: List[Tensor] = [] + ds_tensor_numel = param.ds_tensor.ds_numel + # if self.use_secondary_tensor: + # ds_tensor_numel *= param.ds_secondary_tensor_num_of_groups + if self.use_secondary_tensor: + ds_tensor_numel = param.ds_secondary_tensor.ds_numel + for rank in range(self.world_size): + param_start = rank * ds_tensor_numel + if param_start < param.ds_numel: + part_to_copy = self.partitions[rank].narrow( + 0, param_offset, min(param.ds_numel - param_start, ds_tensor_numel) + ) + partitions.append(part_to_copy) + param.data = instrument_w_nvtx(torch.cat)(partitions).view(param.ds_shape) + param.ds_status = ZeroParamStatus.AVAILABLE + + for part_to_copy in partitions: + if not get_accelerator().is_synchronized_device(): + part_to_copy.record_stream(get_accelerator().current_stream()) + + param_offset += ds_tensor_numel + + self.complete = True + + +class MultipleAllGatherHandles: + def __init__(self, handles: List[AllGatherCoalescedHandle]): + self.handles = handles + + def wait(self) -> None: + for handle in self.handles: + handle.wait() + + +class QuantizationInfo: + # a placeholder object to store all quant related vars used in handles + def __init__(self) -> None: + self.quantized_param = None + self.backend = None + self.quant_handle = None + self.scale_buffer = None + + +class CUDAQuantizer: + async_flag = True + target_group_size = 8000 # the optimal size is 4k, so we set the target to be below 8k + group_size_cache = dict() + quantizer_cuda_module = None + + def __init__(self) -> None: + if CUDAQuantizer.quantizer_cuda_module is None: + CUDAQuantizer.quantizer_cuda_module = deepspeed.ops.op_builder.QuantizerBuilder().load() + + def quantize(self, param, groups=None): + if groups is None: + try: + groups = self.group_size_cache[param.numel()] + except KeyError: + groups = math.ceil(param.numel() / self.target_group_size) + while groups < param.numel(): + if param.numel() % (8 * groups) == 0: + break + groups += 1 + while True: + if ( + param.numel() % (8 * groups * 2) == 0 and param.numel() / groups > self.target_group_size + ): # hard limit of 16k group_size + groups *= 2 + else: + break + assert ( + param.numel() % (8 * groups) == 0 + ), f"Qantized weight requires the number of weights be a multiple of 8. Yet {param.numel()} cannot be divided by 8*{groups}" + assert param.numel() / groups < 16000, f"{param.numel()} / {groups} is larger than 16k" + assert ( + param.numel() > groups + ), f"Adaptive grouping algorithm cannot find a group size for input tensor of size {param.numel()}" + self.group_size_cache[param.numel()] = groups + return self.quantizer_cuda_module.quantize( + param.to(get_accelerator().device_name()), groups, 8, self.quantizer_cuda_module.Symmetric + ) + + def dequantize(self, quantized_param, scale): + return self.quantizer_cuda_module.dequantize( + quantized_param, scale, scale.numel(), 8, self.quantizer_cuda_module.Symmetric + ) + + +def _no_gather_coalesced(params: Iterable[Parameter]) -> AllGatherCoalescedHandle: + for param in params: + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(f"expect param.ds_status == ZeroParamStatus.NOT_AVAILABLE, got{param.ds_summary()}") + param.ds_status = ZeroParamStatus.INFLIGHT + + params = sorted(params, key=lambda p: p.ds_id) + if len(params) == 1: + (param,) = params + return NoGatherHandle(param) + return NoGatherCoalescedHandle(params) + + +# Replaces all parameters in module with Scattered Parameters +class Init(InsertPostInitMethodToModuleSubClasses): + param_id = 0 + param_persistence_threshold = get_config_default(DeepSpeedZeroConfig, "param_persistence_threshold") + model_persistence_threshold = get_config_default(DeepSpeedZeroConfig, "model_persistence_threshold") + num_persisted_parameters = 0 + num_persisted_elements = 0 + apply_param_persistence = False + override_module_apply = get_config_default(DeepSpeedZeroConfig, "override_module_apply") + + def __init__( + self, + module=None, + data_parallel_group=None, + mem_efficient_linear=True, + remote_device=None, + pin_memory=False, + config_dict_or_path=None, + config=None, + enabled=True, + dtype=None, + mpu=None, + zero_param_parallel_group=None, + zero_quantized_weights=False, + zero_quantized_nontrainable_weights=False, + sequence_data_parallel_group=None, + param_swapper=None, + ): + """A context to enable massive model construction for training with + ZeRO-3. Models are automatically partitioned (or, sharded) across the + system and converted to half precision. + + Args: + module (``torch.nn.Module``, optional): If provided, partition the model as + if it was constructed in the context. + data_parallel_group (``deepspeed.comm`` process group, optional): + The group of processes to partition among. Defaults to all processes. + mem_efficient_linear (bool, optional): Replace + torch.nn.functional.linear with an implementation that allows + DeepSpeed to partition parameters. Defaults to ``True``. + remote_device (string, optional): The initial device to store model + weights e.g., ``cpu``, ``nvme``. Passing ``"cpu"`` will create the model in CPU + memory. The model may still be moved to GPU based on the + offload settings for training. Defaults to param offload device if a config is + defined, otherwise GPU. + pin_memory (bool, optional): Potentially increase performance by + using pinned memory for model weights. ``remote_device`` must be + ``"cpu"``. Defaults to pin_memory value in config, otherwise ``False``. + config_dict_or_path (dict or ``json file``, optional): If provided, provides configuration + for swapping fp16 params to NVMe. + config (dict or ``json file``, optional): Deprecated, use config_dict_or_path instead. + enabled (bool, optional): If ``False``, this context has no + effect. Defaults to ``True``. + dtype (``dtype``, optional): Can be used to change the data type of the parameters. + Supported options are ``torch.half`` and ``torch.float``. Defaults to ``None`` + mpu (``object``, optional): A model parallelism unit object that implements get_{model,data}_parallel_{rank,group,world_size}. + zero_param_parallel_group(``object``, optional): Parallel (comm) group for dual partitioning of ZeRO params. + zero_quantized_weights (bool, optional): If ``True``, turn on quantized weights in all gather weights. Default is ``False`` + zero_quantized_nontrainable_weights (bool, optional): If ``True``, nontrainable weights will be stored in quantized format. Default is ``False`` + param_swapper (``deepspeed.runtime.swap_tensor.partitioned_param_swapper.AsyncPartitionedParameterSwapper``, optional): [Experimental] Use existing parameter swapper. Defaults to ``None``. + This argument will be removed in the near future. + + This context accelerates model initialization and enables models that + are too large to allocate in their entirety in CPU memory. It has the + following effects: + + #. allocates tensors to either GPU or CPU memory or NVMe + #. converts floating point tensors to half precision + #. immediately partitions tensors among the group of data-parallel devices + #. (*optional*) replaces ``torch.nn.functional.linear`` with a more + memory-efficient implementation + + These modifications allow for models that exceed the size of local CPU/GPU + memory/NVMe, but fit within the total NVMe capacity (*i.e.*, aggregate CPU + or GPU memory or NVMe) across all nodes. Consider initializing a model with one + trillion parameters, whose weights occupy two terabytes (TB) in half + precision. The initial CPU allocation in full precision requires 4TB of + memory *per process*, and so a system with 8 GPUs per node would need 32TB of + CPU memory due to data-parallel redundancies. Instead, by immediately + partitioning tensors we remove the redundancies. The result is that + regardless of the number of GPUs, we still only require the original 4TB. This + allows for a linear increase in model size with the aggregate system memory. + For example, if a node has 1TB of memory and 8 GPUs, we could fit a trillion + parameter model with 4 nodes and 32 GPUs. + + Important: If the fp16 weights of the model can't fit onto a single GPU memory + this feature must be used. + + .. note:: + Initializes ``deepspeed.comm`` if it has not already been done so. + See :meth:`deepspeed.init_distributed` for more information. + + .. note:: + Only applicable to training with ZeRO-3. + + Examples + -------- + + #. Allocate a model and partition it among all processes: + + .. code-block:: python + + with deepspeed.zero.Init(): + model = MyLargeModel() + + + #. Allocate a model in pinned CPU memory and partition it among a subgroup of processes: + + .. code-block:: python + + with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), remote_device="cpu", pin_memory=True): + model = MyLargeModel() + + + #. Partition an already-allocated model in CPU memory: + + .. code-block:: python + + model = deepspeed.zero.Init(module=model) + """ + if config is not None: + config_dict_or_path = config + logger.warning(f"zero.Init: the `config` argument is deprecated. Please use `config_dict_or_path` instead.") + _ds_config = ( + deepspeed.runtime.config.DeepSpeedConfig(config_dict_or_path, mpu) + if config_dict_or_path is not None + else None + ) + if _ds_config is not None: + if _ds_config.zero_config.memory_efficient_linear and _ds_config.compile_config.enabled: + # memory_efficient_linear displays numerous errors when torch.compile is enabled. + # Refer to https://github.com/pytorch/pytorch/issues/119059 for details. + # Further investigation into performance is necessary, even after resolving this issue because + # the `memory_efficient_linear` module may lead to more graph breaks compared to the original implementation. + logger.warning(f"memory_efficient_linear is disabled when torch.compile is enabled.") + mem_efficient_linear = False + else: + mem_efficient_linear = _ds_config.zero_config.memory_efficient_linear + + super().__init__(enabled=enabled, mem_efficient_linear=mem_efficient_linear, ds_config=_ds_config, dtype=dtype) + if not dist.is_initialized(): + init_distributed() + assert dist.is_initialized(), "Parameters cannot be scattered without initializing deepspeed.comm" + + if data_parallel_group is None and sequence_data_parallel_group is None: + self.ds_process_group = dist.get_world_group() + elif sequence_data_parallel_group is not None: + self.ds_process_group = sequence_data_parallel_group + elif data_parallel_group is not None: + self.ds_process_group = data_parallel_group + else: # both given + raise ValueError( + "Both 'data_parallel_group' and 'sequence_data_parallel_group' were specified. Please provide only one of these arguments." + ) + + self.rank = dist.get_rank(group=self.ds_process_group) + self.dp_world_size = dist.get_world_size(group=self.ds_process_group) + + self.zero_param_process_group = zero_param_parallel_group + if ( + _ds_config is not None + and _ds_config.zero_config.zero_hpz_partition_size > 1 + and self.zero_param_process_group is None + ): + groups._create_zero_param_parallel_group(_ds_config.zero_config.zero_hpz_partition_size) + self.zero_param_process_group = groups._get_zero_param_intra_parallel_group() + + self.num_ranks_in_param_group = self.dp_world_size + self.rank_in_group = self.rank + self.num_param_groups = 1 + if self.zero_param_process_group is not None: + self.num_ranks_in_param_group = groups._get_zero_param_intra_parallel_group_world_size() + self.num_param_groups = int(self.dp_world_size / self.num_ranks_in_param_group) + self.rank_in_group = groups._get_zero_param_intra_parallel_rank_in_mygroup() + print_rank_0(f"hpZeRO group size: {self.num_ranks_in_param_group}", force=True) + + logger.debug( + "hpZeRO partition parameter my rank in world {} my rank in group {} ranks in my param partition group: {} ".format( + self.rank, self.rank_in_group, groups._get_zero_param_intra_parallel_group_ranks() + ) + ) + + # Local device is the device where the parameters are consumed, must be default device. + # It is the device where parameters are fully instantiated using allgather + self.local_device = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"])) + get_accelerator().set_device(self.local_device) + + self.quantized_weights = zero_quantized_weights + if _ds_config is not None and _ds_config.zero_config.zero_quantized_weights and not self.quantized_weights: + self.quantized_weights = _ds_config.zero_config.zero_quantized_weights + self.quantized_nontrainable_weights = zero_quantized_nontrainable_weights + if ( + _ds_config is not None + and _ds_config.zero_config.zero_quantized_nontrainable_weights + and not self.quantized_nontrainable_weights + ): + self.quantized_nontrainable_weights = _ds_config.zero_config.zero_quantized_nontrainable_weights + + self.module = module + if self.quantized_weights or self.quantized_nontrainable_weights: + self.quantizer_module = CUDAQuantizer() + print_rank_0(f"Using quantizer for weights: {self.quantizer_module.__class__.__name__}", force=True) + + if _ds_config is not None: + Init.override_module_apply = _ds_config.zero_config.override_module_apply + + if _ds_config.zero_config.offload_param is not None: + remote_device = _ds_config.zero_config.offload_param.device + pin_memory = _ds_config.zero_config.offload_param.pin_memory + + self._validate_remote_device(remote_device, _ds_config) + + # Remote device is the device where parameter partitions are stored + # It can be same as local_device or it could be CPU or NVMe. + self.remote_device = self.local_device if remote_device in [None, OffloadDeviceEnum.none] else remote_device + self.pin_memory = ( + pin_memory if (self.remote_device in [OffloadDeviceEnum.cpu, OffloadDeviceEnum.nvme]) else False + ) + + # Enable fp16 param swapping to NVMe + if self.remote_device == OffloadDeviceEnum.nvme: + self.param_swapper = param_swapper or AsyncPartitionedParameterSwapper(_ds_config, self.dtype) + else: + self.param_swapper = None + + # If we are provided an already-allocated module to prepare. + if module is not None: + assert isinstance(module, torch.nn.Module) + self._convert_to_zero_parameters(module.parameters(recurse=True)) + + self.use_all_gather_into_tensor = dist.has_all_gather_into_tensor() + if not self.use_all_gather_into_tensor: + logger.info(f"all_gather_into_tensor API is not available in torch {torch.__version__}") + + def _update_persist_config(self, ds_config): + Init.apply_param_persistence = True + Init.param_persistence_threshold = ds_config.zero_config.param_persistence_threshold + Init.model_persistence_threshold = ds_config.zero_config.model_persistence_threshold // self.num_partitions + + def _zero_init_param(self, param): + self._convert_to_deepspeed_param(param) + if dist.get_world_group() == self.get_dp_process_group(): + dist.broadcast(param.data, 0, self.get_dp_process_group()) + else: + dist.broadcast( + param.data, dist.get_global_rank(self.get_dp_process_group(), 0), self.get_dp_process_group() + ) + param.partition() + + def _convert_to_zero_parameters(self, param_list): + for param in param_list: + if is_zero_param(param): + continue + + param.data = param.data.to(self.local_device) + self._zero_init_param(param) + + def _validate_remote_device(self, remote_device, ds_config): + if ds_config is not None: + if remote_device in [None, OffloadDeviceEnum.cpu]: + if ds_config.zero_config.offload_param is not None: + offload_param_device = ds_config.zero_config.offload_param.device + assert ( + offload_param_device != OffloadDeviceEnum.nvme + ), f"'device' in DeepSpeed Config cannot be {offload_param_device} if remote device is {remote_device}." + + if remote_device == OffloadDeviceEnum.nvme: + assert ( + ds_config.zero_config.offload_param is not None + ), f'"offload_param" must be defined in DeepSpeed Config if remote device is {OffloadDeviceEnum.nvme}.' + + assert ( + ds_config.zero_config.offload_param.nvme_path is not None + ), f'"nvme_path" in DeepSpeed Config cannot be None if remote device is {OffloadDeviceEnum.nvme}' + + def _post_init_method(self, module): + # see_memory_usage(f"Before converting params in {module.__class__.__name__}", force=False) + print_rank_0(f"Converting Params in {module.__class__.__name__}", force=False) + see_memory_usage(f"Before converting and partitioning params in {module.__class__.__name__}", force=False) + + for name, param in module.named_parameters(recurse=False): + print_rank_0(f"Analyzing param {name} in {module.__class__.__name__}", force=False) + InsertPostInitMethodToModuleSubClasses.num_module_parameters += 1 + InsertPostInitMethodToModuleSubClasses.num_module_elements += param.numel() + if not is_zero_param(param): + if not get_accelerator().on_accelerator(param): + param.data = param.data.to(self.local_device) + + if name == "weight" and self.quantized_initialization and type(module) in WEIGHT_QUANTIZATION_LAYERS: + _quantize_param(param, self.quantized_initialization) + + self._zero_init_param(param) + print_rank_0( + f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}" + ) + + see_memory_usage( + f"Param count {InsertPostInitMethodToModuleSubClasses.num_module_elements}. After converting and partitioning params in {module.__class__.__name__}", + force=False, + ) + + def _convert_to_deepspeed_param(self, param): + + # Partitioned, Normal, Remote + param.ds_param_type = ZeroParamType.PARTITIONED + + # Replicated vs Partitioned vs Inflight + param.ds_status = ZeroParamStatus.AVAILABLE + + # Stores the shape of the original tensor + param.ds_shape = param.shape + + # Stores the number of elements in the original parameter without padding + param.ds_numel = param.numel() + + # Stores the partitioned copy of the tensor + param.ds_tensor = None + + # Keeps track of how many active sub-modules need this param at any given point in time + param.ds_active_sub_modules = set() + + # If this flag is true, then the parameters are replicated throughput training + # And only partitioned before the step + if ( + Init.apply_param_persistence + and param.ds_numel <= Init.param_persistence_threshold + and Init.num_persisted_elements + param.ds_numel <= Init.model_persistence_threshold + ): + param.ds_persist = True + Init.num_persisted_parameters += 1 + Init.num_persisted_elements += param.ds_numel + else: + param.ds_persist = False + + param.is_external_param = False + + # The group that the parameter is scattered across. + param.ds_process_group = self.ds_process_group + + # Stores the secondary partitioned copy of the tensor + param.ds_secondary_tensor = None + + # Process group for secondary partition all (group) gather + param.ds_zero_param_process_group = self.zero_param_process_group + param.ds_secondary_tensor_group_size = self.num_ranks_in_param_group + param.ds_secondary_tensor_num_of_groups = self.num_param_groups + + # This is set to the Async Param swapper if remote device is nvme + # else this is set to None + param.nvme_swapper = self.param_swapper + + # DeepSpeed Param ID + param.ds_id = Init.param_id + Init.param_id += 1 + + def all_gather(param_list=None, async_op=False, hierarchy=0): + cls = param + if param_list is None: + param_list = [cls] + return self._all_gather(param_list, async_op=async_op, hierarchy=hierarchy) + + def _all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group): + partition_sz = sum(p.ds_tensor.ds_numel for p in params) + + use_secondary_tensor = params[0].ds_secondary_tensor is not None + + if use_secondary_tensor: + # partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params) + partition_sz = sum(p.ds_secondary_tensor.ds_numel for p in params) + + flat_tensor = torch.empty( + partition_sz * world_size, + dtype=dtype, + device=get_accelerator().current_device_name(), + requires_grad=False, + ) + + partitions: List[Parameter] = [] + for i in range(world_size): + partitions.append(flat_tensor.narrow(0, partition_sz * i, partition_sz)) + + if use_secondary_tensor: + instrument_w_nvtx(torch.cat)( + [p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params], + out=partitions[rank_in_group], + ) + else: + instrument_w_nvtx(torch.cat)( + [p.ds_tensor.to(get_accelerator().current_device_name()) for p in params], + out=partitions[rank_in_group], + ) + # Ensure all gather output size is correct + assert partitions[rank_in_group].numel() * world_size == flat_tensor.numel() + handle = _dist_allgather_fn(partitions[rank_in_group], flat_tensor, ds_process_group) + # Fix get_partition_dp_group(params[0])) + + return AllGatherCoalescedHandle( + allgather_handle=handle, + params=params, + partitions=partitions, + world_size=world_size, + use_secondary_tensor=use_secondary_tensor, + ) + + @instrument_w_nvtx + def all_gather_coalesced( + params: Iterable[Parameter], safe_mode: bool = False, quantize: bool = False + ) -> AllGatherCoalescedHandle: + + # fetches from nvme if the partition is not available and in nvme + self._ensure_availability_of_partitioned_params(params) + + if self.num_partitions == 1: + return _no_gather_coalesced(params) + + for param in params: + if param.ds_status != ZeroParamStatus.NOT_AVAILABLE: + raise RuntimeError(param.ds_summary()) + param.ds_status = ZeroParamStatus.INFLIGHT + + # use appropriate all gather process group + ds_process_group = self.ds_process_group + rank_in_group = self.rank + world_size = self.dp_world_size + use_secondary_tensor = params[0].ds_secondary_tensor is not None + if self.zero_param_process_group and use_secondary_tensor: + ds_process_group = self.zero_param_process_group # intragroup + rank_in_group = self.rank_in_group + world_size = self.num_ranks_in_param_group + + # pprint(dir(ds_process_group)) + # ensure that each rank has params in same order. the allgather + # is done by flattening the parameter list into a single tensor that + # can be allgathered in a single call - this means that if each rank + # gives a list of the same parameters in a different order we will + # silently get incorrect parameter values, and have very difficult + # to debug correctness issues. + params = sorted(params, key=lambda p: p.ds_id) + + if logger.isEnabledFor(logging.DEBUG): + debug_rank0(f"-allgather_coalesced: {[p.ds_id for p in params]}") + + if safe_mode: + # ensure that same list (with same ordering) of parameters are + # being allgathered across all ranks, otherwise could mix + # data between tensors. + assert_ints_same_as_other_ranks([p.ds_id for p in params]) + # ensure that tensors from each rank agree on the same ds_numel + # otherwise could mix data between tensors. + assert_ints_same_as_other_ranks([p.ds_tensor.ds_numel for p in params]) + + if len(params) == 1: + # have an opportunity to avoid some intermediate memory allocations + param = params[0] + buffer_size = math.ceil(param.ds_numel / world_size) * world_size + if use_secondary_tensor: + buffer_size = ( + param.ds_secondary_tensor.shape[0] * world_size + ) # make sure out is appropriately sized + + param_ds_tensor = param.ds_secondary_tensor if use_secondary_tensor else param.ds_tensor + param_buffer = torch.empty( + buffer_size, + dtype=param_ds_tensor.dtype if not quantize else torch.int8, + device=get_accelerator().current_device_name(), + requires_grad=False, + ) + if not quantize: + handles = _dist_allgather_fn( + param_ds_tensor.to(get_accelerator().current_device_name()), + param_buffer, + ds_process_group, + ) + param.data = param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device) + return AllGatherHandle(handles, param) + else: + if hasattr(param_ds_tensor, "ds_quant_scale"): + scales = param_ds_tensor.ds_quant_scale + quantized_param = param_ds_tensor.data + else: + quantized_param, scales = self.quantizer_module.quantize(param_ds_tensor) + handle = _dist_allgather_fn( + quantized_param.to(get_accelerator().current_device_name()), param_buffer, ds_process_group + ) + + quant_scale_buffer = torch.empty( + scales.numel() * world_size, + dtype=scales.dtype, + device=get_accelerator().current_device_name(), + requires_grad=False, + ) + quant_handle = _dist_allgather_fn( + scales.to(get_accelerator().current_device_name()), quant_scale_buffer, ds_process_group + ) + quant_info = QuantizationInfo() + quant_info.quantized_param = ( + param_buffer.narrow(0, 0, param.ds_numel).view(param.ds_shape).to(param.device) + ) + quant_info.backend = self.quantizer_module + quant_info.quant_handle = quant_handle + quant_info.scale_buffer = quant_scale_buffer + return AllGatherHandle(handle, param, quantization=quant_info) + + else: + if not quantize: + dtype_params = defaultdict(list) + for p in params: + dtype_params[p.ds_tensor.dtype].append(p) + handles = [] + for dtype, params in dtype_params.items(): + handles.append(_all_gather_dtype(dtype, params, world_size, rank_in_group, ds_process_group)) + + return MultipleAllGatherHandles(handles) + + else: + partition_sz = sum(p.ds_tensor.ds_numel for p in params) + + if use_secondary_tensor: + partition_sz = sum(p.ds_tensor.ds_numel * p.ds_secondary_tensor_num_of_groups for p in params) + + flat_tensor = torch.empty( + partition_sz * world_size, + dtype=torch.int8, + device=get_accelerator().current_device_name(), + requires_grad=False, + ) + + if use_secondary_tensor: + if hasattr(params[0].ds_secondary_tensor, "ds_quant_scale"): + quantized_param = instrument_w_nvtx(torch.cat)( + [p.ds_secondary_tensor.data.to(get_accelerator().current_device_name()) for p in params] + ) + scales = instrument_w_nvtx(torch.cat)( + [ + p.ds_secondary_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) + for p in params + ] + ) + else: + quantized_param, scales = self.quantizer_module.quantize( + instrument_w_nvtx(torch.cat)( + [p.ds_secondary_tensor.to(get_accelerator().current_device_name()) for p in params] + ) + ) + else: + if hasattr(params[0].ds_tensor, "ds_quant_scale"): + quantized_param = instrument_w_nvtx(torch.cat)( + [p.ds_tensor.data.to(get_accelerator().current_device_name()) for p in params] + ) + scales = instrument_w_nvtx(torch.cat)( + [p.ds_tensor.ds_quant_scale.to(get_accelerator().current_device_name()) for p in params] + ) + else: + quantized_param, scales = self.quantizer_module.quantize( + instrument_w_nvtx(torch.cat)( + [p.ds_tensor.to(get_accelerator().current_device_name()) for p in params] + ) + ) + quant_scale_buffer = torch.empty( + scales.numel() * world_size, + dtype=torch.float32, + device=get_accelerator().current_device_name(), + requires_grad=False, + ) + handle = _dist_allgather_fn(quantized_param, flat_tensor, ds_process_group) + quant_handle = _dist_allgather_fn(scales, quant_scale_buffer, ds_process_group) + quant_info = QuantizationInfo() + quant_info.quantized_param = flat_tensor + quant_info.backend = self.quantizer_module + quant_info.quant_handle = quant_handle + quant_info.scale_buffer = quant_scale_buffer + quant_info.partition_sz = partition_sz + quant_info.world_size = world_size + return AllGatherCoalescedHandle( + allgather_handle=handle, + params=params, + partitions=None, + world_size=world_size, + use_secondary_tensor=use_secondary_tensor, + quantization=quant_info, + ) + + def partition(param_list=None, hierarchy=0, has_been_updated=False): + cls = param + print_rank_0(f"{'--'*hierarchy}----Partitioning param {debug_param2name_id_shape_device(cls)}", force=False) + if param_list is None: + param_list = [cls] + self._partition(param_list, has_been_updated=has_been_updated) + + def reduce_gradients_at_owner(param_list=None, hierarchy=0): + cls = param + if param_list is None: + param_list = [cls] + print_rank_0( + f"{'--'*hierarchy}----Reducing Gradients for param with ids {[param.ds_id for param in param_list]} to owner" + ) + self._reduce_scatter_gradients(param_list) + + def partition_gradients(param_list=None, partition_buffers=None, hierarchy=0, accumulate=False): + cls = param + print_rank_0( + f"{'--'*hierarchy}----Partitioning param gradient with id {debug_param2name_id_shape_device(cls)}" + ) + if param_list is None: + param_list = [cls] + if isinstance(partition_buffers, torch.Tensor): + partition_buffers = [partition_buffers] + + self._partition_gradients(param_list, partition_buffers=partition_buffers, accumulate=accumulate) + + def aligned_size(): + return self._aligned_size(param) + + def padding_size(): + return self._padding_size(param) + + def partition_numel(): + return self._partition_numel(param) + + def item_override(): + param.all_gather() + return param._orig_item() + + def ds_summary(slf: torch.Tensor, use_debug_name: bool = False) -> dict: + return { + "id": debug_param2name_id(slf) if use_debug_name else slf.ds_id, + "status": slf.ds_status.name, + "numel": slf.numel(), + "ds_numel": slf.ds_numel, + "shape": tuple(slf.shape), + "ds_shape": tuple(slf.ds_shape), + "requires_grad": slf.requires_grad, + "grad_shape": tuple(slf.grad.shape) if slf.grad is not None else None, + "persist": slf.ds_persist, + "active_sub_modules": slf.ds_active_sub_modules, + "ds_tensor.shape": slf.ds_tensor.shape if slf.ds_tensor is not None else None, + } + + def convert_to_zero_parameters(param_list): + self._convert_to_zero_parameters(param_list) + + def allgather_before(func: Callable) -> Callable: + def wrapped(*args, **kwargs): + param.all_gather() + return func(*args, **kwargs) + + return wrapped + + # Collectives for gathering and partitioning parameters + param.all_gather = all_gather + param.all_gather_coalesced = all_gather_coalesced + param.partition = partition + + # Collective for averaging gradients + param.reduce_gradients_at_owner = reduce_gradients_at_owner + param.partition_gradients = partition_gradients + + # Partitioning size utilities + param.aligned_size = aligned_size + param.padding_size = padding_size + param.partition_numel = partition_numel + param.ds_summary = types.MethodType(ds_summary, param) + + param.item = allgather_before(param.item) + + param.convert_to_zero_parameters = convert_to_zero_parameters + + def _aligned_size(self, param): + return param.ds_numel + self._padding_size(param) + + def _padding_size(self, param): + remainder = param.ds_numel % self.num_partitions + return (self.num_partitions - remainder) if remainder else 0 + + def _aligned_size_sec(self, param): + return param.ds_numel + self._padding_size_sec(param) + + def _padding_size_sec(self, param): + remainder = param.ds_numel % self.num_ranks_in_param_group + return (self.num_ranks_in_param_group - remainder) if remainder else 0 + + def _partition_numel(self, param): + return param.ds_tensor.ds_numel + + def _ensure_availability_of_partitioned_params(self, params): + swap_in_list = [] + swap_in_flight = [] + for param in params: + if param.ds_tensor.status == PartitionedParamStatus.NOT_AVAILABLE: + assert ( + param.ds_tensor.final_location == OffloadDeviceEnum.nvme + and param.ds_status == ZeroParamStatus.NOT_AVAILABLE + ) + swap_in_list.append(param) + if param.ds_tensor.status == PartitionedParamStatus.INFLIGHT: + assert ( + param.ds_tensor.final_location == OffloadDeviceEnum.nvme + and param.ds_status == ZeroParamStatus.NOT_AVAILABLE + ) + swap_in_flight.append(param) + if len(swap_in_list) > 0: + swap_in_list[0].nvme_swapper.swap_in(swap_in_list, async_op=False) + elif len(swap_in_flight) > 0: + swap_in_flight[0].nvme_swapper.synchronize_reads() + + @instrument_w_nvtx + def _all_gather(self, param_list, async_op=False, hierarchy=None): + + # fetches from nvme if the partition is not available and in nvme + self._ensure_availability_of_partitioned_params(param_list) + + handles = [] + all_gather_list = [] + for param in param_list: + if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: + if async_op: + handle = self._allgather_param(param, async_op=async_op, hierarchy=hierarchy) + param.ds_status = ZeroParamStatus.INFLIGHT # if async_op else ZeroParamStatus.AVAILABLE + handles.append(handle) + else: + all_gather_list.append(param) + # note: param_list may contain params that are already in flight / aviailable. So we need to use all_gather_list + if not async_op: + if len(all_gather_list) == 1: + ret_value = self._allgather_params(all_gather_list, hierarchy=hierarchy) + else: + all_gather_quantize_list = [] + all_gather_nonquantize_list = [] + for param in all_gather_list: + if hasattr(param.ds_tensor, "ds_quant_scale") or ( + hasattr(param, "ds_secondary_tensor") and hasattr(param.ds_secondary_tensor, "ds_quant_scale") + ): + all_gather_quantize_list.append(param) + else: + all_gather_nonquantize_list.append(param) + # _allgather_params_coalesced always return None + self._allgather_params_coalesced(all_gather_nonquantize_list, hierarchy, quantize=False) + self._allgather_params_coalesced(all_gather_quantize_list, hierarchy, quantize=True) + for param in all_gather_list: + param.ds_status = ZeroParamStatus.AVAILABLE + return None + + return handles + + def _partition(self, param_list, force=False, has_been_updated=False): + for param in param_list: + print_rank_0(f"Before Partitioning Param {param.ds_id}", force=False) + if self.zero_param_process_group is not None: + self._partition_param_sec(param) + self._partition_param(param, has_been_updated=has_been_updated) + + param.ds_status = ZeroParamStatus.NOT_AVAILABLE + # if param.ds_tensor is not None: + # assert id(param.data) == id(param.ds_tensor.data), \ + # "After the parameters are initially partitioned, make sure we are not recreating the partition." + # print_rank_0(f"After Partitioning Param {param.ds_id} {param.ds_tensor.size()} {param.ds_tensor}",force=False) + + @instrument_w_nvtx + def _partition_param(self, param, buffer=None, has_been_updated=False): + assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight" + global reuse_buffers + print_rank_0(f"Param id {param.ds_id} status is {param.ds_status}", force=False) + if param.ds_status is ZeroParamStatus.AVAILABLE: + print_rank_0(f"Partitioning param id {param.ds_id} reuse buffers {reuse_buffers}", force=False) + # if reuse_buffers and False: + # numel = buffer.numel() + # buffer = param.data.view(-1) + # print_rank_0( + # "Returning buffer for param {param.ds_id} with numel {param.ds_numel} to empty buffers", + # force=False) + # if numel in empty_buffers: + # empty_buffers[numel].append(buffer) + + # if deepspeed.comm.get_rank(): + # print(f"Releasing {param.data.numel()}") + + if param.ds_tensor is not None and not has_been_updated: ##param already partitioned + + # print_rank_0(f"Param {param.ds_id} pri {param.ds_tensor.size()} loc? {param.ds_tensor.final_location}", force=True) + # param.data = param.ds_tensor.data + + see_memory_usage(f"Before partitioning param {param.ds_id} {param.shape}", force=False) + # param.data does not store anything meaningful in partitioned state + free_param(param) + see_memory_usage(f"After partitioning param {param.ds_id} {param.shape}", force=False) + + if param.ds_tensor.final_location == OffloadDeviceEnum.nvme: + print_rank_0(f"Param {param.ds_id} partition released since it exists in nvme", force=False) + param.nvme_swapper.remove_partition_and_release_buffers([param]) + print_rank_0( + f"after swap Param {param.ds_id} {param.ds_tensor.shape} partition released since it exists in nvme", + force=False, + ) + + return + + tensor_size = self._aligned_size(param) + partition_size = tensor_size // self.num_partitions + if param.ds_tensor is None: + final_location = None + if self.remote_device == OffloadDeviceEnum.nvme and self.param_swapper.swappable_tensor( + numel=partition_size + ): + final_location = OffloadDeviceEnum.nvme + buffer = self.param_swapper.get_buffer(param, partition_size) + partitioned_tensor = torch.empty(0, dtype=param.dtype, device=buffer.device) + partitioned_tensor.data = buffer.data + print_rank_0(f"ID {param.ds_id} Initializing partition for the first time for nvme offload.") + + else: + if param.ds_persist: + device = self.local_device + elif self.remote_device == OffloadDeviceEnum.nvme: + device = OffloadDeviceEnum.cpu + else: + device = self.remote_device + + partitioned_tensor = torch.empty(partition_size, dtype=param.dtype, device=device) + # quantize the tensor if it's not trainable + if not param.requires_grad and self.quantized_nontrainable_weights: + partitioned_tensor, partitioned_tensor.ds_quant_scale = self.quantizer_module.quantize( + partitioned_tensor + ) + + if device == OffloadDeviceEnum.cpu and self.pin_memory: + partitioned_tensor = get_accelerator().pin_memory(partitioned_tensor) + + partitioned_tensor.requires_grad = False + param.ds_tensor = partitioned_tensor + param.ds_tensor.ds_numel = partition_size + param.ds_tensor.status = PartitionedParamStatus.AVAILABLE + param.ds_tensor.final_location = final_location + + start = partition_size * self.get_partition_rank() + end = start + partition_size + + one_dim_param = param.contiguous().view(-1) + + if start < param.ds_numel and end <= param.ds_numel: + src_tensor = one_dim_param.narrow(0, start, partition_size) + + with torch.no_grad(): + # make sure param.ds_tensor requires_grad always be false, + # otherwise, torch tracer will complain. + param.ds_tensor.copy_(src_tensor) + + # partitioned_tensor = src_tensor.clone().detach().to(self.remote_device) + + else: + # partitioned_tensor = torch.zeros(partition_size, + # dtype=param.dtype, + # device=self.remote_device ) + + if start < param.ds_numel: + elems_to_copy = param.ds_numel - start + with torch.no_grad(): + # make sure param.ds_tensor requires_grad always be false, + # otherwise, torch tracer will complain. + param.ds_tensor.narrow(0, 0, elems_to_copy).copy_(one_dim_param.narrow(0, start, elems_to_copy)) + + # print(f"Remote device {self.remote_device}") + + # param.ds_tensor = partitioned_tensor + + # param.data = param.ds_tensor.data + + # param.data does not store anything meaningful in partitioned state + + see_memory_usage(f"Before partitioning param {param.ds_id} {param.shape}", force=False) + free_param(param) + see_memory_usage(f"After partitioning param {param.ds_id} {param.shape}", force=False) + + if param.ds_tensor.final_location == OffloadDeviceEnum.nvme: + self.param_swapper.swap_out_and_release([param]) + print_rank_0(f"ID {param.ds_id} Offloaded to nvme offload and buffers released.") + see_memory_usage(f"ID {param.ds_id} Offloaded to nvme offload and buffers released.", force=False) + + print_rank_0(f"ID {param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}") + + @instrument_w_nvtx + def _partition_param_sec(self, param, buffer=None, has_been_updated=False): + assert param.ds_status is not ZeroParamStatus.INFLIGHT, f" {param} Cannot partition a param in flight" + global reuse_buffers + ##support for NVME secondary param offload + # print_rank_0(f"SEC Param id {param.ds_id} status is {param.ds_status}", force=True) + if param.ds_status is ZeroParamStatus.AVAILABLE: + # check padding + tensor_size = self._aligned_size_sec(param) + partition_size = tensor_size // self.dp_world_size + + secondary_partition_size = int(tensor_size // self.num_ranks_in_param_group) + if param.ds_secondary_tensor is None: + final_location = None + secondary_partitioned_tensor = torch.empty( + secondary_partition_size, dtype=param.dtype, device=self.remote_device + ) + + if self.pin_memory: + secondary_partitioned_tensor = secondary_partitioned_tensor.pin_memory() + # quantize the tensor if it's not trainable + if not param.requires_grad and self.quantized_nontrainable_weights: + ( + secondary_partitioned_tensor, + secondary_partitioned_tensor.ds_quant_scale, + ) = self.quantizer_module.quantize(secondary_partitioned_tensor) + secondary_partitioned_tensor.requires_grad = False + param.ds_secondary_tensor = secondary_partitioned_tensor + param.ds_secondary_tensor.ds_numel = secondary_partition_size + param.ds_secondary_tensor.status = PartitionedParamStatus.AVAILABLE + param.ds_secondary_tensor.final_location = final_location + + # use rank in group for secondary tensor + secondary_start = secondary_partition_size * self.rank_in_group + + secondary_end = secondary_start + secondary_partition_size + + one_dim_param = param.contiguous().view(-1) + + # ds_numel is unpadded, so the last chunk of the secondary tensor might not be secondary_partition_size + sec_numel = param.ds_numel - secondary_start if secondary_end > param.ds_numel else secondary_partition_size + + # copy from full tensor to secondary tensor + param.ds_secondary_tensor.narrow(0, 0, sec_numel).copy_(one_dim_param.narrow(0, secondary_start, sec_numel)) + + # TODO: This is a temporary fix to avoid the issue that 2nd tensor all-gather happens before 2nd tensor partition is done + get_accelerator().current_stream().synchronize() + + print_rank_0( + f"{param.ds_id} partitioned type {param.dtype} dev {param.device} shape {param.shape}", force=False + ) + + def _param_status(self, param): + if param.ds_tensor is not None: + print_rank_0( + f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned numel {param.ds_tensor.numel()}, data numel {param.data.numel()}" + ) + else: + print_rank_0( + f"Param id {param.ds_id}, param status: {param.ds_status}, param numel {param.ds_numel}, partitioned ds_tensor {param.ds_tensor}, data numel {param.data.numel()}" + ) + + def _allgather_param(self, param, async_op=False, hierarchy=0): + + partition_size = param.ds_tensor.ds_numel + + tensor_size = partition_size * self.num_partitions + aligned_param_size = self._aligned_size(param) + assert ( + tensor_size == aligned_param_size + ), f"param id {param.ds_id} aligned size {aligned_param_size} does not match tensor size {tensor_size}" + + print_rank_0( + f"{'--'* hierarchy}---- Before allocating allgather param {debug_param2name_id_shape_status(param)} partition size={partition_size}" + ) + + see_memory_usage( + f"Before allocate allgather param {debug_param2name_id_shape_status(param)} partition_size={partition_size} ", + force=False, + ) + flat_tensor = torch.zeros(aligned_param_size, dtype=param.dtype, device=param.device).view(-1) + see_memory_usage( + f"After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ", + force=False, + ) + + get_accelerator().synchronize() + + print_rank_0( + f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}" + ) + # if not flat_tensor.numel() > 100000: + # replicated_tensor = flat_tensor.narrow(0, + # 0, + # param.ds_numel).view(param.ds_shape) + # param.data = replicated_tensor.data + # return None + if self.use_all_gather_into_tensor: + handle = dist.all_gather_into_tensor( + flat_tensor, + param.ds_tensor.to(get_accelerator().device_name()), + group=self.get_partition_dp_group(param), + async_op=async_op, + ) + else: + partitions = [] + for i in range(self.num_partitions): + partitions.append(flat_tensor.narrow(0, partition_size * i, partition_size)) + + if i == dist.get_rank(group=self.get_partition_dp_group(param)): + partitions[i].data.copy_(param.ds_tensor.data, non_blocking=True) + + handle = dist.all_gather( + partitions, + partitions[self.get_partition_rank()], + group=self.get_partition_dp_group(param), + async_op=async_op, + ) + + replicated_tensor = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape) + param.data = replicated_tensor.data + return handle + + def _allgather_params_coalesced(self, param_list, hierarchy=0, quantize=False): + """blocking call + avoid explicit memory copy in _allgather_params + """ + if len(param_list) == 0: + return + + if self.num_partitions == 1: + handle = _no_gather_coalesced(param_list) + handle.wait() + return None + + # collect local tensors and partition sizes + partition_sizes = [] + local_tensors = [] + if quantize: + quantize_scale_sizes = [] + quantize_scale_tensors = [] + for param in param_list: + partition_sizes.append(param.ds_tensor.ds_numel) + local_tensors.append(param.ds_tensor.to(get_accelerator().device_name())) + if quantize: + quantize_scale_sizes.append(param.ds_tensor.ds_quant_scale.numel()) + quantize_scale_tensors.append(param.ds_tensor.ds_quant_scale.to(get_accelerator().device_name())) + # allocate memory for allgather params + allgather_params = [] + if quantize: + allgather_quantize_scale = [] + for psize in partition_sizes: + tensor_size = psize * self.num_partitions + flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype, device=self.local_device).view( + -1 + ) + flat_tensor.requires_grad = False + allgather_params.append(flat_tensor) + if quantize: + for psize in quantize_scale_sizes: + tensor_size = psize * self.num_partitions + flat_tensor = torch.empty( + tensor_size, dtype=param_list[0].ds_tensor.ds_quant_scale.dtype, device=self.local_device + ).view(-1) + flat_tensor.requires_grad = False + allgather_quantize_scale.append(flat_tensor) + + # launch + launch_handles = [] + launch_quantize_handles = [] + for param_idx, param in enumerate(param_list): + input_tensor = local_tensors[param_idx].view(-1) + + if self.use_all_gather_into_tensor: + # try the _all_gather_base from Pytorch master + h = dist.all_gather_into_tensor( + allgather_params[param_idx], input_tensor, group=self.get_partition_dp_group(param), async_op=True + ) + if quantize: + quantize_handle = dist.all_gather_into_tensor( + allgather_quantize_scale[param_idx], + quantize_scale_tensors[param_idx], + group=self.get_partition_dp_group(param), + async_op=True, + ) + launch_quantize_handles.append(quantize_handle) + else: + output_list = [] + for i in range(self.num_partitions): + psize = partition_sizes[param_idx] + partition = allgather_params[param_idx].narrow(0, i * psize, psize) + output_list.append(partition) + if not get_accelerator().on_accelerator(partition): + logger.warning( + f"param {param_idx}, partition {i} is not on CUDA, partition shape {partition.size()}" + ) + + # back to old all_gather function + h = dist.all_gather(output_list, input_tensor, group=self.get_partition_dp_group(param), async_op=True) + if quantize: + output_scale_list = [] + for i in range(self.num_partitions): + psize = quantize_scale_sizes[param_idx] + partition = allgather_quantize_scale[param_idx].narrow(0, i * psize, psize) + output_scale_list.append(partition) + quant_handle = dist.all_gather( + output_scale_list, + quantize_scale_tensors[param_idx], + group=self.get_partition_dp_group(param), + async_op=True, + ) + launch_quantize_handles.append(quant_handle) + launch_handles.append(h) + + # Wait ensures the operation is enqueued, but not necessarily complete. + launch_handles[-1].wait() + if quantize: + for quant_handle in launch_quantize_handles: + quant_handle.wait() + + # assign to param.data (not copy) + for i, param in enumerate(param_list): + gathered_tensor = allgather_params[i] + if quantize: + gathered_tensor = self.quantizer_module.dequantize(gathered_tensor, allgather_quantize_scale[i]) + param.data = gathered_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape).data + + # guarantee the communication to be completed + get_accelerator().synchronize() + + return None + + def _allgather_params(self, param_list, hierarchy=0): + if len(param_list) == 0: + return + + partition_size = sum([param.ds_tensor.ds_numel for param in param_list]) + + tensor_size = partition_size * self.num_partitions + flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype, device=self.local_device) + flat_tensor.requires_grad = False + partitions = [] + for i in range(self.num_partitions): + start = partition_size * i + + partitions.append(flat_tensor.narrow(0, start, partition_size)) + + if i == self.get_partition_rank(): + offset = 0 + for param in param_list: + param_numel = param.ds_tensor.ds_numel + + partitions[i].narrow(0, offset, param_numel).copy_(param.ds_tensor.data) + + offset += param_numel + + if hasattr(param_list[0], "ds_quant_scale"): + scale_size = sum([param.ds_tensor.ds_quant_scale.numel() for param in param_list]) + scale_tensor_size = scale_size * self.world_size + flat_scale_tensor = torch.empty( + scale_tensor_size, dtype=param_list[0].ds_tensor.ds_quant_scale.dtype, device=self.local_device + ) + flat_scale_tensor.requires_grad = False + scale_partitions = [] + for i in range(self.world_size): + start = scale_tensor_size * i + scale_partitions.append(flat_scale_tensor.narrow(0, start, scale_tensor_size)) + if i == self.rank: + offset = 0 + for param in param_list: + param_scale_numel = param.ds_tensor.ds_quant_scale.ds_numel + + scale_partitions[i].narrow(0, offset, param_scale_numel).copy_( + param.ds_tensor.ds_quant_scale.data + ) + + offset += param_scale_numel + + dist.all_gather_into_tensor( + flat_tensor, partitions[self.get_partition_rank()], group=self.get_partition_dp_group(param), async_op=False + ) + if hasattr(param_list[0], "ds_quant_scale"): + dist.all_gather( + flat_scale_tensor, + param_list[0].ds_quant_scale, + group=self.get_partition_dp_group(param), + async_op=False, + ) + param_offset = 0 + + for param in param_list: + param_partition_size = param.ds_tensor.ds_numel + param_size = param.ds_numel + replicated_tensor = torch.empty(param.ds_shape, dtype=param.ds_tensor.dtype, device=self.local_device) + + for i in range(self.num_partitions): + + start = i * partition_size + + param_start = i * param_partition_size + + if param_start < param_size: + numel_to_copy = min(param_size - param_start, param_partition_size) + + part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy) + + replicated_tensor.view(-1).narrow(0, param_start, numel_to_copy).copy_(part_to_copy) + # param_offset += param.data.numel() + param_offset += param.ds_tensor.ds_numel + if hasattr(param_list[0], "ds_quant_scale"): + replicated_tensor = self.quantizer_module.dequantize(replicated_tensor, flat_scale_tensor) + param.data = replicated_tensor.data + + return None + + def _reduce_scatter_gradients(self, param_list): + # print_rank_0([param.grad for param in param_list]) + # assert any([param.grad is None for param in param_list]), "None gradients cannot be reduce scattered" + + handles_and_reduced_partitions = [] + for param in param_list: + assert ( + param.grad.numel() == param.ds_numel + ), f"{param.grad.numel()} != {param.ds_numel} Cannot reduce scatter gradients whose size is not same as the params" + + handles_and_reduced_partitions.append(self._reduce_scatter_gradient(param)) + + for param, (handle, reduced_partition) in zip(param_list, handles_and_reduced_partitions): + if handle is not None: + handle.wait() + + # some ranks may have partitions that are padded to go beyond the grad size. + # For these ranks the output of reduce scatter is a separate buffer and needs + # to be copied in + partition_size = param.ds_tensor.ds_numel + start = self.get_partition_rank() * partition_size + end = start + partition_size + # print_rank_0("REduce scatter was executed for param {param.ds_id}") + if start < param.ds_numel < end: + elements = param.ds_numel - start + param.grad.view(-1).narrow(0, start, elements).copy_(reduced_partition.narrow(0, 0, elements)) + + def _reduce_scatter_gradient(self, param): + + partition_size = param.ds_tensor.ds_numel + # output = torch.empty(partition_size, dtype=param.dtype, device=param.device) + + total_size = partition_size * self.num_partitions + input_list = [] + + for i in range(self.num_partitions): + + start = i * partition_size + end = start + partition_size + + # print("before reduce scatter gradients") + if start < param.ds_numel and end <= param.ds_numel: + input = param.grad.view(-1).narrow(0, start, partition_size) + else: + input = torch.zeros(partition_size, dtype=param.dtype, device=param.device) + + if start < param.ds_numel: + elements = param.ds_numel - start + input.narrow(0, 0, elements).copy_(param.grad.view(-1).narrow(0, start, elements)) + # print("after reduce scatter gradients") + input_list.append(input) + + rank = dist.get_rank(group=self.get_partition_dp_group(param)) + handle = dist.reduce_scatter( + input_list[rank], input_list, group=self.get_partition_dp_group(param), async_op=True + ) + + return handle, input_list[rank] + + def _partition_gradients(self, param_list, partition_buffers=None, accumulate=False): + if partition_buffers is None: + partition_buffers = [None] * len(param_list) + + for param, partition_buffer in zip(param_list, partition_buffers): + self._partition_gradient(param, partition_buffer=partition_buffer, accumulate=accumulate) + + def _partition_gradient(self, param, partition_buffer=None, accumulate=False): + + # import pdb;pdb.set_trace() + # param.grad=None + # param.grad.test() + print_rank_0( + f"Partitioning param {param.ds_id} gradient of size {param.grad.numel()} type {param.grad.dtype} part_size {param.ds_tensor.ds_numel}" + ) + see_memory_usage("Before partitioning gradients", force=False) + partition_size = param.ds_tensor.ds_numel + + if partition_buffer is None: + assert not accumulate, "No buffer to accumulate to" + partition_buffer = torch.zeros(partition_size, dtype=param.dtype, device=param.device) + else: + assert ( + partition_buffer.numel() >= partition_size + ), f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}" + + rank = dist.get_rank(group=self.get_partition_dp_group(param)) + start = partition_size * rank + end = start + partition_size + + dest_tensor_full_buffer = partition_buffer.view(-1).narrow(0, 0, partition_size) + + # print("before partition gradients") + if start < param.ds_numel: + elements = min(param.ds_numel - start, partition_size) + + dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements) + src_tensor = param.grad.view(-1).narrow(0, start, elements) + + # just copy the grad partition to the buffer + if not accumulate: + dest_tensor.copy_(src_tensor) + + # if source and destination are on same device, + # add to the provided buffer + elif src_tensor.device == dest_tensor.device: + dest_tensor.add_(src_tensor) + + # if source and destination are on different device, copy first to src + # then add and move back to the destination. This seems to run faster + # when src is gpu and dest is cpu + # adding directly to cpu is very slow + else: + acc_tensor = torch.empty(src_tensor.numel(), dtype=param.dtype, device=param.device) + + acc_tensor.copy_(dest_tensor) + acc_tensor.add_(src_tensor) + dest_tensor.copy_(acc_tensor) + + # partition_buffer.view(-1).narrow( + # 0, + # 0, + # elements).copy_(param.grad.view(-1).narrow(0, + # start, + # elements)) + + # print("after partition gradients") + param.grad.data = dest_tensor_full_buffer.data + see_memory_usage("After partitioning gradients", force=False) + + def get_partition_dp_group(self, param): + return param.ds_process_group + + def get_partition_rank(self): + """subclass can overload to specify different relative rank in + parameter partition group""" + return self.rank + + @property + def num_partitions(self): + return self.dp_world_size + + def get_dp_process_group(self): + """Return the communication group with all data-parallel ranks""" + return self.ds_process_group + + +class GatheredParameters: + def __init__(self, params, modifier_rank=None, fwd_module=None, enabled=True): + """A context that collects parameters that were partitioned via a + :class:`deepspeed.zero.Init` context. The parameters are partitioned + again upon exit. + + Args: + params (``torch.nn.Parameter``): A single parameter, or an iterable of parameters (list, tuple, generator) of parameters to collect. + It's assumed that all parameters are zero params. + modifier_rank (int, optional): If specified, this rank's parameter will be + broadcasted on exit from the context. This argument is required if ``params`` are + modified, so that all processes have a consistent view of the data. Defaults + to ``None``. + fwd_module (``torch.nn.Module``, optional): If specified, ``params`` will be + registered as external parameters of ``fwd_module``. See :meth:`deepspeed.zero.register_external_parameter`. + enabled (bool, optional): If ``False``, this context is a no-op. Defaults to ``True``. + + Important: Make sure to use ``modifier_rank`` that is not ``None`` (e.g., ``modifier_rank=0``) + if you need the GPU memory allocated by gather to be released upon exit from the context manager. + + Important: if ``params`` isn't an iterable of parameters or a single parameter it'll be silently ignored! + + Examples + ======== + + #. Allocate a partitioned module, initialize its weight on rank 0, and update all + processes. + + .. code-block:: python + + with deepspeed.zero.Init(): + linear = torch.nn.Linear(1000, 1000) + + with deepspeed.zero.GatheredParameters(linear.weight, modifier_rank=0): + if deepspeed.comm.get_rank() == 0: + linear.weight.zero_() + + with deepspeed.zero.GatheredParameters(linear.weight, modifier_rank=0): + if deepspeed.comm.get_rank() == 0: + linear.weight.zero_() + + #. Collect a partitioned weight to pass to another module during + training. The parameter will be registered as an external parameter + and made available during the backward pass. + + .. code-block:: python + :emphasize-lines: 6 + + def forward(self, input): + x = self.layer1(input) + + # self.layer1.weight is required by self.layer2.forward + with deepspeed.zero.GatheredParameters(self.layer1.weight, fwd_module=self): + y = self.layer2(x, self.layer1.weight) + return y + + + #. Pretrained model loading + + .. code-block:: python + + with deepspeed.zero.Init(): + model = MyModel() + + state_dict = torch.load(model_path, map_location="cpu") + + + def load(module: nn.Module, prefix=""): + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): + if deepspeed.comm.get_rank() == 0: + module._load_from_state_dict(state_dict, prefix) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + + load(model, prefix="") + + If this approach is not used, then the full model will first be copied to each GPU. For models + bigger than the memory of a single GPU, this method is required. + """ + + self.enabled = enabled + if not enabled: + return + + if isinstance(params, Iterable) and not isinstance(params, torch.Tensor): + # deal with generators like model.parameters() + # must convert to list to be able to iterate more than once if we get a generator + params = list(params) + else: + # single param + params = [params] + # enable if at least one is zero-param, otherwise a noop + if not any(is_zero_param(p) for p in params): + self.enabled = False + return + + self.params = [p for p in params if hasattr(p, "ds_id")] + self.params = sorted( + set(self.params), key=lambda x: x.ds_id + ) # remove the duplicates to prevent racing condition, we must also make sure the order is the same on all ranks otherwise we'll get deadlocks + self.src_rank = None + if modifier_rank is not None: + if self.params[0].ds_process_group == dist.get_world_group(): + self.src_rank = modifier_rank + else: + # A group was specified; convert DP rank to global rank + self.src_rank = dist.get_global_rank(self.params[0].ds_process_group, modifier_rank) + self.fwd_module = fwd_module + if self.fwd_module is not None: + # is a no-op if already registered + for p in self.params: + register_external_parameter(self.fwd_module, p) + + def __enter__(self): + if not self.enabled: + return + self.params[0].all_gather(param_list=self.params) + + def __exit__(self, *exc): + if not self.enabled: + return + if self.src_rank is None: + self.params[0].partition(param_list=self.params, has_been_updated=False) + return + + handles = [dist.broadcast(p.data, self.src_rank, group=p.ds_process_group, async_op=True) for p in self.params] + for h in handles: + h.wait() + self.params[0].partition(param_list=self.params, has_been_updated=True) diff --git a/llava/train/deepspeed_replace_deprecated/utils/groups.py b/llava/train/deepspeed_replace_deprecated/utils/groups.py new file mode 100644 index 00000000..1c3571a7 --- /dev/null +++ b/llava/train/deepspeed_replace_deprecated/utils/groups.py @@ -0,0 +1,585 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# The file has been adapted from https://github.com/NVIDIA/Megatron-LM and retains the following license from the original file + +# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + Support different forms of parallelism in DeepSpeed using multiple process groups. + Given that there are multiple scenarios and use-cases, this file is going to be updated + frequently. For now, the group creation needed for the training scenario is being implemented. + For inference and other new scenarios, the code will be either reused or added to this file. +""" + +from deepspeed import comm as dist +from deepspeed.accelerator import get_accelerator +from deepspeed.utils import log_dist +from deepspeed.utils.bwc import bwc_pipeline_parallel_world_size, bwc_tensor_model_parallel_world_size +from deepspeed.utils.exceptions import DeprecatedException + +# Expert parallel group that the current rank belongs to. +_EXPERT_PARALLEL_GROUP = {} +# Expert data parallel group that the current rank belongs to. +_EXPERT_DATA_PARALLEL_GROUP = {} +# dist world group needs to be cloned for some cases +_WORLD_GROUP = None +# ZeRO parameter partitioning group that the current rank belongs to. +_ZERO_PARAM_INTRA_PARALLEL_GROUP = None +# global object to maintain mpu object if passed by a Megatron client +mpu = None +# global object that stores tensor parallel world size for experts +expert_tensor_parallel_world_size = 1 +# All to All quantized graident communication groups +_ALL_TO_ALL_GROUP = {} + +_DATA_PARALLEL_GROUP = None + + +# Deprecated groups initialize function. +def initialize(ep_size=1, mpu=None): + """Deprecated function. Retained to inform the users.""" + raise DeprecatedException( + "Please do not use the groups.initialize() API as it is deprecated. Instead, pass the desired ep_size to deepspeed.moe.layer.MoE(..,ep_size,..)" + ) + + +def _ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}" + + +# Not currently used. Helper function to create a model (tensor) parallel group. +def _create_model_parallel(model_parallel_size_): + """ + Initialize model data parallel groups. + + Arguments: + model_parallel_size: number of GPUs used to parallelize model. + + Returns: + Tuple of data parallel group and model parallel group + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model. The present function will + create 4 model parallel groups and 2 data parallel groups as: + 4 model parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 data parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + log_dist(f"Creating model parallel group with size {model_parallel_size_}", ranks=[0]) + # Get world size and rank. Ensure some consistencies. + assert dist.is_initialized() + world_size = dist.get_world_size() + model_parallel_size = min(model_parallel_size_, world_size) + _ensure_divisibility(world_size, model_parallel_size) + rank = dist.get_rank() + + _DATA_PARALLEL_GROUP = None + _MODEL_PARALLEL_GROUP = None + # Build the data parallel groups. + for i in range(model_parallel_size): + ranks = range(i, world_size, model_parallel_size) + group = dist.new_group(ranks) + if i == (rank % model_parallel_size): + _DATA_PARALLEL_GROUP = group + + # Build the model parallel groups. + for i in range(world_size // model_parallel_size): + ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size) + group = dist.new_group(ranks) + if i == (rank // model_parallel_size): + _MODEL_PARALLEL_GROUP = group + + return _DATA_PARALLEL_GROUP, _MODEL_PARALLEL_GROUP + + +def _create_expert_and_data_parallel(expert_parallel_size_, use_data_before_expert_parallel_=False): + """ + Create expert and data parallel groups. + + Note: Caller of this function is responsible to check if the groups already exist. + + Example - E + D parallel + world_size = 16 + expert_parallel_size = 2 # number of experts in same group + expert_data_parallel_group = [0,2,4,6,8,10,12,14], [1,3,5,7,9,11,13,15] - all reduce is only on MoE params + expert_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - no all reduce, but all to all + data_parallel_group = [0,1,...,15] - all reduce is only on non-MoE + use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology + """ + assert dist.is_initialized() + + log_dist(f"Creating expert and data parallel groups with size {expert_parallel_size_}", ranks=[0]) + world_size = dist.get_world_size() + pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu) + rank = dist.get_rank() + + pp_stride = world_size // pp_world_size + _ensure_divisibility(pp_stride, expert_parallel_size_) + + group_name = f"ep_size_{expert_parallel_size_}" + + # Build the expert data parallel groups. + global _EXPERT_DATA_PARALLEL_GROUP + + ep_stride = pp_stride // expert_parallel_size_ + + # Only create group if it does not already exist + if group_name not in _EXPERT_DATA_PARALLEL_GROUP: + for pp_stage_start in range(0, world_size, pp_stride): + for i in range(expert_parallel_size_): + if use_data_before_expert_parallel_: + ranks = range(pp_stage_start + i * ep_stride, pp_stage_start + (i + 1) * ep_stride) + else: + ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, expert_parallel_size_) + group = dist.new_group(ranks) + log_dist( + f"Creating expert data parallel process group named {group_name} " f"with ranks: {list(ranks)}", [0] + ) + if rank in ranks: + _EXPERT_DATA_PARALLEL_GROUP[group_name] = group + + # Build the expert parallel groups. + global _EXPERT_PARALLEL_GROUP + + # Only create group if it does not already exist + if group_name not in _EXPERT_PARALLEL_GROUP: + if use_data_before_expert_parallel_: + for pp_stage_start in range(0, world_size, pp_stride): + for i in range(ep_stride): + ranks = range(pp_stage_start + i, pp_stage_start + pp_stride, ep_stride) + group = dist.new_group(ranks) + log_dist( + f"creating expert parallel process group named {group_name} " f"with ranks: {list(ranks)}", [0] + ) + if rank in ranks: + _EXPERT_PARALLEL_GROUP[group_name] = group + else: + for i in range(world_size // expert_parallel_size_): + ranks = range(i * expert_parallel_size_, (i + 1) * expert_parallel_size_) + group = dist.new_group(ranks) + log_dist( + f"creating expert parallel process group named {group_name} " f"with ranks: {list(ranks)}", [0] + ) + if rank in ranks: + _EXPERT_PARALLEL_GROUP[group_name] = group + + +def _get_expert_parallel_ranks( + world_size, + tensor_parallel_size_, + expert_parallel_size_, + pipeline_parallel_size_=1, + use_data_before_expert_parallel_=False, +): + """Generate expert parallel and expert data parallel group ranks list. + + Example - E + M + D parallel + world_size = 16 + model_degree = 2 + expert_degree = 4 # number of experts in same group + mp_group = [0, 1], [2,3], [4,5] ... + data_parallel_group =[0,2,4,6,8,10, 12,14], [1,3,5,7,9,11,13,15] + expert_parallel_group = [0,2,4,6], [8,10,12,14] [1,3,5,7], [9,11,13,15] + expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15] + + Args: + world_size (int): Distributed world size. + tensor_parallel_size_ (int): Tensor parallel group size. + expert_parallel_size_ (int): Expert parallel group size. + pipeline_parallel_size_ (int): Pipeline parallel group size + use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology + Returns: + Expert parallel group ranks and Expert data parallel group ranks list. + """ + _ensure_divisibility(world_size, tensor_parallel_size_ * pipeline_parallel_size_) + dp_world_size = world_size // (tensor_parallel_size_ * pipeline_parallel_size_) + _ensure_divisibility(dp_world_size, expert_parallel_size_) + + # Generate data parallel groups + data_parallel_groups = [] + dp_group_size = tensor_parallel_size_ + pp_stride = world_size // pipeline_parallel_size_ + + if use_data_before_expert_parallel_: + dp_stride = world_size // expert_parallel_size_ // tensor_parallel_size_ // pipeline_parallel_size_ + for pp_stage_start in range(0, world_size, pp_stride): + pp_stage_next = pp_stage_start + pp_stride + for i in range(dp_group_size): + data_parallel_groups.append(list()) + for ds in range(dp_stride): + # [0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30] + # [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31] + data_parallel_groups[-1].extend( + list( + range( + pp_stage_start + i + ds * tensor_parallel_size_, + pp_stage_next, + dp_stride * tensor_parallel_size_, + ) + ) + ) + else: + for pp_stage_start in range(0, world_size, pp_stride): + pp_stage_next = pp_stage_start + pp_stride + for i in range(dp_group_size): + data_parallel_groups.append(list(range(pp_stage_start + i, pp_stage_next, dp_group_size))) + + expert_parallel_groups = [] + expert_data_parallel_groups = [] + for dp_ranks in data_parallel_groups: + # partition of expert parallel groups, e.g. [0,2,4,6], [8,10,12,14] + part_ep_groups = [] + for i in range(0, dp_world_size, expert_parallel_size_): + part_ep_groups.append(dp_ranks[i : i + expert_parallel_size_]) + expert_parallel_groups.extend(part_ep_groups) + + # zip part_ep_groups get expert data parallel ranks, e.g [0,8],[2,10],[4,12],[6,14] + for expert_dp_ranks in zip(*part_ep_groups): + expert_data_parallel_groups.append(list(expert_dp_ranks)) + + return expert_parallel_groups, expert_data_parallel_groups + + +def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu, use_data_before_expert_parallel_=False): + """ + Create expert and data parallel groups based on MPU (model parallel) group. + + Note: Caller of this function is responsible to check if the groups already exist. + + Example - E + M + D parallel + world_size = 16 + model_degree = 2 + expert_degree = 4 # number of experts in same group + mp_group = [0, 1], [2,3], [4,5] ... + data_parallel_group =[0,2,4,6,8,10, 12,14], [1,3,5,7,9,11,13,15] + expert_parallel_group = [0,2,4,6], [8,10,12,14] [1,3,5,7], [9,11,13,15] + expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15] + """ + assert dist.is_initialized(), "dist is not initialized" + tensor_parallel_size_ = bwc_tensor_model_parallel_world_size(mpu) + + global expert_tensor_parallel_world_size + expert_tensor_parallel_world_size = tensor_parallel_size_ + + world_size = dist.get_world_size() + rank = dist.get_rank() + dp_world_size = mpu.get_data_parallel_world_size() + pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu) + + _ensure_divisibility(world_size, tensor_parallel_size_) + _ensure_divisibility(dp_world_size, expert_parallel_size_) + + log_dist( + f"Creating deepspeed groups with model parallel size {tensor_parallel_size_}, " + f"pipeline parallel size {pp_world_size}, expert parallel size {expert_parallel_size_}, " + f"world size {world_size}, dp world size {dp_world_size}", + [0], + ) + + global _EXPERT_PARALLEL_GROUP, _EXPERT_DATA_PARALLEL_GROUP + + group_name = f"ep_size_{expert_parallel_size_}" + + # Only create groups if they don't already exist + # Need to check conditions outside the group creation loop because of the way torch.dist group creation works + if group_name not in _EXPERT_DATA_PARALLEL_GROUP and group_name not in _EXPERT_PARALLEL_GROUP: + expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks( + world_size, tensor_parallel_size_, expert_parallel_size_, pp_world_size, use_data_before_expert_parallel_ + ) + for ranks in expert_parallel_groups: + group = dist.new_group(ranks) + if rank in list(ranks): + _EXPERT_PARALLEL_GROUP[group_name] = group + + for ranks in expert_data_parallel_groups: + group = dist.new_group(ranks) + if rank in list(ranks): + _EXPERT_DATA_PARALLEL_GROUP[group_name] = group + + +def _get_max_expert_size(): + """Get the maximum ep_size from all the created groups.""" + assert _EXPERT_PARALLEL_GROUP is not None, "Warning! Process group not initialized" + keylist = [] + for key in _EXPERT_PARALLEL_GROUP.keys(): + # index 2 is ep_size in the group name: ep_size_ + index = 2 + keylist.append(int(key.split("_")[index])) + return max(keylist) if len(keylist) > 0 else None + + +def _get_max_expert_size_name(): + """Get the name of the group with max. ep_size""" + return f"ep_size_{_get_max_expert_size()}" + + +def _get_max_expert_parallel_group(): + """Get the max expert parallel size.""" + return _get_expert_parallel_group(_get_max_expert_size_name()) + + +def _get_expert_parallel_group(group_name): + """Get the expert parallel group the caller rank belongs to.""" + assert group_name in _EXPERT_PARALLEL_GROUP, "expert parallel group is not initialized" + return _EXPERT_PARALLEL_GROUP[group_name] + + +def _get_expert_parallel_group_dict(): + """Get the expert parallel group dict.""" + return _EXPERT_PARALLEL_GROUP + + +def _get_expert_data_parallel_group(group_name): + """Get the expert data parallel group the caller rank belongs to.""" + assert group_name in _EXPERT_DATA_PARALLEL_GROUP, "expert data parallel group is not initialized" + return _EXPERT_DATA_PARALLEL_GROUP[group_name] + + +def _get_expert_data_parallel_group_dict(): + """Get the expert data parallel group dict.""" + return _EXPERT_DATA_PARALLEL_GROUP + + +def _clone_world_group(): + """Create a clone of the world group + Note: We need to clone the dist world group because we + use dist.get_global_rank() utility function in DeepSpeed at many places. + As that function does not work on dist.group.WORLD, we + need to keep a clone of it. + """ + assert dist.is_initialized(), "dist is not initialized" + global _WORLD_GROUP + if _WORLD_GROUP is None: + # If not cloned already, clone the world group + _WORLD_GROUP = dist.new_group(ranks=range(dist.get_world_size())) + return _WORLD_GROUP + + +def _get_local_all_to_all_group(): + assert dist.is_initialized(), "dist is not initialized" + global _ALL_TO_ALL_GROUP + device_per_node = get_accelerator().device_count() + num_local = dist.get_world_size() // device_per_node + if num_local == 0 and dist.get_world_size() > 0: + assert dist.get_world_size() >= 1, "num_gpus must >=1, cannot initialize All-To-All" + cur_rank = [] + for i in range(dist.get_world_size()): + cur_rank.append(i) + _ALL_TO_ALL_GROUP["local_0"] = dist.new_group(ranks=cur_rank) + elif num_local == 1: + assert ( + dist.get_world_size() == device_per_node + ), "num_gpus not equal to device per node, cannot initialize All-To-All" + _ALL_TO_ALL_GROUP["local_0"] = dist.new_group(ranks=[i for i in range(device_per_node)]) + else: + assert dist.get_world_size() > device_per_node, "num_nodes<2 cannot initialize All-To-All" + for i in range(num_local): + local_rank = [j + device_per_node * i for j in range(device_per_node)] + _ALL_TO_ALL_GROUP[f"local_{i}"] = dist.new_group(ranks=local_rank) + + for i in range(device_per_node): + cur_rank = [] + for j in range(num_local): + cur_rank.append(i + j * device_per_node) + _ALL_TO_ALL_GROUP[f"global_{i}"] = dist.new_group(ranks=cur_rank) + return _ALL_TO_ALL_GROUP + + +def _get_data_parallel_group(): + """Get the data parallel group the caller rank belongs to.""" + assert dist.is_initialized(), "dist is not initialized" + global mpu + if mpu is not None: + return mpu.get_data_parallel_group() + # Return the clone of dist world group + return _clone_world_group() + + +def _get_broadcast_src_rank(): + return dist.get_global_rank(_get_sequence_data_parallel_group(), 0) + + +def _get_expert_broadcast_src_rank(group_name): + return dist.get_global_rank(_get_expert_data_parallel_group(group_name), 0) + + +def _get_expert_parallel_world_size(group_name): + """Return world size for the expert parallel group.""" + return dist.get_world_size(group=_get_expert_parallel_group(group_name)) + + +def _get_expert_data_parallel_world_size(group_name): + """Return world size for the expert data parallel group.""" + return dist.get_world_size(group=_get_expert_data_parallel_group(group_name)) + + +def _get_expert_parallel_rank(group_name): + """Return my rank for the expert parallel group.""" + return dist.get_rank(group=_get_expert_parallel_group(group_name)) + + +def _get_expert_parallel_src_rank(group_name): + """Calculate the global rank corresponding to a local rank zero + in the expert parallel group.""" + global_rank = dist.get_rank() + local_world_size = _get_expert_parallel_world_size(group_name) + return (global_rank // local_world_size) * local_world_size + + +def _get_expert_data_parallel_rank(group_name): + """Return my rank for the expert data parallel group.""" + return dist.get_rank(group=_get_expert_data_parallel_group(group_name)) + + +def _get_data_parallel_world_size(): + """Return world size for the data parallel group.""" + global mpu + if mpu is not None: + return mpu.get_data_parallel_world_size() + return dist.get_world_size(group=_get_data_parallel_group()) + + +def _get_model_parallel_world_size(): + """Return world size for the model parallel group.""" + global mpu + if mpu is not None: + return mpu.get_model_parallel_world_size() + return 1 + + +def _get_data_parallel_rank(): + """Return my rank for the data parallel group.""" + return dist.get_rank(group=_get_data_parallel_group()) + + +def _get_sequence_parallel_world_size(): + """Return world size for the model parallel group.""" + global mpu + if mpu is not None and hasattr(mpu, "get_sequence_parallel_world_size"): + return mpu.get_sequence_parallel_world_size() + return 1 + + +def _get_sequence_parallel_rank(): + """Return my rank for the data parallel group.""" + global mpu + if mpu is not None and hasattr(mpu, "get_sequence_parallel_rank"): + return mpu.get_sequence_parallel_rank() + return 0 + + +def _get_sequence_parallel_group(): + global mpu + if mpu is not None and hasattr(mpu, "get_sequence_parallel_group"): + return mpu.get_sequence_parallel_group() + return None + + +def _get_sequence_data_parallel_world_size(): + """Return world size for the model parallel group.""" + global mpu + if mpu is not None and hasattr(mpu, "get_sequence_data_parallel_world_size"): + return mpu.get_sequence_data_parallel_world_size() + return _get_data_parallel_world_size() + + +def _get_sequence_data_parallel_rank(): + """Return my rank for the data parallel group.""" + global mpu + if mpu is not None and hasattr(mpu, "get_sequence_data_parallel_rank"): + return mpu.get_sequence_data_parallel_rank() + return _get_data_parallel_rank() + + +def _get_sequence_data_parallel_group(): + global mpu + # When sequence parallelism is enabled, the process group for zero sharding and + # gradient allreduce must be across both dimensions of data and sequence parallelism. + if mpu is not None and hasattr(mpu, "get_sequence_data_parallel_group"): + return mpu.get_sequence_data_parallel_group() + return _get_data_parallel_group() + + +def _get_expert_model_parallel_world_size(): + global expert_tensor_parallel_world_size + return expert_tensor_parallel_world_size + + +def _create_zero_param_parallel_group(group_size): + """ + Create parameter partitioning group within ZeRO data parallel groups. + + Example - ZP + D parallel + world_size = 16 + zero_hpz_partition_size = 2 # number of ranks with replicated params (dual partitioning) + zero_param_intra_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - segmented (subgroup) with rep partition + data_parallel_group = [0,1,...,15] - all reduce is on ZeRO model + """ + assert dist.is_initialized() + global _ZERO_PARAM_INTRA_PARALLEL_GROUP + # Only create group if it does not already exist + if _ZERO_PARAM_INTRA_PARALLEL_GROUP is not None: + return + + world_size = dist.get_world_size() + rank = dist.get_rank() + + zero_param_parallel_size_ = min(group_size, world_size) + _ensure_divisibility(world_size, zero_param_parallel_size_) + + # Build the ZeRO param intra parallel groups. + for i in range(world_size // zero_param_parallel_size_): + ranks = range(i * zero_param_parallel_size_, (i + 1) * zero_param_parallel_size_) + group = dist.new_group(ranks) + if i == (rank // zero_param_parallel_size_): + _ZERO_PARAM_INTRA_PARALLEL_GROUP = group + + +def _get_zero_param_intra_parallel_group(): + """Get the ZeRO parameter partitioning intra parallel group the caller rank belongs to.""" + # assert _ZERO_PARAM_INTRA_PARALLEL_GROUP is not None, \ + # 'ZeRO parameter partitioning group is not initialized' + # TODO: Add warning + return _ZERO_PARAM_INTRA_PARALLEL_GROUP + + +def _zero_param_parallel_is_initialized(): + """Check if ZeRO data parallel with parameter partititioning groups are initialized.""" + ###TODO: assert that MPU is not set + if _ZERO_PARAM_INTRA_PARALLEL_GROUP is None and _DATA_PARALLEL_GROUP is None: + return False + + +def _get_zero_param_intra_parallel_rank_in_mygroup(): + """Return my rank for the ZeRO parameter inter parallel group.""" + return dist.get_rank(group=_get_zero_param_intra_parallel_group()) + + +def _get_zero_param_intra_parallel_group_world_size(): + """Return world size for the ZeRO parameter parallel group.""" + return dist.get_world_size(group=_get_zero_param_intra_parallel_group()) + + +def _get_zero_param_intra_parallel_group_ranks(): + """Return all ranks for the ZeRO parameter intra parallel group.""" + return dist.get_all_ranks_from_group(group=_get_zero_param_intra_parallel_group()) diff --git a/llava/train/transformers_replace/modeling_utils.py b/llava/train/transformers_replace/modeling_utils.py new file mode 100755 index 00000000..888192ef --- /dev/null +++ b/llava/train/transformers_replace/modeling_utils.py @@ -0,0 +1,4830 @@ +# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import collections +import copy +import functools +import gc +import importlib.metadata +import inspect +import json +import os +import re +import shutil +import tempfile +import warnings +from contextlib import contextmanager +from dataclasses import dataclass +from functools import partial, wraps +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from packaging import version +from torch import Tensor, nn +from torch.nn import CrossEntropyLoss, Identity +from torch.utils.checkpoint import checkpoint + +from .activations import get_activation +from .configuration_utils import PretrainedConfig +from .dynamic_module_utils import custom_object_save +from .generation import GenerationConfig, GenerationMixin +from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled +from .pytorch_utils import ( # noqa: F401 + Conv1D, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + id_tensor_storage, + prune_conv1d_layer, + prune_layer, + prune_linear_layer, +) +from .safetensors_conversion import auto_conversion +from .utils import ( + ADAPTER_SAFE_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, + CONFIG_NAME, + DUMMY_INPUTS, + FLAX_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + TF2_WEIGHTS_NAME, + TF_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + ContextManagers, + ModelOutput, + PushToHubMixin, + cached_file, + copy_func, + download_url, + extract_commit_hash, + has_file, + is_accelerate_available, + is_auto_awq_available, + is_auto_gptq_available, + is_bitsandbytes_available, + is_flash_attn_2_available, + is_offline_mode, + is_optimum_available, + is_peft_available, + is_remote_url, + is_safetensors_available, + is_torch_sdpa_available, + is_torch_tpu_available, + logging, + replace_return_docstrings, + strtobool, +) +from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files +from .utils.import_utils import ( + ENV_VARS_TRUE_VALUES, + is_sagemaker_mp_enabled, + is_torch_fx_proxy, + is_torchdynamo_compiling, +) +from .utils.quantization_config import AwqConfig, BitsAndBytesConfig, GPTQConfig, QuantizationMethod +from .utils.versions import require_version_core + +XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() +XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() + +if is_accelerate_available(): + from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights + from accelerate.hooks import add_hook_to_module + from accelerate.utils import ( + check_tied_parameters_on_same_device, + find_tied_parameters, + get_balanced_memory, + get_max_memory, + load_offloaded_weights, + offload_weight, + save_offload_index, + set_module_tensor_to_device, + ) + +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.torch import load_file as safe_load_file + from safetensors.torch import save_file as safe_save_file + +logger = logging.get_logger(__name__) + + +_init_weights = True + + +def is_fsdp_enabled(): + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1 + and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1 + ) + + +def is_local_dist_rank_0(): + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and int(os.environ.get("LOCAL_RANK", -1)) == 0 + ) + + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + from smdistributed.modelparallel import __version__ as SMP_VERSION + + IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") +else: + IS_SAGEMAKER_MP_POST_1_10 = False + +if is_peft_available(): + from .utils import find_adapter_config_file + +TORCH_INIT_FUNCTIONS = { + "uniform_": nn.init.uniform_, + "normal_": nn.init.normal_, + "trunc_normal_": nn.init.trunc_normal_, + "constant_": nn.init.constant_, + "xavier_uniform_": nn.init.xavier_uniform_, + "xavier_normal_": nn.init.xavier_normal_, + "kaiming_uniform_": nn.init.kaiming_uniform_, + "kaiming_normal_": nn.init.kaiming_normal_, + "uniform": nn.init.uniform, + "normal": nn.init.normal, + "xavier_uniform": nn.init.xavier_uniform, + "xavier_normal": nn.init.xavier_normal, + "kaiming_uniform": nn.init.kaiming_uniform, + "kaiming_normal": nn.init.kaiming_normal, +} + + +@contextmanager +def no_init_weights(_enable=True): + """ + Context manager to globally disable weight initialization to speed up loading large models. + + TODO(Patrick): Delete safety argument `_enable=True` at next major version. . + """ + global _init_weights + old_init_weights = _init_weights + + if _enable: + _init_weights = False + + def _skip_init(*args, **kwargs): + pass + + # # Save the original initialization functions + for name, init_func in TORCH_INIT_FUNCTIONS.items(): + setattr(torch.nn.init, name, _skip_init) + try: + yield + finally: + _init_weights = old_init_weights + if _enable: + # # Restore the original initialization functions + for name, init_func in TORCH_INIT_FUNCTIONS.items(): + setattr(torch.nn.init, name, init_func) + + +def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): + try: + return next(parameter.parameters()).device + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_first_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): + """ + Returns the first parameter dtype (can be non-floating) or asserts if none were found. + """ + try: + return next(parameter.parameters()).dtype + except StopIteration: + # For nn.DataParallel compatibility in PyTorch > 1.5 + + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): + """ + Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found. + """ + last_dtype = None + for t in parameter.parameters(): + last_dtype = t.dtype + if t.is_floating_point(): + # Adding fix for https://github.com/pytorch/xla/issues/4152 + # Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1 + # and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf + # NOTE: `is_torch_tpu_available()` is checked last as it induces a graph break in torch dynamo + if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available(): + return torch.bfloat16 + if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_tpu_available(): + if t.dtype == torch.float: + return torch.bfloat16 + if t.dtype == torch.double: + return torch.float32 + return t.dtype + + if last_dtype is not None: + # if no floating dtype was found return whatever the first dtype is + return last_dtype + + # For nn.DataParallel compatibility in PyTorch > 1.5 + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + last_tuple = None + for tuple in gen: + last_tuple = tuple + if tuple[1].is_floating_point(): + return tuple[1].dtype + + if last_tuple is not None: + # fallback to the last dtype + return last_tuple[1].dtype + + # fallback to buffer dtype + for t in parameter.buffers(): + last_dtype = t.dtype + if t.is_floating_point(): + return t.dtype + return last_dtype + + +def get_state_dict_float_dtype(state_dict): + """ + Returns the first found floating dtype in `state_dict` or asserts if none were found. + """ + for t in state_dict.values(): + if t.is_floating_point(): + return t.dtype + + raise ValueError("couldn't find any floating point dtypes in state_dict") + + +def get_state_dict_dtype(state_dict): + """ + Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype. + """ + for t in state_dict.values(): + if t.is_floating_point(): + return t.dtype + + # if no floating dtype was found return whatever the first dtype is + else: + return next(state_dict.values()).dtype + + +def dtype_byte_size(dtype): + """ + Returns the size (in bytes) occupied by one parameter of type `dtype`. + + Example: + + ```py + >>> dtype_byte_size(torch.float32) + 4 + ``` + """ + if dtype == torch.bool: + return 1 / 8 + bit_search = re.search(r"[^\d](\d+)$", str(dtype)) + if bit_search is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + bit_size = int(bit_search.groups()[0]) + return bit_size // 8 + + +def shard_checkpoint( + state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME +): + """ + Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. + + The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no + optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the + limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], + [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. + + + + If one of the model's weight is bigger than `max_shard_size`, it will end up in its own sub-checkpoint which will + have a size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save. + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit + (like `"5MB"`). + weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`): + The name of the model save file. + """ + max_shard_size = convert_file_size_to_int(max_shard_size) + + sharded_state_dicts = [{}] + last_block_size = 0 + total_size = 0 + storage_id_to_block = {} + + for key, weight in state_dict.items(): + # when bnb serialization is used the weights in the state dict can be strings + # check: https://github.com/huggingface/transformers/pull/24416 for more details + if isinstance(weight, str): + continue + else: + storage_id = id_tensor_storage(weight) + + # If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block` + if storage_id in storage_id_to_block: + block_id = storage_id_to_block[storage_id] + sharded_state_dicts[block_id][key] = weight + continue + + weight_size = weight.numel() * dtype_byte_size(weight.dtype) + + # If this weight is going to tip up over the maximal size, we split, but only if we have put at least one + # weight in the current shard. + if last_block_size + weight_size > max_shard_size and len(sharded_state_dicts[-1]) > 0: + sharded_state_dicts.append({}) + last_block_size = 0 + + sharded_state_dicts[-1][key] = weight + last_block_size += weight_size + total_size += weight_size + storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1 + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {weights_name: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + for idx, shard in enumerate(sharded_state_dicts): + shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin") + shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors") + shards[shard_file] = shard + for key in shard.keys(): + weight_map[key] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + return shards, index + + +def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True): + """ + This is the same as + [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict) + but for a sharded checkpoint. + + This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being + loaded in the model. + + Args: + model (`torch.nn.Module`): The model in which to load the checkpoint. + folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint. + strict (`bool`, *optional`, defaults to `True`): + Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint. + prefer_safe (`bool`, *optional*, defaults to `False`) + If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the + safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible. + + Returns: + `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields + - `missing_keys` is a list of str containing the missing keys + - `unexpected_keys` is a list of str containing the unexpected keys + """ + # Load the index + index_file = os.path.join(folder, WEIGHTS_INDEX_NAME) + safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME) + + index_present = os.path.isfile(index_file) + safe_index_present = os.path.isfile(safe_index_file) + + if not index_present and not (safe_index_present and is_safetensors_available()): + filenames = ( + (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME) if is_safetensors_available() else (WEIGHTS_INDEX_NAME,) + ) + raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.") + + load_safe = False + if safe_index_present: + if prefer_safe: + if is_safetensors_available(): + load_safe = True # load safe due to preference + else: + logger.warning(f"Cannot load sharded checkpoint at {folder} safely since safetensors is not installed!") + elif not index_present: + load_safe = True # load safe since we have no other choice + + load_index = safe_index_file if load_safe else index_file + + with open(load_index, encoding="utf-8") as f: + index = json.load(f) + + shard_files = list(set(index["weight_map"].values())) + + # If strict=True, error before loading any of the state dicts. + loaded_keys = index["weight_map"].keys() + model_keys = model.state_dict().keys() + missing_keys = [key for key in model_keys if key not in loaded_keys] + unexpected_keys = [key for key in loaded_keys if key not in model_keys] + if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0): + error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}" + if len(missing_keys) > 0: + str_missing_keys = ",".join([f'"{k}"' for k in missing_keys]) + error_message += f"\nMissing key(s): {str_missing_keys}." + if len(unexpected_keys) > 0: + str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys]) + error_message += f"\nMissing key(s): {str_unexpected_keys}." + raise RuntimeError(error_message) + + loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu") + + for shard_file in shard_files: + state_dict = loader(os.path.join(folder, shard_file)) + model.load_state_dict(state_dict, strict=False) + + # Make sure memory is freed before we load the next state dict. + del state_dict + gc.collect() + + # Return the same thing as PyTorch load_state_dict function. + return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys) + + +def load_state_dict(checkpoint_file: Union[str, os.PathLike]): + """ + Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. + """ + if checkpoint_file.endswith(".safetensors") and is_safetensors_available(): + # Check format of the archive + with safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + if metadata.get("format") not in ["pt", "tf", "flax"]: + raise OSError( + f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " + "you save your model with the `save_pretrained` method." + ) + return safe_load_file(checkpoint_file) + try: + if ( + is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0 + ) or (is_fsdp_enabled() and not is_local_dist_rank_0()): + map_location = "meta" + else: + map_location = "cpu" + + return torch.load(checkpoint_file, map_location=map_location) + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read(7) == "version": + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' " + f"at '{checkpoint_file}'. " + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." + ) + + +def set_initialized_submodules(model, state_dict_keys): + """ + Sets the `_is_hf_initialized` flag in all submodules of a given model when all its weights are in the loaded state + dict. + """ + for module_name, module in model.named_modules(): + loaded_keys = [k.replace(f"{module_name}.", "") for k in state_dict_keys if k.startswith(f"{module_name}.")] + if len(set(module.state_dict().keys()) - set(loaded_keys)) == 0: + module._is_hf_initialized = True + + +def _load_state_dict_into_model(model_to_load, state_dict, start_prefix): + # Convert old format to new format if needed from a PyTorch state_dict + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, "_metadata", None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: nn.Module, state_dict, prefix=""): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) + # Parameters of module and children will start with prefix. We can exit early if there are none in this + # state_dict + if len([key for key in state_dict if key.startswith(prefix)]) > 0: + if is_deepspeed_zero3_enabled(): + import deepspeed + + # In sharded models, each shard has only part of the full state_dict, so only gather + # parameters that are in the current state_dict. + named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) + params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] + if len(params_to_gather) > 0: + # because zero3 puts placeholders in model params, this context + # manager gathers (unpartitions) the params of the current layer, then loads from + # the state dict and then re-partitions them again + with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): + if torch.distributed.get_rank() == 0: + module._load_from_state_dict(*args) + else: + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, state_dict, prefix + name + ".") + + load(model_to_load, state_dict, prefix=start_prefix) + # Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so + # it's safe to delete it. + del state_dict + + return error_msgs + + +def find_submodule_and_param_name(model, long_key, start_prefix): + """ + A helper util to find the last sub-module and the param/buffer name. If `start_prefix` is supplied it'll be removed + from the start of the key + """ + + if len(start_prefix) > 0 and long_key.startswith(start_prefix): + long_key = ".".join(long_key.split(".")[1:]) + + split_key = long_key.split(".") + submodule = model + while len(split_key) > 1: + if hasattr(submodule, split_key[0]): + submodule = getattr(submodule, split_key[0]) + del split_key[0] + else: + submodule = None + break + if submodule == model: + submodule = None + return submodule, split_key[0] + + +def _move_model_to_meta(model, loaded_state_dict_keys, start_prefix): + """ + Moves `loaded_state_dict_keys` in model to meta device which frees up the memory taken by those params. + + `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in + `bert.pooler.dense.weight` + + """ + + # dematerialize param storage for keys that are going to be replaced by state_dict, by + # putting those on the meta device + for k in loaded_state_dict_keys: + submodule, param_name = find_submodule_and_param_name(model, k, start_prefix) + if submodule is not None: + # selectively switch to the meta device only those params/buffers that will + # be next replaced from state_dict. This a complex way to do p.to_("meta") + # since we have no in-place to_ for tensors. + new_val = getattr(submodule, param_name) + if isinstance(new_val, torch.nn.Parameter): + # isinstance returns False for Params on meta device, so switch after the check + new_val = torch.nn.Parameter(new_val.to("meta")) + else: + new_val = new_val.to("meta") + setattr(submodule, param_name, new_val) + + +def _load_state_dict_into_meta_model( + model, + state_dict, + loaded_state_dict_keys, # left for now but could be removed, see below + start_prefix, + expected_keys, + device_map=None, + offload_folder=None, + offload_index=None, + state_dict_folder=None, + state_dict_index=None, + dtype=None, + is_quantized=False, + is_safetensors=False, + keep_in_fp32_modules=None, +): + """ + This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its + params on a `meta` device. It replaces the model params with the data from the `state_dict`, while moving the + params back to the normal device, but only for `loaded_state_dict_keys`. + + `start_prefix` is used for models which insert their name into model keys, e.g. `bert` in + `bert.pooler.dense.weight` + + """ + + # XXX: remaining features to implement to be fully compatible with _load_state_dict_into_model + # - deepspeed zero 3 support + # - need to copy metadata if any - see _load_state_dict_into_model + # - handling error_msgs - mimicking the error handling in module._load_from_state_dict() + # - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case + # they won't get loaded. + + if is_quantized: + from .integrations import set_module_quantized_tensor_to_device + + error_msgs = [] + + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if "gamma" in key: + new_key = key.replace("gamma", "weight") + if "beta" in key: + new_key = key.replace("beta", "bias") + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + for param_name, param in state_dict.items(): + # First part of the test is always true as load_state_dict_keys always contains state_dict keys. + if param_name not in loaded_state_dict_keys or param_name not in expected_keys: + continue + + if param_name.startswith(start_prefix): + param_name = param_name[len(start_prefix) :] + + module_name = param_name + set_module_kwargs = {} + + # We convert floating dtypes to the `dtype` passed. We want to keep the buffers/params + # in int/uint/bool and not cast them. + if dtype is not None and torch.is_floating_point(param): + if ( + keep_in_fp32_modules is not None + and any( + module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules + ) + and dtype == torch.float16 + ): + param = param.to(torch.float32) + + # For backward compatibility with older versions of `accelerate` + # TODO: @sgugger replace this check with version check at the next `accelerate` release + if "dtype" in list(inspect.signature(set_module_tensor_to_device).parameters): + set_module_kwargs["dtype"] = torch.float32 + else: + param = param.to(dtype) + + # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model + if dtype is None: + old_param = model + splits = param_name.split(".") + for split in splits: + old_param = getattr(old_param, split) + if old_param is None: + break + + if old_param is not None: + param = param.to(old_param.dtype) + + set_module_kwargs["value"] = param + + if device_map is None: + param_device = "cpu" + else: + # find next higher level module that is defined in device_map: + # bert.lm_head.weight -> bert.lm_head -> bert -> '' + while len(module_name) > 0 and module_name not in device_map: + module_name = ".".join(module_name.split(".")[:-1]) + if module_name == "" and "" not in device_map: + # TODO: group all errors and raise at the end. + raise ValueError(f"{param_name} doesn't have any device set.") + param_device = device_map[module_name] + + if param_device == "disk": + if not is_safetensors: + offload_index = offload_weight(param, param_name, offload_folder, offload_index) + elif param_device == "cpu" and state_dict_index is not None: + state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index) + elif not is_quantized: + # For backward compatibility with older versions of `accelerate` + set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) + else: + if param.dtype == torch.int8 and param_name.replace("weight", "SCB") in state_dict.keys(): + fp16_statistics = state_dict[param_name.replace("weight", "SCB")] + else: + fp16_statistics = None + + if "SCB" not in param_name: + set_module_quantized_tensor_to_device( + model, param_name, param_device, value=param, fp16_statistics=fp16_statistics + ) + + return error_msgs, offload_index, state_dict_index + + +def _add_variant(weights_name: str, variant: Optional[str] = None) -> str: + if variant is not None: + splits = weights_name.split(".") + splits = splits[:-1] + [variant] + splits[-1:] + weights_name = ".".join(splits) + + return weights_name + + +class ModuleUtilsMixin: + """ + A few utilities for `torch.nn.Modules`, to be used as a mixin. + """ + + @staticmethod + def _hook_rss_memory_pre_forward(module, *args, **kwargs): + try: + import psutil + except ImportError: + raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.") + + process = psutil.Process(os.getpid()) + mem = process.memory_info() + module.mem_rss_pre_forward = mem.rss + return None + + @staticmethod + def _hook_rss_memory_post_forward(module, *args, **kwargs): + try: + import psutil + except ImportError: + raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.") + + process = psutil.Process(os.getpid()) + mem = process.memory_info() + module.mem_rss_post_forward = mem.rss + mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward + module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0) + return None + + def add_memory_hooks(self): + """ + Add a memory hook before and after each sub-module forward pass to record increase in memory consumption. + + Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero + with `model.reset_memory_hooks_state()`. + """ + for module in self.modules(): + module.register_forward_pre_hook(self._hook_rss_memory_pre_forward) + module.register_forward_hook(self._hook_rss_memory_post_forward) + self.reset_memory_hooks_state() + + def reset_memory_hooks_state(self): + """ + Reset the `mem_rss_diff` attribute of each module (see [`~modeling_utils.ModuleUtilsMixin.add_memory_hooks`]). + """ + for module in self.modules(): + module.mem_rss_diff = 0 + module.mem_rss_post_forward = 0 + module.mem_rss_pre_forward = 0 + + @property + def device(self) -> torch.device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (`torch.Tensor`): An attention mask. + + Returns: + `torch.Tensor`: The inverted attention mask. + """ + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min + + return encoder_extended_attention_mask + + @staticmethod + def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None): + if device is not None: + warnings.warn( + "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning + ) + else: + device = attention_mask.device + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = torch.cat( + [ + torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + return extended_attention_mask + + def get_extended_attention_mask( + self, attention_mask: Tensor, input_shape: Tuple[int], device: torch.device = None, dtype: torch.float = None + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + + Returns: + `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + if dtype is None: + dtype = self.dtype + + if not (attention_mask.dim() == 2 and self.config.is_decoder): + # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder` + if device is not None: + warnings.warn( + "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning + ) + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder: + extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder( + input_shape, attention_mask, device + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min + return extended_attention_mask + + def get_head_mask( + self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False + ) -> Tensor: + """ + Prepare the head mask if needed. + + Args: + head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*): + The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). + num_hidden_layers (`int`): + The number of hidden layers in the model. + is_attention_chunked (`bool`, *optional*, defaults to `False`): + Whether or not the attentions scores are computed by chunks or not. + + Returns: + `torch.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with + `[None]` for each layer. + """ + if head_mask is not None: + head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) + if is_attention_chunked is True: + head_mask = head_mask.unsqueeze(-1) + else: + head_mask = [None] * num_hidden_layers + + return head_mask + + def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): + """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer + assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}" + head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility + return head_mask + + def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: + """ + Get number of (optionally, trainable or non-embeddings) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters + + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embeddings parameters + + Returns: + `int`: The number of parameters. + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding) + ] + total_parameters = [ + parameter for name, parameter in self.named_parameters() if name not in embedding_param_names + ] + else: + total_parameters = list(self.parameters()) + + total_numel = [] + is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False) + if is_loaded_in_4bit: + if is_bitsandbytes_available(): + import bitsandbytes as bnb + else: + raise ValueError( + "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong" + " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. " + ) + + for param in total_parameters: + if param.requires_grad or not only_trainable: + # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are + # used for the 4bit quantization (uint8 tensors are stored) + if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit): + total_numel.append(param.numel() * 2) + else: + total_numel.append(param.numel()) + + return sum(total_numel) + + def estimate_tokens(self, input_dict: Dict[str, Union[torch.Tensor, Any]]) -> int: + """ + Helper function to estimate the total number of tokens from the model inputs. + + Args: + inputs (`dict`): The model inputs. + + Returns: + `int`: The total number of tokens. + """ + if not hasattr(self, "warnings_issued"): + self.warnings_issued = {} + if self.main_input_name in input_dict: + return input_dict[self.main_input_name].numel() + elif "estimate_tokens" not in self.warnings_issued: + logger.warning( + "Could not estimate the number of tokens of the input, floating-point operations will not be computed" + ) + self.warnings_issued["estimate_tokens"] = True + return 0 + + def floating_point_ops( + self, input_dict: Dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True + ) -> int: + """ + Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a + batch with this transformer model. Default approximation neglects the quadratic dependency on the number of + tokens (valid if `12 * d_model << sequence_length`) as laid out in [this + paper](https://arxiv.org/pdf/2001.08361.pdf) section 2.1. Should be overridden for transformers with parameter + re-use e.g. Albert or Universal Transformers, or if doing long-range modeling with very high sequence lengths. + + Args: + batch_size (`int`): + The batch size for the forward pass. + + sequence_length (`int`): + The number of tokens in each line of the batch. + + exclude_embeddings (`bool`, *optional*, defaults to `True`): + Whether or not to count embedding and softmax operations. + + Returns: + `int`: The number of floating-point operations. + """ + + return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) + + +class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin): + r""" + Base class for all models. + + [`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading, + downloading and saving models as well as a few methods common to all models to: + + - resize the input embeddings, + - prune heads in the self-attention heads. + + Class attributes (overridden by derived classes): + + - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class + for this model architecture. + - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model, + taking as arguments: + + - **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint. + - **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model. + - **path** (`str`) -- A path to the TensorFlow checkpoint. + + - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived + classes of the same architecture adding modules on top of the base model. + - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization. + - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP + models, `pixel_values` for vision models and `input_values` for speech models). + """ + + config_class = None + base_model_prefix = "" + main_input_name = "input_ids" + _auto_class = None + _no_split_modules = None + _skip_keys_device_placement = None + _keep_in_fp32_modules = None + + # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing + # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings. + _keys_to_ignore_on_load_missing = None + # a list of `re` patterns of `state_dict` keys that should be removed from the list of + # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary + # warnings. + _keys_to_ignore_on_load_unexpected = None + # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't + # trained, but which are either deterministic or tied variables) + _keys_to_ignore_on_save = None + # a list of `state_dict` keys that are potentially tied to another key in the state_dict. + _tied_weights_keys = None + + is_parallelizable = False + supports_gradient_checkpointing = False + + # Flash Attention 2 support + _supports_flash_attn_2 = False + + # SDPA support + _supports_sdpa = False + + # Has support for a `Cache` instance as `past_key_values` + _supports_cache_class = False + + @property + def dummy_inputs(self) -> Dict[str, torch.Tensor]: + """ + `Dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network. + """ + return {"input_ids": torch.tensor(DUMMY_INPUTS)} + + @property + def framework(self) -> str: + """ + :str: Identifies that this is a PyTorch model. + """ + return "pt" + + def __init__(self, config: PretrainedConfig, *inputs, **kwargs): + super().__init__() + if not isinstance(config, PretrainedConfig): + raise ValueError( + f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class " + "`PretrainedConfig`. To create a model from a pretrained model use " + f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + # Save config and origin of the pretrained weights if given in model + config = self._autoset_attn_implementation( + config, torch_dtype=torch.get_default_dtype(), check_device_map=False + ) + self.config = config + + self.name_or_path = config.name_or_path + self.warnings_issued = {} + self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None + # Overwrite the class attribute to make it an instance attribute, so models like + # `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute + # when a different component (e.g. language_model) is used. + self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules) + + def post_init(self): + """ + A method executed at the end of each Transformer model initialization, to execute code that needs the model's + modules properly initialized (such as weight initialization). + """ + self.init_weights() + self._backward_compatibility_gradient_checkpointing() + + def _backward_compatibility_gradient_checkpointing(self): + if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False): + self.gradient_checkpointing_enable() + # Remove the attribute now that is has been consumed, so it's no saved in the config. + delattr(self.config, "gradient_checkpointing") + + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + + Args: + torch_dtype (`torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. + """ + torch_dtype = kwargs.pop("torch_dtype", None) + use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) + + # override default dtype if needed + dtype_orig = None + if torch_dtype is not None: + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + + config = copy.deepcopy(config) # We do not want to modify the config inplace in _from_config. + config._attn_implementation = kwargs.pop("attn_implementation", None) + config = cls._autoset_attn_implementation( + config, use_flash_attention_2=use_flash_attention_2, check_device_map=False + ) + + if is_deepspeed_zero3_enabled(): + import deepspeed + + # this immediately partitions the model across all gpus, to avoid the overhead in time + # and memory copying it on CPU or each GPU first + if ( + "mics_shard_size" in deepspeed_config()["zero_optimization"] + and deepspeed_config()["zero_optimization"]["mics_shard_size"] > 1 + ): + print("Detected DeepSpeed ZeRO-3 with MiCS: activating zero.MiCS_Init() for this model") + with deepspeed.zero.MiCS_Init(config_dict_or_path=deepspeed_config()): + model = cls(config, **kwargs) + else: + print("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") + with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()): + model = cls(config, **kwargs) + else: + model = cls(config, **kwargs) + + # restore default dtype if it was modified + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + + return model + + @classmethod + def _autoset_attn_implementation( + cls, + config, + use_flash_attention_2: bool = False, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + ): + """ + Automatically checks and dispatches to a default attention implementation. In order of priority: + 1. An implementation specified in `config._attn_implementation` (due for example to the argument attn_implementation="sdpa" in from_pretrained). + 2. DEPRECATED: if use_flash_attention_2 is set to `True` and `flash_attn` is available, flash attention. (`LlamaFlashAttention` for example) + 3. SDPA implementation, if available and supported by the model type. (`LlamaSdpaAttention` for example) + 4. The default model's implementation otherwise (`LlamaAttention` for example) . + """ + # Here we use config._attn_implementation_internal to check whether the attention implementation was explicitely set by the user. + # The property `PretrainedConfig._attn_implementation` is never `None`, for backward compatibility (always fall back on "eager"). + # The `hasattr` here is used as some Transformers tests for some reason do not call PretrainedConfig __init__ (e.g. test_no_super_init_config_and_model) + requested_attn_implementation = None + if hasattr(config, "_attn_implementation_internal") and config._attn_implementation_internal is not None: + if config._attn_implementation != "flash_attention_2" and use_flash_attention_2: + raise ValueError( + f'Both attn_implementation="{config._attn_implementation}" and `use_flash_attention_2=True` were used when loading the model, which are not compatible.' + ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' + ) + + if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]: + message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' + if cls._supports_flash_attn_2: + message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' + if cls._supports_sdpa: + message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' + raise ValueError(message + ".") + + # If a config is passed with a preset attn_implementation, we skip the automatic dispatch and use the user-provided config, with hard checks that the requested attention implementation is available. + requested_attn_implementation = config._attn_implementation_internal + + if use_flash_attention_2: + logger.warning_once( + 'The model was loaded with use_flash_attention_2=True, which is deprecated and may be removed in a future release. Please use `attn_implementation="flash_attention_2"` instead.' + ) + config._attn_implementation = "flash_attention_2" + + if config._attn_implementation == "flash_attention_2": + cls._check_and_enable_flash_attn_2( + config, + torch_dtype=torch_dtype, + device_map=device_map, + hard_check_only=False, + check_device_map=check_device_map, + ) + elif requested_attn_implementation in [None, "sdpa"]: + # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. + config = cls._check_and_enable_sdpa( + config, hard_check_only=False if requested_attn_implementation is None else True + ) + else: + config._attn_implementation = "eager" + + return config + + @classmethod + def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype: + """ + Change the default dtype and return the previous one. This is needed when wanting to instantiate the model + under specific dtype. + + Args: + dtype (`torch.dtype`): + a floating dtype to set to. + + Returns: + `torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was + modified. If it wasn't, returns `None`. + + Note `set_default_dtype` currently only works with floating-point types and asserts if for example, + `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception. + """ + if not dtype.is_floating_point: + raise ValueError( + f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype" + ) + + logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.") + dtype_orig = torch.get_default_dtype() + torch.set_default_dtype(dtype) + return dtype_orig + + @property + def base_model(self) -> nn.Module: + """ + `torch.nn.Module`: The main body of the model. + """ + return getattr(self, self.base_model_prefix, self) + + @classmethod + def can_generate(cls) -> bool: + """ + Returns whether this model can generate sequences with `.generate()`. + + Returns: + `bool`: Whether this model can generate sequences with `.generate()`. + """ + # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation. + # Alternativelly, the model can also have a custom `generate` function. + if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate): + return False + return True + + @classmethod + def _check_and_enable_flash_attn_2( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + hard_check_only: bool = False, + ) -> PretrainedConfig: + """ + Checks the availability of Flash Attention 2 and compatibility with the current model. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. + """ + if not cls._supports_flash_attn_2: + raise ValueError( + f"{cls.__name__} does not support Flash Attention 2.0 yet. Please open an issue on GitHub to " + "request support for this architecture: https://github.com/huggingface/transformers/issues/new" + ) + + if not is_flash_attn_2_available(): + preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:" + install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2." + + if importlib.util.find_spec("flash_attn") is None: + raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}") + + flash_attention_version = version.parse(importlib.metadata.version("flash_attn")) + if torch.version.cuda: + if flash_attention_version < version.parse("2.1.0"): + raise ImportError( + f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}" + ) + else: + raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") + elif torch.version.hip: + if flash_attention_version < version.parse("2.0.4"): + raise ImportError( + f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Make sure to have that version installed - detected version {flash_attention_version}. {install_message}" + ) + else: + raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}") + + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + + if _is_bettertransformer: + raise ValueError( + "Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" + ) + + if torch_dtype is None: + logger.warning( + "You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour" + ) + elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: + raise ValueError( + f"Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. You passed {torch_dtype}, this might lead to" + " unexpected behaviour." + ) + + # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, + # or the model may be initialized under the context manager `with torch.device("cuda"):`. + if check_device_map and device_map is None and torch.empty(0).device.type != "cuda": + if torch.cuda.is_available(): + logger.warning( + "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU" + " after initializing it on CPU with `model.to('cuda')`." + ) + else: + raise ValueError( + "You are attempting to use Flash Attention 2.0 with a model not initialized on GPU and with no GPU available. " + "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " + "or initialising the model on CPU and then moving it to GPU." + ) + elif ( + check_device_map + and device_map is not None + and isinstance(device_map, dict) + and ("cpu" in device_map.values() or "disk" in device_map.values()) + ): + raise ValueError( + "You are attempting to use Flash Attention 2.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " + "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." + ) + if not hard_check_only: + config._attn_implementation = "flash_attention_2" + return config + + @classmethod + def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: + """ + Checks the availability of SDPA for a given model. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_2" so that the model can initialize the correct attention module. + """ + if hard_check_only: + if not cls._supports_sdpa: + raise ValueError( + f"{cls.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet. Please open an issue on GitHub to " + "request support for this architecture: https://github.com/huggingface/transformers/issues/new" + ) + if not is_torch_sdpa_available(): + raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.1.1.") + + if not is_torch_sdpa_available() or not cls._supports_sdpa: + return config + + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + if _is_bettertransformer: + return config + + if not hard_check_only: + config._attn_implementation = "sdpa" + return config + + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping + the model weights fixed. + """ + + def make_inputs_require_grads(module, input, output): + output.requires_grad_(True) + + self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + + def disable_input_require_grads(self): + """ + Removes the `_require_grads_hook`. + """ + self._require_grads_hook.remove() + + def get_input_embeddings(self) -> nn.Module: + """ + Returns the model's input embeddings. + + Returns: + `nn.Module`: A torch module mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + return base_model.get_input_embeddings() + else: + raise NotImplementedError + + def set_input_embeddings(self, value: nn.Module): + """ + Set model's input embeddings. + + Args: + value (`nn.Module`): A module mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + base_model.set_input_embeddings(value) + else: + raise NotImplementedError + + def get_output_embeddings(self) -> nn.Module: + """ + Returns the model's output embeddings. + + Returns: + `nn.Module`: A torch module mapping hidden states to vocabulary. + """ + return None # Overwrite for models with output embeddings + + def _init_weights(self, module): + """ + Initialize the weights. This method should be overridden by derived class and is + the only initialization method that will be called when loading a checkpoint + using `from_pretrained`. Any attempt to initialize outside of this function + will be useless as the torch.nn.init function are all replaced with skip. + """ + pass + + def _initialize_weights(self, module): + """ + Initialize the weights if they are not already initialized. + """ + if getattr(module, "_is_hf_initialized", False): + return + self._init_weights(module) + module._is_hf_initialized = True + + def tie_weights(self): + """ + Tie the weights between the input embeddings and the output embeddings. + + If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the + weights instead. + """ + if getattr(self.config, "tie_word_embeddings", True): + output_embeddings = self.get_output_embeddings() + if output_embeddings is not None: + self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) + + if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False): + if hasattr(self, self.base_model_prefix): + self = getattr(self, self.base_model_prefix) + self._tie_encoder_decoder_weights(self.encoder, self.decoder, self.base_model_prefix) + + for module in self.modules(): + if hasattr(module, "_tie_weights"): + module._tie_weights() + + @staticmethod + def _tie_encoder_decoder_weights(encoder: nn.Module, decoder: nn.Module, base_model_prefix: str): + uninitialized_encoder_weights: List[str] = [] + if decoder.__class__ != encoder.__class__: + logger.info( + f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder" + " weights are correctly initialized." + ) + + def tie_encoder_to_decoder_recursively( + decoder_pointer: nn.Module, + encoder_pointer: nn.Module, + module_name: str, + uninitialized_encoder_weights: List[str], + depth=0, + ): + assert isinstance(decoder_pointer, nn.Module) and isinstance( + encoder_pointer, nn.Module + ), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module" + if hasattr(decoder_pointer, "weight"): + assert hasattr(encoder_pointer, "weight") + encoder_pointer.weight = decoder_pointer.weight + if hasattr(decoder_pointer, "bias"): + assert hasattr(encoder_pointer, "bias") + encoder_pointer.bias = decoder_pointer.bias + return + + encoder_modules = encoder_pointer._modules + decoder_modules = decoder_pointer._modules + if len(decoder_modules) > 0: + assert ( + len(encoder_modules) > 0 + ), f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}" + + all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules.keys()} + encoder_layer_pos = 0 + for name, module in decoder_modules.items(): + if name.isdigit(): + encoder_name = str(int(name) + encoder_layer_pos) + decoder_name = name + if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len( + encoder_modules + ) != len(decoder_modules): + # this can happen if the name corresponds to the position in a list module list of layers + # in this case the decoder has added a cross-attention that the encoder does not have + # thus skip this step and subtract one layer pos from encoder + encoder_layer_pos -= 1 + continue + elif name not in encoder_modules: + continue + elif depth > 500: + raise ValueError( + "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is" + " a circular dependency between two or more `nn.Modules` of your model." + ) + else: + decoder_name = encoder_name = name + tie_encoder_to_decoder_recursively( + decoder_modules[decoder_name], + encoder_modules[encoder_name], + module_name + "/" + name, + uninitialized_encoder_weights, + depth=depth + 1, + ) + all_encoder_weights.remove(module_name + "/" + encoder_name) + + uninitialized_encoder_weights += list(all_encoder_weights) + + # tie weights recursively + tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) + if len(uninitialized_encoder_weights) > 0: + logger.warning( + f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" + ) + + def _tie_or_clone_weights(self, output_embeddings, input_embeddings): + """Tie or clone module weights depending of whether we are using TorchScript or not""" + if self.config.torchscript: + output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone()) + else: + output_embeddings.weight = input_embeddings.weight + + if getattr(output_embeddings, "bias", None) is not None: + output_embeddings.bias.data = nn.functional.pad( + output_embeddings.bias.data, + ( + 0, + output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0], + ), + "constant", + 0, + ) + if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): + output_embeddings.out_features = input_embeddings.num_embeddings + + def _get_no_split_modules(self, device_map: str): + """ + Get the modules of the model that should not be spit when using device_map. We iterate through the modules to + get the underlying `_no_split_modules`. + + Args: + device_map (`str`): + The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"] + + Returns: + `List[str]`: List of modules that should not be split + """ + _no_split_modules = set() + modules_to_check = [self] + while len(modules_to_check) > 0: + module = modules_to_check.pop(-1) + # if the module does not appear in _no_split_modules, we also check the children + if module.__class__.__name__ not in _no_split_modules: + if isinstance(module, PreTrainedModel): + if module._no_split_modules is None: + raise ValueError( + f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model " + "class needs to implement the `_no_split_modules` attribute." + ) + else: + _no_split_modules = _no_split_modules | set(module._no_split_modules) + modules_to_check += list(module.children()) + return list(_no_split_modules) + + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None + ) -> nn.Embedding: + """ + Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. + + Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. + + Arguments: + new_num_tokens (`int`, *optional*): + The number of new tokens in the embedding matrix. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just + returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. + pad_to_multiple_of (`int`, *optional*): + If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to + `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more + details about this, or help on choosing the correct value for resizing, refer to this guide: + https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + + Return: + `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. + """ + model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + if new_num_tokens is None and pad_to_multiple_of is None: + return model_embeds + + # Update base model and current model config + self.config.vocab_size = model_embeds.weight.shape[0] + self.vocab_size = model_embeds.weight.shape[0] + + # Tie weights again if needed + self.tie_weights() + + return model_embeds + + def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): + old_embeddings = self.get_input_embeddings() + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) + if hasattr(old_embeddings, "_hf_hook"): + hook = old_embeddings._hf_hook + add_hook_to_module(new_embeddings, hook) + old_embeddings_requires_grad = old_embeddings.weight.requires_grad + new_embeddings.requires_grad_(old_embeddings_requires_grad) + self.set_input_embeddings(new_embeddings) + + # Update new_num_tokens with the actual size of new_embeddings + if pad_to_multiple_of is not None: + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None): + new_num_tokens = new_embeddings.weight.shape[0] + else: + new_num_tokens = new_embeddings.weight.shape[0] + + # if word embeddings are not tied, make sure that lm head is resized as well + if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: + old_lm_head = self.get_output_embeddings() + new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens) + if hasattr(old_lm_head, "_hf_hook"): + hook = old_lm_head._hf_hook + add_hook_to_module(new_lm_head, hook) + old_lm_head_requires_grad = old_lm_head.weight.requires_grad + new_lm_head.requires_grad_(old_lm_head_requires_grad) + self.set_output_embeddings(new_lm_head) + + return self.get_input_embeddings() + + def _get_resized_embeddings( + self, + old_embeddings: nn.Embedding, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + ) -> nn.Embedding: + """ + Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly + initialized vectors at the end. Reducing the size will remove vectors from the end + + Args: + old_embeddings (`torch.nn.Embedding`): + Old embeddings to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the embedding matrix. + + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or `None`, just returns a pointer to the input tokens + `torch.nn.Embedding` module of the model without doing anything. + pad_to_multiple_of (`int`, *optional*): + If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to + `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. + + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more + details about this, or help on choosing the correct value for resizing, refer to this guide: + https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + + + Return: + `torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if + `new_num_tokens` is `None` + """ + + if pad_to_multiple_of is not None: + if not isinstance(pad_to_multiple_of, int): + raise ValueError( + f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer" + ) + if new_num_tokens is None: + new_num_tokens = old_embeddings.weight.shape[0] + new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of + else: + logger.info( + "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding" + f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available." + " For more details about this, or help on choosing the correct value for resizing, refer to this guide:" + " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc" + ) + + if new_num_tokens is None: + return old_embeddings + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None): + old_num_tokens, old_embedding_dim = old_embeddings.weight.size() + else: + old_num_tokens, old_embedding_dim = old_embeddings.weight.size() + + if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled(): + return old_embeddings + + if not isinstance(old_embeddings, nn.Embedding): + raise TypeError( + f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You" + " should either use a different resize function or make sure that `old_embeddings` are an instance of" + f" {nn.Embedding}." + ) + + # Build new embeddings + + # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init + # because the shape of the new embedding layer is used across various modeling files + # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading + # to errors when training. + new_embeddings = nn.Embedding( + new_num_tokens, + old_embedding_dim, + device=old_embeddings.weight.device, + dtype=old_embeddings.weight.dtype, + ) + + # initialize all new embeddings (in particular added tokens) + self._init_weights(new_embeddings) + + # Copy token embeddings from the previous weights + + # numbers of tokens to copy + n = min(old_num_tokens, new_num_tokens) + + if is_deepspeed_zero3_enabled(): + import deepspeed + + params = [old_embeddings.weight, new_embeddings.weight] + with deepspeed.zero.GatheredParameters(params, modifier_rank=0): + new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] + else: + new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] + + return new_embeddings + + def _get_resized_lm_head( + self, old_lm_head: nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False + ) -> nn.Linear: + """ + Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end + + Args: + old_lm_head (`torch.nn.Linear`): + Old lm head liner layer to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the linear matrix. + + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or `None`, just returns a pointer to the input tokens + `torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults + to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim, + vocab_size` else `vocab_size, lm_head_dim`. + + Return: + `torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is + `None` + """ + if new_num_tokens is None: + return old_lm_head + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None): + old_num_tokens, old_lm_head_dim = ( + old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size() + ) + else: + old_num_tokens, old_lm_head_dim = ( + old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size() + ) + + if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled(): + return old_lm_head + + if not isinstance(old_lm_head, nn.Linear): + raise TypeError( + f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You" + " should either use a different resize function or make sure that `old_lm_head` are an instance of" + f" {nn.Linear}." + ) + + # Build new lm head + new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim) + has_new_lm_head_bias = old_lm_head.bias is not None + + # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init + # because the shape of the new embedding layer is used across various modeling files + # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading + # to errors when training. + new_lm_head = nn.Linear( + *new_lm_head_shape, + bias=has_new_lm_head_bias, + device=old_lm_head.weight.device, + dtype=old_lm_head.weight.dtype, + ) + + # initialize new lm head (in particular added tokens) + self._init_weights(new_lm_head) + + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + + if is_deepspeed_zero3_enabled(): + import deepspeed + + params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias] + with deepspeed.zero.GatheredParameters(params, modifier_rank=0): + self._copy_lm_head_original_to_resized( + new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias + ) + else: + self._copy_lm_head_original_to_resized( + new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias + ) + + return new_lm_head + + def _copy_lm_head_original_to_resized( + self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias + ): + # Copy old lm head weights to new lm head + if not transposed: + new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :] + else: + new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy] + + # Copy bias weights to new lm head + if has_new_lm_head_bias: + new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] + + def resize_position_embeddings(self, new_num_position_embeddings: int): + raise NotImplementedError( + f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should " + f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`" + ) + + def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]: + raise NotImplementedError( + f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should " + f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`" + ) + + def init_weights(self): + """ + If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any + initialization logic in `_init_weights`. + """ + # Prune heads if needed + if self.config.pruned_heads: + self.prune_heads(self.config.pruned_heads) + + if _init_weights: + # Initialize weights + self.apply(self._initialize_weights) + + # Tie weights should be skipped when not initializing all weights + # since from_pretrained(...) calls tie weights anyways + self.tie_weights() + + def prune_heads(self, heads_to_prune: Dict[int, List[int]]): + """ + Prunes heads of the base model. + + Arguments: + heads_to_prune (`Dict[int, List[int]]`): + Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads + to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on + layer 1 and heads 2 and 3 on layer 2. + """ + # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads + for layer, heads in heads_to_prune.items(): + union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) + self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON + + self.base_model._prune_heads(heads_to_prune) + + def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + + We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of + the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2 + + Args: + gradient_checkpointing_kwargs (dict, *optional*): + Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function. + """ + if not self.supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + + gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) + + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` methid + _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters + + if not _is_using_old_format: + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) + else: + self.apply(partial(self._set_gradient_checkpointing, value=True)) + logger.warn( + "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." + "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." + ) + + if getattr(self, "_hf_peft_config_loaded", False): + # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True + # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 + # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate + # the gradients to make sure the gradient flows. + self.enable_input_require_grads() + + def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint): + is_gradient_checkpointing_set = False + + # Apply it on the top-level module in case the top-level modules supports it + # for example, LongT5Stack inherits from `PreTrainedModel`. + if hasattr(self, "gradient_checkpointing"): + self._gradient_checkpointing_func = gradient_checkpointing_func + self.gradient_checkpointing = enable + is_gradient_checkpointing_set = True + + for module in self.modules(): + if hasattr(module, "gradient_checkpointing"): + module._gradient_checkpointing_func = gradient_checkpointing_func + module.gradient_checkpointing = enable + is_gradient_checkpointing_set = True + + if not is_gradient_checkpointing_set: + raise ValueError( + f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute" + " `gradient_checkpointing` to modules of the model that uses checkpointing." + ) + + def gradient_checkpointing_disable(self): + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if self.supports_gradient_checkpointing: + # For old GC format (transformers < 4.35.0) for models that live on the Hub + # we will fall back to the overwritten `_set_gradient_checkpointing` methid + _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters + if not _is_using_old_format: + self._set_gradient_checkpointing(enable=False) + else: + logger.warn( + "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)." + "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." + ) + self.apply(partial(self._set_gradient_checkpointing, value=False)) + + if getattr(self, "_hf_peft_config_loaded", False): + self.disable_input_require_grads() + + @property + def is_gradient_checkpointing(self) -> bool: + """ + Whether gradient checkpointing is activated for this model or not. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = torch.save, + push_to_hub: bool = False, + max_shard_size: Union[int, str] = "5GB", + safe_serialization: bool = True, + variant: Optional[str] = None, + token: Optional[Union[str, bool]] = None, + save_peft_format: bool = True, + **kwargs, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + [`~PreTrainedModel.from_pretrained`] class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + state_dict (nested dictionary of `torch.Tensor`): + The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only + save parts of the model or if special precautions need to be taken when recovering the state dictionary + of a model (like when using model parallelism). + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the + repository you want to push to with `repo_id` (will default to the name of `save_directory` in your + namespace). + max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + We default it to 5GB in order for models to be able to run easily on free-tier google colab instances + without CPU OOM issues. + + + + If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard + which will be bigger than `max_shard_size`. + + + + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + variant (`str`, *optional*): + If specified, weights are saved in the format pytorch_model..bin. + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + save_peft_format (`bool`, *optional*, defaults to `True`): + For backward compatibility with PEFT library, in case adapter weights are attached to the model, all + keys of the state dict of adapters needs to be pre-pended with `base_model.model`. Advanced users can + disable this behaviours by setting `save_peft_format` to `False`. + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None: + kwargs["token"] = token + + _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False) + + # Checks if the model has been loaded in 8-bit + if ( + getattr(self, "is_loaded_in_8bit", False) + and not getattr(self, "is_8bit_serializable", False) + and not _hf_peft_config_loaded + ): + raise ValueError( + "You are calling `save_pretrained` to a 8-bit converted model you may likely encounter unexepected" + " behaviors. If you want to save 8-bit models, make sure to have `bitsandbytes>0.37.2` installed." + ) + + # If the model has adapters attached, you can save the adapters + if getattr(self, "is_loaded_in_4bit", False) and not _hf_peft_config_loaded: + raise NotImplementedError( + "You are calling `save_pretrained` on a 4-bit converted model. This is currently not supported" + ) + + if getattr(self, "_awq_is_fused", False): + raise ValueError("You cannot save an AWQ model that uses fused modules!") + + if "save_config" in kwargs: + warnings.warn( + "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead." + ) + is_main_process = kwargs.pop("save_config") + if safe_serialization and not is_safetensors_available(): + raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + # Only save the model itself if we are using distributed training + model_to_save = unwrap_model(self) + + # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" + # we currently don't use this setting automatically, but may start to use with v5 + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.torch_dtype = str(dtype).split(".")[1] + + # Attach architecture to the config + model_to_save.config.architectures = [model_to_save.__class__.__name__] + + # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be + # loaded from the Hub. + if self._auto_class is not None: + custom_object_save(self, save_directory, config=self.config) + + # Save the config + if is_main_process: + if not _hf_peft_config_loaded: + model_to_save.config.save_pretrained(save_directory) + if self.can_generate(): + model_to_save.generation_config.save_pretrained(save_directory) + + if _hf_peft_config_loaded: + logger.info( + "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved." + ) + state_dict = model_to_save.get_adapter_state_dict() + + if save_peft_format: + logger.info( + "To match the expected format of the PEFT library, all keys of the state dict of adapters will be pre-pended with `base_model.model`." + ) + peft_state_dict = {} + for key, value in state_dict.items(): + peft_state_dict[f"base_model.model.{key}"] = value + state_dict = peft_state_dict + + active_adapter = self.active_adapters() + + if len(active_adapter) > 1: + raise ValueError( + "Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one " + "by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`" + ) + active_adapter = active_adapter[0] + + current_peft_config = self.peft_config[active_adapter] + current_peft_config.save_pretrained(save_directory) + + # Save the model + if state_dict is None: + state_dict = model_to_save.state_dict() + + # Translate state_dict from smp to hf if saving with smp >= 1.10 + if IS_SAGEMAKER_MP_POST_1_10: + for smp_to_hf, _ in smp.state.module_manager.translate_functions: + state_dict = smp_to_hf(state_dict) + + # Handle the case where some state_dict keys shouldn't be saved + if self._keys_to_ignore_on_save is not None: + for ignore_key in self._keys_to_ignore_on_save: + if ignore_key in state_dict.keys(): + del state_dict[ignore_key] + if safe_serialization: + # Safetensors does not allow tensor aliasing. + # We're going to remove aliases before saving + ptrs = collections.defaultdict(list) + for name, tensor in state_dict.items(): + # Sometimes in the state_dict we have non-tensor objects. + # e.g. in bitsandbytes we have some `str` objects in the state_dict + if isinstance(tensor, torch.Tensor): + ptrs[id_tensor_storage(tensor)].append(name) + else: + # In the non-tensor case, fall back to the pointer of the object itself + ptrs[id(tensor)].append(name) + + # These are all the pointers of shared tensors. + shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} + warn_names = set() + for names in shared_ptrs.values(): + # Removing the keys which are declared as known duplicates on + # load. This allows to make sure the name which is kept is consistent. + if self._tied_weights_keys is not None: + found = 0 + for name in sorted(names): + matches_pattern = any(re.search(pat, name) for pat in self._tied_weights_keys) + if matches_pattern and name in state_dict: + found += 1 + if found < len(names): + del state_dict[name] + + # When not all duplicates have been cleaned, still remove those keys, but put a clear warning. + # If the link between tensors was done at runtime then `from_pretrained` will not get + # the key back leading to random tensor. A proper warning will be shown + # during reload (if applicable), but since the file is not necessarily compatible with + # the config, better show a proper warning. + found = 0 + for name in names: + if name in state_dict: + found += 1 + if found > 1: + del state_dict[name] + warn_names.add(name) + if len(warn_names) > 0: + logger.warning_once( + f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", + ) + + # Shard the model if it is too big. + if not _hf_peft_config_loaded: + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + weights_name = _add_variant(weights_name, variant) + else: + weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME + + shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "") + reg = re.compile(r"(.*?)-\d{5}-of-\d{5}") + + if ( + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and filename not in shards.keys() + and is_main_process + and reg.fullmatch(filename_no_suffix) is not None + ): + os.remove(full_filename) + + # Save the model + for shard_file, shard in shards.items(): + if safe_serialization: + # At some point we will need to deal better with save_function (used for TPU and other distributed + # joyfulness), but for now this enough. + safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}) + else: + save_function(shard, os.path.join(save_directory, shard_file)) + + if index is None: + path_to_weights = os.path.join(save_directory, _add_variant(WEIGHTS_NAME, variant)) + logger.info(f"Model weights saved in {path_to_weights}") + else: + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=token, + ) + + def get_memory_footprint(self, return_buffers=True): + r""" + Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. + Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the + PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 + + Arguments: + return_buffers (`bool`, *optional*, defaults to `True`): + Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers + are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch + norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 + """ + mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) + if return_buffers: + mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) + mem = mem + mem_bufs + return mem + + @wraps(torch.nn.Module.cuda) + def cuda(self, *args, **kwargs): + # Checks if the model has been loaded in 8-bit + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + raise ValueError( + "Calling `cuda()` is not supported for `4-bit` or `8-bit` quantized models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + else: + return super().cuda(*args, **kwargs) + + @wraps(torch.nn.Module.to) + def to(self, *args, **kwargs): + # Checks if the model has been loaded in 8-bit + if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + raise ValueError( + "`.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the" + " model has already been set to the correct devices and casted to the correct `dtype`." + ) + elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ: + # For GPTQ models, we prevent users from casting the model to another dytpe to restrict unwanted behaviours. + # the correct API should be to load the model with the desired dtype directly through `from_pretrained`. + dtype_present_in_args = False + + if "dtype" not in kwargs: + for arg in args: + if isinstance(arg, torch.dtype): + dtype_present_in_args = True + break + else: + dtype_present_in_args = True + + if dtype_present_in_args: + raise ValueError( + "You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired" + " `dtype` by passing the correct `torch_dtype` argument." + ) + return super().to(*args, **kwargs) + + def half(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.half()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been casted to the correct `dtype`." + ) + else: + return super().half(*args) + + def float(self, *args): + # Checks if the model is quantized + if getattr(self, "is_quantized", False): + raise ValueError( + "`.float()` is not supported for quantized model. Please use the model as it is, since the" + " model has already been casted to the correct `dtype`." + ) + else: + return super().float(*args) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: bool = None, + **kwargs, + ): + r""" + Instantiate a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + user or organization name, like `dbmdz/bert-base-german-cased`. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In + this case, `from_tf` should be set to `True` and a configuration object should be provided as + `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a + PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. + - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g, + `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to + `True`. + - `None` if you are both providing the configuration and state dictionary (resp. with keyword + arguments `config` and `state_dict`). + model_args (sequence of positional arguments, *optional*): + All remaining positional arguments will be passed to the underlying model's `__init__` method. + config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): + Can be either: + + - an instance of a class derived from [`PretrainedConfig`], + - a string or path valid as input to [`~PretrainedConfig.from_pretrained`]. + + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + state_dict (`Dict[str, torch.Tensor]`, *optional*): + A state dictionary to use instead of a state dictionary loaded from saved weights file. + + This option can be used if you want to create a model from a pretrained configuration but load your own + weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and + [`~PreTrainedModel.from_pretrained`] is not a simpler option. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_tf (`bool`, *optional*, defaults to `False`): + Load the model weights from a TensorFlow checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + token (`str` or `bool`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use + the token generated when running `huggingface-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/". + + + + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + _fast_init(`bool`, *optional*, defaults to `True`): + Whether or not to disable fast initialization. + + + + One should only disable *_fast_init* to ensure backwards compatibility with `transformers.__version__ < + 4.6.0` for seeded model initialization. This argument will be removed at the next major version. See + [pull request 11471](https://github.com/huggingface/transformers/pull/11471) for more information. + + + + > Parameters for big model inference + + low_cpu_mem_usage(`bool`, *optional*): + Tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + This is an experimental feature and a subject to change at any moment. + torch_dtype (`str` or `torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under a specific `dtype`. The different options + are: + + 1. `torch.float16` or `torch.bfloat16` or `torch.float`: load in a specified + `dtype`, ignoring the model's `config.torch_dtype` if one exists. If not specified + - the model will get loaded in `torch.float` (fp32). + + 2. `"auto"` - A `torch_dtype` entry in the `config.json` file of the model will be + attempted to be used. If this entry isn't found then next check the `dtype` of the first weight in + the checkpoint that's of a floating point type and use that as `dtype`. This will load the model + using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how + the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32. + + + + For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or + reach out to the authors and ask them to add this information to the model's card and to insert the + `torch_dtype` entry in `config.json` on the hub. + + + + device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank + like `1`) on which the model will be allocated, the device map will map the entire model to this + device. Passing `device_map = 0` means put the whole model on GPU 0. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + offload_state_dict (`bool`, *optional*): + If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU + RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to + `True` when there is some disk offload. + load_in_8bit (`bool`, *optional*, defaults to `False`): + If `True`, will convert the loaded model into mixed-8bit quantized model. To use this feature please + install `bitsandbytes` (`pip install -U bitsandbytes`). + load_in_4bit (`bool`, *optional*, defaults to `False`): + If `True`, will convert the loaded model into 4bit precision quantized model. To use this feature + install the latest version of `bitsandbytes` (`pip install -U bitsandbytes`). + quantization_config (`Union[QuantizationConfigMixin,Dict]`, *optional*): + A dictionary of configuration parameters or a QuantizationConfigMixin object for quantization (e.g + bitsandbytes, gptq) + subfolder (`str`, *optional*, defaults to `""`): + In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can + specify the folder name here. + variant (`str`, *optional*): + If specified load weights from `variant` filename, *e.g.* pytorch_model..bin. `variant` is + ignored when using `from_tf` or `from_flax`. + use_safetensors (`bool`, *optional*, defaults to `None`): + Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors` + is not installed, it will be set to `False`. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + Examples: + + ```python + >>> from transformers import BertConfig, BertModel + + >>> # Download model and configuration from huggingface.co and cache. + >>> model = BertModel.from_pretrained("bert-base-uncased") + >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable). + >>> model = BertModel.from_pretrained("./test/saved_model/") + >>> # Update configuration during loading. + >>> model = BertModel.from_pretrained("bert-base-uncased", output_attentions=True) + >>> assert model.config.output_attentions == True + >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). + >>> config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json") + >>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config) + >>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower) + >>> model = BertModel.from_pretrained("bert-base-uncased", from_flax=True) + ``` + + * `low_cpu_mem_usage` algorithm: + + This is an experimental function that loads the model using ~1x model size CPU memory + + Here is how it works: + + 1. save which state_dict keys we have + 2. drop state_dict before the model is created, since the latter takes 1x model size CPU memory + 3. after the model has been instantiated switch to the meta device all params/buffers that + are going to be replaced from the loaded state_dict + 4. load state_dict 2nd time + 5. replace the params/buffers from the state_dict + + Currently, it can't handle deepspeed ZeRO stage 3 and ignores loading errors + + """ + state_dict = kwargs.pop("state_dict", None) + from_tf = kwargs.pop("from_tf", False) + from_flax = kwargs.pop("from_flax", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + use_auth_token = kwargs.pop("use_auth_token", None) + trust_remote_code = kwargs.pop("trust_remote_code", None) + _ = kwargs.pop("mirror", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + _fast_init = kwargs.pop("_fast_init", True) + torch_dtype = kwargs.pop("torch_dtype", None) + low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", None) + device_map = kwargs.pop("device_map", None) + max_memory = kwargs.pop("max_memory", None) + offload_folder = kwargs.pop("offload_folder", None) + offload_state_dict = kwargs.pop("offload_state_dict", False) + load_in_8bit = kwargs.pop("load_in_8bit", False) + load_in_4bit = kwargs.pop("load_in_4bit", False) + quantization_config = kwargs.pop("quantization_config", None) + subfolder = kwargs.pop("subfolder", "") + commit_hash = kwargs.pop("_commit_hash", None) + variant = kwargs.pop("variant", None) + adapter_kwargs = kwargs.pop("adapter_kwargs", {}) + adapter_name = kwargs.pop("adapter_name", "default") + use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False) + + if is_fsdp_enabled(): + low_cpu_mem_usage = True + + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", + FutureWarning, + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token + + if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs: + adapter_kwargs["token"] = token + + if use_safetensors is None and not is_safetensors_available(): + use_safetensors = False + + if is_bitsandbytes_available(): + is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse("0.37.2") + else: + is_8bit_serializable = False + + if trust_remote_code is True: + logger.warning( + "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is" + " ignored." + ) + + if commit_hash is None: + if not isinstance(config, PretrainedConfig): + # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible + resolved_config_file = cached_file( + pretrained_model_name_or_path, + CONFIG_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + _raise_exceptions_for_connection_errors=False, + ) + commit_hash = extract_commit_hash(resolved_config_file, commit_hash) + else: + commit_hash = getattr(config, "_commit_hash", None) + + if is_peft_available(): + _adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None) + + if _adapter_model_path is None: + _adapter_model_path = find_adapter_config_file( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + _commit_hash=commit_hash, + **adapter_kwargs, + ) + if _adapter_model_path is not None and os.path.isfile(_adapter_model_path): + with open(_adapter_model_path, encoding="utf-8") as f: + _adapter_model_path = pretrained_model_name_or_path + pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"] + else: + _adapter_model_path = None + + # change device_map into a map if we passed an int, a str or a torch.device + if isinstance(device_map, torch.device): + device_map = {"": device_map} + elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + try: + device_map = {"": torch.device(device_map)} + except RuntimeError: + raise ValueError( + "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or " + f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}." + ) + elif isinstance(device_map, int): + if device_map < 0: + raise ValueError( + "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' " + ) + else: + device_map = {"": device_map} + + if device_map is not None: + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + elif not low_cpu_mem_usage: + raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`") + + if low_cpu_mem_usage: + if device_map is not None: + # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. + require_version_core("torch>=1.10") + + if is_deepspeed_zero3_enabled(): + raise ValueError( + "DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`." + ) + elif not is_accelerate_available(): + raise ImportError( + "Using `low_cpu_mem_usage=True` or a `device_map` requires Accelerate: `pip install accelerate`" + ) + + quantization_method_from_args = None + + if quantization_config is not None: + quantization_method_from_args = getattr( + quantization_config, "quant_method", QuantizationMethod.BITS_AND_BYTES + ) + + if quantization_config is None and (load_in_8bit or load_in_4bit): + quantization_method_from_args = QuantizationMethod.BITS_AND_BYTES + quantization_config, kwargs = BitsAndBytesConfig.from_dict( + config_dict={"load_in_8bit": load_in_8bit, "load_in_4bit": load_in_4bit}, + return_unused_kwargs=True, + **kwargs, + ) + elif quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES: + load_in_8bit = quantization_config.load_in_8bit + load_in_4bit = quantization_config.load_in_4bit + + quantization_config_kwargs = { + k: v for k, v in kwargs.items() if k in inspect.signature(BitsAndBytesConfig).parameters + } + + if len(quantization_config_kwargs) > 0: + raise ValueError( + "You can't pass `load_in_8bit` or any other `BitsAndBytesConfig` argument as a kwarg when passing " + "`quantization_config` argument at the same time." + ) + + if load_in_8bit or load_in_4bit: + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + if not (is_accelerate_available() and is_bitsandbytes_available()): + raise ImportError( + "Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of" + " bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes` or" + " `pip install bitsandbytes`." + ) + + if torch_dtype is None: + # We force the `dtype` to be float16, this is a requirement from `bitsandbytes` + logger.info( + f"Overriding torch_dtype={torch_dtype} with `torch_dtype=torch.float16` due to " + "requirements of `bitsandbytes` to enable model loading in 8-bit or 4-bit. " + "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass" + " torch_dtype=torch.float16 to remove this warning." + ) + torch_dtype = torch.float16 + + if device_map is None: + device_map = {"": torch.cuda.current_device()} + logger.info( + "The device_map was not initialized. " + "Setting device_map to {'':torch.cuda.current_device()}. " + "If you want to use the model for inference, please set device_map ='auto' " + ) + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + + if from_tf or from_flax: + raise ValueError( + "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make" + " sure the weights are in PyTorch format." + ) + + user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} + if from_pipeline is not None: + user_agent["using_pipeline"] = from_pipeline + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + else: + model_kwargs = kwargs + + quantizer = None + quantization_method_from_config = None + if hasattr(config, "quantization_config"): + quantization_method_from_config = config.quantization_config.get( + "quant_method", QuantizationMethod.BITS_AND_BYTES + ) + + if ( + quantization_method_from_args is not None + and quantization_method_from_args == QuantizationMethod.AWQ + and quantization_method_from_config is None + ): + raise ValueError( + "You cannot quantize with AWQ a non-quantized model using transformers, please refer to the quantization documentation" + " to read more about how to quantize models with AWQ algorithm https://huggingface.co/docs/transformers/main_classes/quantization" + ) + + if quantization_method_from_config is not None and quantization_method_from_args is not None: + if quantization_method_from_config != quantization_method_from_args: + raise ValueError( + f"The model is already quantized with {quantization_method_from_config}. " + f"You can't quantize it again with {quantization_method_from_args}" + ) + + if ( + quantization_method_from_config in (QuantizationMethod.GPTQ, QuantizationMethod.AWQ) + and quantization_method_from_args is not None + ): + loading_attr_dict = quantization_config.get_loading_attributes() + for attr, val in loading_attr_dict.items(): + config.quantization_config[attr] = val + quantization_method_from_args = None + logger.warning( + f"You passed `quantization_config` to `from_pretrained` but the model you're loading already has a " + f"`quantization_config` attribute and has already quantized weights. However, loading attributes" + f" (e.g. {list(loading_attr_dict.keys())}) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored." + ) + if ( + quantization_method_from_args == QuantizationMethod.GPTQ + or quantization_method_from_config == QuantizationMethod.GPTQ + ): + gptq_supports_cpu = version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2") + if not gptq_supports_cpu and not torch.cuda.is_available(): + raise RuntimeError("GPU is required to quantize or run quantize model.") + elif not (is_optimum_available() and is_auto_gptq_available()): + raise ImportError( + "Loading a GPTQ quantized model requires optimum (`pip install optimum`) and auto-gptq library (`pip install auto-gptq`)" + ) + elif version.parse(importlib.metadata.version("auto_gptq")) < version.parse("0.4.2"): + raise ImportError( + "You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq`" + ) + else: + # Need to protect the import + from optimum.gptq import GPTQQuantizer + if quantization_method_from_config == QuantizationMethod.GPTQ: + quantization_config = GPTQConfig.from_dict(config.quantization_config) + config.quantization_config = quantization_config + if torch_dtype is None: + torch_dtype = torch.float16 + else: + logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.") + quantizer = GPTQQuantizer.from_dict(quantization_config.to_dict_optimum()) + elif quantization_method_from_config == QuantizationMethod.AWQ: + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run AWQ quantized model.") + + if not is_auto_awq_available(): + raise ImportError("Loading an AWQ quantized model requires auto-awq library (`pip install autoawq`)") + + if not is_accelerate_available(): + raise ImportError("Loading an AWQ quantized model requires accelerate (`pip install accelerate`)") + + if device_map is None: + logger.warning( + "You have loaded an AWQ model on CPU and have a CUDA device available, make sure to set " + "your model on a GPU device in order to run your model." + ) + elif device_map is not None: + if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()): + raise ValueError( + "You are attempting to load an AWQ model with a device_map that contains a CPU or disk device." + " This is not supported. Please remove the CPU or disk device from the device_map." + ) + + if torch_dtype is None: + torch_dtype = torch.float16 + else: + logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.") + + # Force-set to `True` for more mem efficiency + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + + if is_8bit_serializable and quantization_method_from_args == QuantizationMethod.BITS_AND_BYTES and load_in_8bit: + if quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES: + logger.warning( + "You passed `quantization_config` to `from_pretrained` but the model you're loading already has a" + " `quantization_config` attribute. The `quantization_config` attribute will be overwritten with the" + " one you passed to `from_pretrained`." + ) + config.quantization_config = quantization_config + elif ( + is_8bit_serializable + and not load_in_8bit + and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES + ): + quantization_config = config.quantization_config + if isinstance(quantization_config, dict): + quantization_config = BitsAndBytesConfig.from_dict(quantization_config, return_unused_kwargs=False) + elif isinstance(quantization_config, BitsAndBytesConfig): + pass + else: + raise ValueError( + f"Invalid type for `quantization_config`: {type(quantization_config)}. Should be a `dict` or a" + " `BitsAndBytesConfig` instance." + ) + + load_in_8bit = quantization_config.load_in_8bit + + if load_in_8bit: + if torch_dtype is None: + torch_dtype = torch.float16 + if device_map is None: + if torch.cuda.is_available(): + device_map = {"": torch.cuda.current_device()} + else: + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + logger.info( + "The device_map was not initialized. " + "Setting device_map to {'':torch.cuda.current_device()}. " + "If you want to use the model for inference, please set device_map ='auto' " + ) + if low_cpu_mem_usage is None: + low_cpu_mem_usage = True + + elif ( + not is_8bit_serializable + and not load_in_8bit + and quantization_method_from_config == QuantizationMethod.BITS_AND_BYTES + ): + logger.warning( + "Detected the presence of a `quantization_config` attribute in the model's configuration but you don't have the correct" + " `bitsandbytes` version to support int8 serialization. Please install the latest version of `bitsandbytes` with " + " `pip install --upgrade bitsandbytes`." + ) + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # index of the files. + is_sharded = False + sharded_metadata = None + # Load model + loading_info = None + + # Keep in fp32 modules + keep_in_fp32_modules = None + use_keep_in_fp32_modules = False + + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + if from_tf and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") + ): + # Load from a TF 1.0 checkpoint in priority if from_tf + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") + elif from_tf and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) + ): + # Load from a TF 2.0 checkpoint in priority if from_tf + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME) + elif from_flax and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + ): + # Load from a Flax checkpoint in priority if from_flax + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME) + elif use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) + ) + elif use_safetensors is not False and os.path.isfile( + os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + ) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + ) + is_sharded = True + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) + ): + # Load from a PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant) + ) + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) + ) + is_sharded = True + # At this stage we don't have a weight file so we will raise an error. + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index") + ) or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)): + raise OSError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use" + " `from_tf=True` to load this model from those weights." + ) + elif os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)): + raise OSError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`" + " to load this model from those weights." + ) + elif use_safetensors: + raise OSError( + f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path}." + ) + else: + raise OSError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}," + f" {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory" + f" {pretrained_model_name_or_path}." + ) + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + archive_file = pretrained_model_name_or_path + is_local = True + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")): + if not from_tf: + raise ValueError( + f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set " + "from_tf to True to load from this checkpoint." + ) + archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index") + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + filename = pretrained_model_name_or_path + resolved_archive_file = download_url(pretrained_model_name_or_path) + else: + # set correct filename + if from_tf: + filename = TF2_WEIGHTS_NAME + elif from_flax: + filename = FLAX_WEIGHTS_NAME + elif use_safetensors is not False: + filename = _add_variant(SAFE_WEIGHTS_NAME, variant) + else: + filename = _add_variant(WEIGHTS_NAME, variant) + + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "token": token, + "user_agent": user_agent, + "revision": revision, + "subfolder": subfolder, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + } + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + elif use_safetensors: + if revision == "main": + resolved_archive_file, revision, is_sharded = auto_conversion( + pretrained_model_name_or_path, **cached_file_kwargs + ) + cached_file_kwargs["revision"] = revision + if resolved_archive_file is None: + raise OSError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} " + "and thus cannot be loaded with `safetensors`. Please make sure that the model has " + "been saved with `safe_serialization=True` or do not set `use_safetensors=True`." + ) + else: + # This repo has no safetensors file of any kind, we switch to PyTorch. + filename = _add_variant(WEIGHTS_NAME, variant) + resolved_archive_file = cached_file( + pretrained_model_name_or_path, filename, **cached_file_kwargs + ) + if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + if resolved_archive_file is None: + # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error + # message. + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + } + if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs): + raise OSError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights." + " Use `from_tf=True` to load this model from those weights." + ) + elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs): + raise OSError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use" + " `from_flax=True` to load this model from those weights." + ) + elif variant is not None and has_file( + pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs + ): + raise OSError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant" + f" {variant}. Use `variant=None` to load this model from those weights." + ) + else: + raise OSError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(WEIGHTS_NAME, variant)}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or" + f" {FLAX_WEIGHTS_NAME}." + ) + except OSError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception as e: + # For any other exception, we throw a generic error. + raise OSError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + " from 'https://huggingface.co/models', make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}," + f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}." + ) from e + + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + else: + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + else: + resolved_archive_file = None + + # We'll need to download and cache each checkpoint shard if the checkpoint is sharded. + if is_sharded: + # rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _commit_hash=commit_hash, + ) + + if ( + is_safetensors_available() + and isinstance(resolved_archive_file, str) + and resolved_archive_file.endswith(".safetensors") + ): + with safe_open(resolved_archive_file, framework="pt") as f: + metadata = f.metadata() + + if metadata.get("format") == "pt": + pass + elif metadata.get("format") == "tf": + from_tf = True + logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.") + elif metadata.get("format") == "flax": + from_flax = True + logger.info("A Flax safetensors file is being loaded in a PyTorch model.") + else: + raise ValueError( + f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax'] but {metadata.get('format')}" + ) + + from_pt = not (from_tf | from_flax) + + # load pt weights early so that we know which dtype to init the model under + if from_pt: + if not is_sharded and state_dict is None: + # Time to load the checkpoint + state_dict = load_state_dict(resolved_archive_file) + + # set dtype to instantiate the model under: + # 1. If torch_dtype is not None, we use that dtype + # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first + # weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype + # we also may have config.torch_dtype available, but we won't rely on it till v5 + dtype_orig = None + + if torch_dtype is not None: + if isinstance(torch_dtype, str): + if torch_dtype == "auto": + if hasattr(config, "torch_dtype") and config.torch_dtype is not None: + torch_dtype = config.torch_dtype + logger.info(f"Will use torch_dtype={torch_dtype} as defined in model's config object") + else: + if is_sharded and "dtype" in sharded_metadata: + torch_dtype = sharded_metadata["dtype"] + elif not is_sharded: + torch_dtype = get_state_dict_dtype(state_dict) + else: + one_state_dict = load_state_dict(resolved_archive_file[0]) + torch_dtype = get_state_dict_dtype(one_state_dict) + del one_state_dict # free CPU memory + logger.info( + "Since the `torch_dtype` attribute can't be found in model's config object, " + "will use torch_dtype={torch_dtype} as derived from model's weights" + ) + else: + raise ValueError( + f'`torch_dtype` can be either `torch.dtype` or `"auto"`, but received {torch_dtype}' + ) + dtype_orig = cls._set_default_torch_dtype(torch_dtype) + + # Check if `_keep_in_fp32_modules` is not None + use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( + torch_dtype == torch.float16 or load_in_4bit or load_in_8bit + ) + + if is_sharded: + loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"] + else: + loaded_state_dict_keys = list(state_dict.keys()) + if low_cpu_mem_usage or (use_keep_in_fp32_modules and is_accelerate_available()): + # In case some weights need to be kept in float32 and accelerate is not installed, + # we later on want to take the path where state_dict is not None, that is the one + # that do not require accelerate. + state_dict = None + + config.name_or_path = pretrained_model_name_or_path + + # Instantiate model. + init_contexts = [no_init_weights(_enable=_fast_init)] + + if is_deepspeed_zero3_enabled(): + import deepspeed + + if ( + "mics_shard_size" in deepspeed_config()["zero_optimization"] + and deepspeed_config()["zero_optimization"]["mics_shard_size"] > 1 + ): + print("Detected DeepSpeed ZeRO-3 with MiCS: activating zero.MiCS_Init() for this model") + init_contexts = [deepspeed.zero.MiCS_Init(config_dict_or_path=deepspeed_config())] + init_contexts + else: + print("Detected DeepSpeed ZeRO-3: activating zero.init() for this model") + init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts + elif load_in_8bit or load_in_4bit or low_cpu_mem_usage: + init_contexts.append(init_empty_weights()) + + config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. + config = cls._autoset_attn_implementation( + config, use_flash_attention_2=use_flash_attention_2, torch_dtype=torch_dtype, device_map=device_map + ) + + with ContextManagers(init_contexts): + # Let's make sure we don't run the init function of buffer modules + model = cls(config, *model_args, **model_kwargs) + + # make sure we use the model's config since the __init__ call might have copied it + config = model.config + + # Check first if we are `from_pt` + if use_keep_in_fp32_modules: + if is_accelerate_available(): + low_cpu_mem_usage = True + keep_in_fp32_modules = model._keep_in_fp32_modules + else: + keep_in_fp32_modules = [] + + if load_in_8bit or load_in_4bit: + from .integrations import get_keys_to_not_convert, replace_with_bnb_linear + + llm_int8_skip_modules = quantization_config.llm_int8_skip_modules + load_in_8bit_fp32_cpu_offload = quantization_config.llm_int8_enable_fp32_cpu_offload + if load_in_8bit: + logger.info("Detected 8-bit loading: activating 8-bit loading for this model") + else: + logger.info("Detected 4-bit loading: activating 4-bit loading for this model") + + # We keep some modules such as the lm_head in their original dtype for numerical stability reasons + if llm_int8_skip_modules is None: + modules_to_not_convert = get_keys_to_not_convert(model) + else: + modules_to_not_convert = llm_int8_skip_modules + + if not isinstance(modules_to_not_convert, list): + modules_to_not_convert = [modules_to_not_convert] + + modules_to_not_convert.extend(keep_in_fp32_modules) + + # Extend the modules to not convert to keys that are supposed to be offloaded to `cpu` or `disk` + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + + if len(keys_on_cpu) > 0 and not load_in_8bit_fp32_cpu_offload: + raise ValueError( + "If you want to offload some keys to `cpu` or `disk`, you need to set " + "`llm_int8_enable_fp32_cpu_offload=True`. Note that these modules will not be " + " converted to 8-bit but kept in 32-bit." + ) + + modules_to_not_convert.extend(keys_on_cpu) + + supports_4bit = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.39.0") + + if load_in_4bit and not supports_4bit: + raise ValueError( + "You have a version of `bitsandbytes` that is not compatible with 4bit inference and training" + " make sure you have the latest version of `bitsandbytes` installed" + ) + + model = replace_with_bnb_linear( + model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config + ) + # training in 8-bit is only available in 0.37.0+ + model._is_quantized_training_enabled = version.parse( + importlib.metadata.version("bitsandbytes") + ) >= version.parse("0.37.0") + + config.quantization_config = quantization_config + model.is_8bit_serializable = is_8bit_serializable + + if load_in_8bit and torch_dtype is None: + logger.warning( + "You are loading your model in 8bit but you did not specify a `torch_dtype` attribute. " + "All non-linear modules will be loaded in full precision." + " If you want to load the other modules in other precision, please specify a `torch_dtype` attribute." + ) + if quantization_method_from_config == QuantizationMethod.GPTQ: + model = quantizer.convert_model(model) + model._is_quantized_training_enabled = True + elif quantization_method_from_config == QuantizationMethod.AWQ: + from .integrations import fuse_awq_modules, get_keys_to_not_convert, replace_with_awq_linear + + modules_to_not_convert = get_keys_to_not_convert(model) + + if quantization_config is None: + quantization_config = AwqConfig.from_dict(config.quantization_config) + + model, has_been_replaced = replace_with_awq_linear( + model, quantization_config=quantization_config, modules_to_not_convert=modules_to_not_convert + ) + model._is_quantized_training_enabled = False + + if not has_been_replaced: + logger.warning( + "You are loading an AWQ model but no linear modules were found in your model." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + + if quantization_method_from_config is not None: + model.quantization_method = quantization_method_from_config + elif quantization_method_from_args is not None: + model.quantization_method = quantization_method_from_args + if hasattr(model, "quantization_method"): + model.is_quantized = True + + # We store the original dtype for quantized models as we cannot easily retrieve it + # once the weights have been quantized + # Note that once you have loaded a quantized model, you can't change its dtype so this will + # remain a single source of truth + config._pre_quantization_dtype = torch_dtype + + if isinstance(device_map, str): + special_dtypes = {} + if load_in_8bit or load_in_4bit: + special_dtypes.update( + { + name: torch_dtype + for name, _ in model.named_parameters() + if any(m in name for m in modules_to_not_convert) + } + ) + + special_dtypes.update( + { + name: torch.float32 + for name, _ in model.named_parameters() + if any(m in name for m in keep_in_fp32_modules) + } + ) + + target_dtype = torch_dtype + + if load_in_4bit: + if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"): + from accelerate.utils import CustomDtype + + target_dtype = CustomDtype.INT4 + else: + raise ValueError( + "You are using `device_map='auto'` on a 4bit loaded version of the model. To automatically compute" + " the appropriate device map, you should upgrade your `accelerate` library, " + "`pip install --upgrade accelerate` or install it from source to support fp4 auto device map " + "calculation. You may encounter unexpected behavior, or pass your own device map" + ) + elif load_in_8bit: + target_dtype = torch.int8 + + no_split_modules = model._get_no_split_modules(device_map) + if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + raise ValueError( + "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or " + "'sequential'." + ) + + device_map_kwargs = {"no_split_module_classes": no_split_modules} + if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters: + device_map_kwargs["special_dtypes"] = special_dtypes + elif len(special_dtypes) > 0: + logger.warning( + "This model has some weights that should be kept in higher precision, you need to upgrade " + "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)." + ) + if device_map != "sequential": + max_memory = get_balanced_memory( + model, + dtype=target_dtype, + low_zero=(device_map == "balanced_low_0"), + max_memory=max_memory, + **device_map_kwargs, + ) + else: + max_memory = get_max_memory(max_memory) + if getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES: + # need more space for buffers that are created during quantization + max_memory = {key: val * 0.90 for key, val in max_memory.items()} + device_map_kwargs["max_memory"] = max_memory + + # Make sure tied weights are tied before creating the device map. + model.tie_weights() + device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) + + if load_in_8bit or load_in_4bit: + # The LM head / tied weights or any last module can stay on disk / CPU + device_map_without_lm_head = { + key: device_map[key] for key in device_map.keys() if key not in modules_to_not_convert + } + if "cpu" in device_map_without_lm_head.values() or "disk" in device_map_without_lm_head.values(): + raise ValueError( + """ + Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit + the quantized model. If you want to dispatch the model on the CPU or the disk while keeping + these modules in 32-bit, you need to set `load_in_8bit_fp32_cpu_offload=True` and pass a custom + `device_map` to `from_pretrained`. Check + https://huggingface.co/docs/transformers/main/en/main_classes/quantization#offload-between-cpu-and-gpu + for more details. + """ + ) + del device_map_without_lm_head + + elif device_map is not None: + model.tie_weights() + tied_params = find_tied_parameters(model) + # check if we don't have tied param in different devices + check_tied_parameters_on_same_device(tied_params, device_map) + + if from_tf: + if resolved_archive_file.endswith(".index"): + # Load from a TensorFlow 1.X checkpoint - provided by original authors + model = cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' + else: + # Load from our TensorFlow 2.0 checkpoints + try: + from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model + + model, loading_info = load_tf2_checkpoint_in_pytorch_model( + model, resolved_archive_file, allow_missing_keys=True, output_loading_info=True + ) + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed." + " Please see https://pytorch.org/ and https://www.tensorflow.org/install/ for installation" + " instructions." + ) + raise + elif from_flax: + try: + from .modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model + + model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file) + except ImportError: + logger.error( + "Loading a Flax model in PyTorch, requires both PyTorch and Flax to be installed. Please see" + " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for" + " installation instructions." + ) + raise + elif from_pt: + # restore default dtype + if dtype_orig is not None: + torch.set_default_dtype(dtype_orig) + ( + model, + missing_keys, + unexpected_keys, + mismatched_keys, + offload_index, + error_msgs, + ) = cls._load_pretrained_model( + model, + state_dict, + loaded_state_dict_keys, # XXX: rename? + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + sharded_metadata=sharded_metadata, + _fast_init=_fast_init, + low_cpu_mem_usage=low_cpu_mem_usage, + device_map=device_map, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + dtype=torch_dtype, + is_quantized=(getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES), + keep_in_fp32_modules=keep_in_fp32_modules, + ) + + model.is_loaded_in_4bit = load_in_4bit + model.is_loaded_in_8bit = load_in_8bit + + # make sure token embedding weights are still tied if needed + model.tie_weights() + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate() and pretrained_model_name_or_path is not None: + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + except OSError: + logger.info( + "Generation config file not found, using a generation config created from the model config." + ) + pass + + if ( + quantization_config is not None + and quantization_config.quant_method == QuantizationMethod.AWQ + and quantization_config.do_fuse + ): + model = fuse_awq_modules(model, config.quantization_config) + model._awq_is_fused = True + + # Dispatch model with hooks on all devices if necessary + if device_map is not None: + device_map_kwargs = { + "device_map": device_map, + "offload_dir": offload_folder, + "offload_index": offload_index, + } + if "skip_keys" in inspect.signature(dispatch_model).parameters: + device_map_kwargs["skip_keys"] = model._skip_keys_device_placement + dispatch_model(model, **device_map_kwargs) + + if quantization_method_from_args == QuantizationMethod.GPTQ: + if quantization_config.tokenizer is None: + quantization_config.tokenizer = pretrained_model_name_or_path + if cls.main_input_name != "input_ids": + raise RuntimeError("We can only quantize pure text model.") + quantizer.quantize_model(model, quantization_config.tokenizer) + config.quantization_config = GPTQConfig.from_dict_optimum(quantizer.to_dict()) + model._is_quantized_training_enabled = True + if quantization_method_from_config == QuantizationMethod.GPTQ: + model = quantizer.post_init_model(model) + + if _adapter_model_path is not None: + model.load_adapter( + _adapter_model_path, + adapter_name=adapter_name, + token=token, + adapter_kwargs=adapter_kwargs, + ) + + if output_loading_info: + if loading_info is None: + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict, + loaded_keys, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + sharded_metadata=None, + _fast_init=True, + low_cpu_mem_usage=False, + device_map=None, + offload_folder=None, + offload_state_dict=None, + dtype=None, + is_quantized=False, + keep_in_fp32_modules=None, + ): + is_safetensors = False + if is_quantized: + from .integrations import set_module_quantized_tensor_to_device + + if device_map is not None and "disk" in device_map.values(): + archive_file = ( + resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file + ) + is_safetensors = archive_file.endswith(".safetensors") + if offload_folder is None and not is_safetensors: + raise ValueError( + "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`" + " for them. Alternatively, make sure you have `safetensors` installed if the model you are using" + " offers the weights in this format." + ) + if offload_folder is not None: + os.makedirs(offload_folder, exist_ok=True) + if offload_state_dict is None: + offload_state_dict = True + + is_sharded_safetensors = is_safetensors and sharded_metadata is not None + + # tie the model weights before retrieving the state_dict + model.tie_weights() + + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + expected_keys = list(model_state_dict.keys()) + prefix = model.base_model_prefix + + def _fix_key(key): + if "beta" in key: + return key.replace("beta", "bias") + if "gamma" in key: + return key.replace("gamma", "weight") + return key + + original_loaded_keys = loaded_keys + loaded_keys = [_fix_key(key) for key in loaded_keys] + + if len(prefix) > 0: + has_prefix_module = any(s.startswith(prefix) for s in loaded_keys) + expects_prefix_module = any(s.startswith(prefix) for s in expected_keys) + else: + has_prefix_module = False + expects_prefix_module = False + + # key re-naming operations are never done on the keys + # that are loaded, but always on the keys of the newly initialized model + remove_prefix_from_model = not has_prefix_module and expects_prefix_module + add_prefix_to_model = has_prefix_module and not expects_prefix_module + + if remove_prefix_from_model: + _prefix = f"{prefix}." + expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)] + expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys] + elif add_prefix_to_model: + expected_keys = [".".join([prefix, s]) for s in expected_keys] + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = set(loaded_keys) - set(expected_keys) + # Remove nonpersistent buffers from unexpected keys: they are not in the state dict but will be in the model + # buffers + model_buffers = {n for n, _ in model.named_buffers()} + if remove_prefix_from_model: + model_buffers = {key[len(_prefix) :] if key.startswith(_prefix) else key for key in model_buffers} + elif add_prefix_to_model: + model_buffers = {".".join([prefix, key]) for key in model_buffers} + unexpected_keys = list(unexpected_keys - model_buffers) + + model.tie_weights() + if device_map is None and not is_fsdp_enabled(): + ptrs = collections.defaultdict(list) + for name, tensor in model.state_dict().items(): + id_tensor = id_tensor_storage(tensor) + ptrs[id_tensor].append(name) + + # These are all the pointers of shared tensors. + tied_params = [names for _, names in ptrs.items() if len(names) > 1] + else: + # id function doesn't work for meta tensor so we need this function + tied_params = find_tied_parameters(model) + + for group in tied_params: + if remove_prefix_from_model: + group = [key[len(_prefix) :] if key.startswith(_prefix) else key for key in group] + elif add_prefix_to_model: + group = [".".join([prefix, key]) for key in group] + missing_in_group = [k for k in missing_keys if k in group] + if len(missing_in_group) > 0 and len(missing_in_group) < len(group): + missing_keys = [k for k in missing_keys if k not in missing_in_group] + + # Some models may have keys that are not in the state by design, removing them before needlessly warning + # the user. + if cls._keys_to_ignore_on_load_missing is not None: + for pat in cls._keys_to_ignore_on_load_missing: + missing_keys = [k for k in missing_keys if re.search(pat, k) is None] + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] + + # retrieve weights on meta device and put them back on CPU. + # This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step + if low_cpu_mem_usage: + for key in missing_keys: + if key in list(model_state_dict.keys()): + key = key + elif f"{prefix}.{key}" in list(model_state_dict.keys()): + key = f"{prefix}.{key}" + elif key.startswith(prefix) and ".".join(key.split(".")[1:]) in list(model_state_dict.keys()): + key = ".".join(key.split(".")[1:]) + param = model_state_dict[key] + + # upcast in fp32 if any + target_dtype = dtype + if ( + keep_in_fp32_modules is not None + and dtype == torch.float16 + and any(module_to_keep_in_fp32 in key.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules) + ): + target_dtype = torch.float32 + + if param.device == torch.device("meta"): + if not (is_quantized): + set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype)) + else: + set_module_quantized_tensor_to_device( + model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype) + ) + + # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. + if _fast_init: + if remove_prefix_from_model: + _loaded_keys = [f"{prefix}.{k}" for k in loaded_keys] + elif add_prefix_to_model: + _loaded_keys = [k[len(prefix) + 1 :] for k in loaded_keys] + else: + _loaded_keys = loaded_keys + set_initialized_submodules(model, _loaded_keys) + # This will only initialize submodules that are not marked as initialized by the line above. + model.apply(model._initialize_weights) + + # Set some modules to fp32 if any + if keep_in_fp32_modules is not None: + for name, param in model.named_parameters(): + if any(module_to_keep_in_fp32 in name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): + # param = param.to(torch.float32) does not work here as only in the local scope. + param.data = param.data.to(torch.float32) + + # Make sure we are able to load base models as well as derived models (with heads) + start_prefix = "" + model_to_load = model + if len(cls.base_model_prefix) > 0 and not hasattr(model, cls.base_model_prefix) and has_prefix_module: + start_prefix = cls.base_model_prefix + "." + if len(cls.base_model_prefix) > 0 and hasattr(model, cls.base_model_prefix) and not has_prefix_module: + model_to_load = getattr(model, cls.base_model_prefix) + base_model_expected_keys = list(model_to_load.state_dict().keys()) + if any(key in expected_keys_not_prefixed and key not in base_model_expected_keys for key in loaded_keys): + raise ValueError( + "The state dictionary of the model you are trying to load is corrupted. Are you sure it was " + "properly saved?" + ) + if device_map is not None: + device_map = {k.replace(f"{cls.base_model_prefix}.", ""): v for k, v in device_map.items()} + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + # If the checkpoint is sharded, we may not have the key here. + if checkpoint_key not in state_dict: + continue + model_key = checkpoint_key + if remove_prefix_from_model: + # The model key starts with `prefix` but `checkpoint_key` doesn't so we add it. + model_key = f"{prefix}.{checkpoint_key}" + elif add_prefix_to_model: + # The model key doesn't start with `prefix` but `checkpoint_key` does so we remove it. + model_key = ".".join(checkpoint_key.split(".")[1:]) + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if resolved_archive_file is not None: + folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1]) + else: + folder = None + if device_map is not None and is_safetensors: + param_device_map = expand_device_map(device_map, original_loaded_keys, start_prefix) + str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32" + if sharded_metadata is None: + archive_file = ( + resolved_archive_file[0] + if isinstance(resolved_archive_file, (list, tuple)) + else resolved_archive_file + ) + weight_map = {p: archive_file for p in original_loaded_keys} + else: + weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()} + offload_index = { + p[len(start_prefix) :]: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype} + for p, f in weight_map.items() + if p.startswith(start_prefix) and param_device_map[p[len(start_prefix) :]] == "disk" + } + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix) + offload_index = None + else: + # Sharded checkpoint or whole but low_cpu_mem_usage==True + + # This should always be a list but, just to be sure. + if not isinstance(resolved_archive_file, list): + resolved_archive_file = [resolved_archive_file] + + error_msgs = [] + mismatched_keys = [] + if not is_safetensors: + offload_index = {} if device_map is not None and "disk" in device_map.values() else None + if offload_state_dict: + state_dict_folder = tempfile.mkdtemp() + state_dict_index = {} + else: + state_dict_folder = None + state_dict_index = None + + if is_sharded_safetensors: + disk_only_shard_files = get_disk_only_shard_files( + device_map, sharded_metadata=sharded_metadata, start_prefix=start_prefix + ) + disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files] + else: + disk_only_shard_files = [] + + if len(resolved_archive_file) > 1: + resolved_archive_file = logging.tqdm(resolved_archive_file, desc="Loading checkpoint shards") + for shard_file in resolved_archive_file: + # Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload. + if shard_file in disk_only_shard_files: + continue + state_dict = load_state_dict(shard_file) + + # Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not + # matching the weights in the model. + mismatched_keys += _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + add_prefix_to_model, + remove_prefix_from_model, + ignore_mismatched_sizes, + ) + if low_cpu_mem_usage: + if is_fsdp_enabled() and not is_local_dist_rank_0(): + for key, param in model_to_load.state_dict().items(): + if param.device == torch.device("meta"): + if not (is_quantized): + set_module_tensor_to_device( + model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) + ) + else: + set_module_quantized_tensor_to_device( + model_to_load, key, "cpu", torch.empty(*param.size(), dtype=dtype) + ) + else: + new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( + model_to_load, + state_dict, + loaded_keys, + start_prefix, + expected_keys, + device_map=device_map, + offload_folder=offload_folder, + offload_index=offload_index, + state_dict_folder=state_dict_folder, + state_dict_index=state_dict_index, + dtype=dtype, + is_quantized=is_quantized, + is_safetensors=is_safetensors, + keep_in_fp32_modules=keep_in_fp32_modules, + ) + error_msgs += new_error_msgs + else: + error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) + + # force memory release + del state_dict + gc.collect() + + if offload_index is not None and len(offload_index) > 0: + if model != model_to_load: + # We need to add the prefix of the base model + prefix = cls.base_model_prefix + if not is_safetensors: + for weight_name in offload_index: + shutil.move( + os.path.join(offload_folder, f"{weight_name}.dat"), + os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"), + ) + offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()} + if not is_safetensors: + save_offload_index(offload_index, offload_folder) + offload_index = None + + if offload_state_dict: + # Load back temporarily offloaded state dict + load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder) + shutil.rmtree(state_dict_folder) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if is_quantized: + unexpected_keys = [elem for elem in unexpected_keys if "SCB" not in elem] + missing_keys = [elem for elem in missing_keys if "SCB" not in elem] + + if len(unexpected_keys) > 0: + archs = [] if model.config.architectures is None else model.config.architectures + warner = logger.warning if model.__class__.__name__ in archs else logger.info + warner( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" + " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs + + def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False): + module_keys = {".".join(key.split(".")[:-1]) for key in names} + + # torch.nn.ParameterList is a special case where two parameter keywords + # are appended to the module name, *e.g.* bert.special_embeddings.0 + module_keys = module_keys.union( + {".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()} + ) + + retrieved_modules = [] + # retrieve all modules that has at least one missing weight name + for name, module in self.named_modules(): + if remove_prefix: + _prefix = f"{self.base_model_prefix}." + name = name[len(_prefix) :] if name.startswith(_prefix) else name + elif add_prefix: + name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix + + if name in module_keys: + retrieved_modules.append(module) + + return retrieved_modules + + @staticmethod + def _load_pretrained_model_low_mem(model, loaded_state_dict_keys, resolved_archive_file, start_prefix=""): + """ + This is an experimental function that loads the model using ~1.x model size CPU memory + + Before you call it do: + + 1. save which state_dict keys are available + 2. drop state_dict before model is created, since the latter takes 1x model size memory + + Here then we continue: + + 3. switch to the meta device all params/buffers that are going to be replaced from the loaded state_dict + 4. load state_dict 2nd time + 5. replace the params/buffers from the state_dict + + Currently, it doesn't handle missing_keys, unexpected_keys, mismatched_keys. It can't handle deepspeed. + """ + + _move_model_to_meta(model, loaded_state_dict_keys, start_prefix) + state_dict = load_state_dict(resolved_archive_file) + error_msgs = _load_state_dict_into_meta_model(model, state_dict, loaded_state_dict_keys, start_prefix) + return error_msgs + + @classmethod + def register_for_auto_class(cls, auto_class="AutoModel"): + """ + Register this class with a given auto class. This should only be used for custom models as the ones in the + library are already mapped with an auto class. + + + + This API is experimental and may have some slight breaking changes in the next releases. + + + + Args: + auto_class (`str` or `type`, *optional*, defaults to `"AutoModel"`): + The auto class to register this new model with. + """ + if not isinstance(auto_class, str): + auto_class = auto_class.__name__ + + import transformers.models.auto as auto_module + + if not hasattr(auto_module, auto_class): + raise ValueError(f"{auto_class} is not a valid auto class.") + + cls._auto_class = auto_class + + def to_bettertransformer(self) -> "PreTrainedModel": + """ + Converts the model to use [PyTorch's native attention + implementation](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html), integrated to + Transformers through [Optimum library](https://huggingface.co/docs/optimum/bettertransformer/overview). Only a + subset of all Transformers models are supported. + + PyTorch's attention fastpath allows to speed up inference through kernel fusions and the use of [nested + tensors](https://pytorch.org/docs/stable/nested.html). Detailed benchmarks can be found in [this blog + post](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2). + + Returns: + [`PreTrainedModel`]: The model converted to BetterTransformer. + """ + if not is_optimum_available(): + raise ImportError("The package `optimum` is required to use Better Transformer.") + + from optimum.version import __version__ as optimum_version + + if version.parse(optimum_version) < version.parse("1.7.0"): + raise ImportError( + f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found." + ) + + from optimum.bettertransformer import BetterTransformer + + return BetterTransformer.transform(self) + + def reverse_bettertransformer(self): + """ + Reverts the transformation from [`~PreTrainedModel.to_bettertransformer`] so that the original modeling is + used, for example in order to save the model. + + Returns: + [`PreTrainedModel`]: The model converted back to the original modeling. + """ + if not is_optimum_available(): + raise ImportError("The package `optimum` is required to use Better Transformer.") + + from optimum.version import __version__ as optimum_version + + if version.parse(optimum_version) < version.parse("1.7.0"): + raise ImportError( + f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found." + ) + + from optimum.bettertransformer import BetterTransformer + + return BetterTransformer.reverse(self) + + def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask): + """ + Shows a one-time warning if the input_ids appear to contain padding and no attention mask was given. + """ + + # Skip the check during tracing. + if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing() or is_torchdynamo_compiling(): + return + + if (attention_mask is not None) or (self.config.pad_token_id is None): + return + + # Check only the first and last input IDs to reduce overhead. + if self.config.pad_token_id in input_ids[:, [-1, 0]]: + warn_string = ( + "We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See " + "https://huggingface.co/docs/transformers/troubleshooting" + "#incorrect-output-when-padding-tokens-arent-masked." + ) + + # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an + # attention_mask or not. In this case, we should still show a warning because this is a rare case. + if ( + (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id) + or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id) + or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id) + ): + warn_string += ( + f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical " + f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), " + f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded." + ) + + logger.warning_once(warn_string) + + +PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub) +if PreTrainedModel.push_to_hub.__doc__ is not None: + PreTrainedModel.push_to_hub.__doc__ = PreTrainedModel.push_to_hub.__doc__.format( + object="model", object_class="AutoModel", object_files="model file" + ) + + +class PoolerStartLogits(nn.Module): + """ + Compute SQuAD start logits from sequence hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model. + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, 1) + + def forward( + self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + Returns: + `torch.FloatTensor`: The start logits for SQuAD. + """ + x = self.dense(hidden_states).squeeze(-1) + + if p_mask is not None: + if get_parameter_dtype(self) == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +class PoolerEndLogits(nn.Module): + """ + Compute SQuAD end logits from sequence hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` + to use. + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dense_1 = nn.Linear(config.hidden_size, 1) + + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + The hidden states of the first tokens for the labeled span. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the first token for the labeled span. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + + + + One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides + `start_states`. + + + + Returns: + `torch.FloatTensor`: The end logits for SQuAD. + """ + assert ( + start_states is not None or start_positions is not None + ), "One of start_states, start_positions should be not None" + if start_positions is not None: + slen, hsz = hidden_states.shape[-2:] + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz) + start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz) + + x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1)) + x = self.activation(x) + x = self.LayerNorm(x) + x = self.dense_1(x).squeeze(-1) + + if p_mask is not None: + if get_parameter_dtype(self) == torch.float16: + x = x * (1 - p_mask) - 65500 * p_mask + else: + x = x * (1 - p_mask) - 1e30 * p_mask + + return x + + +class PoolerAnswerClass(nn.Module): + """ + Compute SQuAD 2.0 answer class from classification and start tokens hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model. + """ + + def __init__(self, config): + super().__init__() + self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size) + self.activation = nn.Tanh() + self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) + + def forward( + self, + hidden_states: torch.FloatTensor, + start_states: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + cls_index: Optional[torch.LongTensor] = None, + ) -> torch.FloatTensor: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + The final hidden states of the model. + start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*): + The hidden states of the first tokens for the labeled span. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + The position of the first token for the labeled span. + cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Position of the CLS token for each sentence in the batch. If `None`, takes the last token. + + + + One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides + `start_states`. + + + + Returns: + `torch.FloatTensor`: The SQuAD 2.0 answer class. + """ + # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample. + hsz = hidden_states.shape[-1] + assert ( + start_states is not None or start_positions is not None + ), "One of start_states, start_positions should be not None" + if start_positions is not None: + start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz) + + if cls_index is not None: + cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz) + cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz) + else: + cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz) + + x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1)) + x = self.activation(x) + x = self.dense_1(x).squeeze(-1) + + return x + + +@dataclass +class SquadHeadOutput(ModelOutput): + """ + Base class for outputs of question answering models using a [`~modeling_utils.SQuADHead`]. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided): + Classification loss as the sum of start token, end token (and is_impossible if provided) classification + losses. + start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the top config.start_n_top start token possibilities (beam-search). + start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Indices for the top config.start_n_top start token possibilities (beam-search). + end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities + (beam-search). + end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search). + cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided): + Log probabilities for the `is_impossible` label of the answers. + + """ + + loss: Optional[torch.FloatTensor] = None + start_top_log_probs: Optional[torch.FloatTensor] = None + start_top_index: Optional[torch.LongTensor] = None + end_top_log_probs: Optional[torch.FloatTensor] = None + end_top_index: Optional[torch.LongTensor] = None + cls_logits: Optional[torch.FloatTensor] = None + + +class SQuADHead(nn.Module): + r""" + A SQuAD head inspired by XLNet. + + Args: + config ([`PretrainedConfig`]): + The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps` + to use. + """ + + def __init__(self, config): + super().__init__() + self.start_n_top = config.start_n_top + self.end_n_top = config.end_n_top + + self.start_logits = PoolerStartLogits(config) + self.end_logits = PoolerEndLogits(config) + self.answer_class = PoolerAnswerClass(config) + + @replace_return_docstrings(output_type=SquadHeadOutput, config_class=PretrainedConfig) + def forward( + self, + hidden_states: torch.FloatTensor, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + cls_index: Optional[torch.LongTensor] = None, + is_impossible: Optional[torch.LongTensor] = None, + p_mask: Optional[torch.FloatTensor] = None, + return_dict: bool = False, + ) -> Union[SquadHeadOutput, Tuple[torch.FloatTensor]]: + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`): + Final hidden states of the model on the sequence tokens. + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Positions of the first token for the labeled span. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Positions of the last token for the labeled span. + cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Position of the CLS token for each sentence in the batch. If `None`, takes the last token. + is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Whether the question has a possible answer in the paragraph or not. + p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*): + Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token + should be masked. + return_dict (`bool`, *optional*, defaults to `False`): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + """ + start_logits = self.start_logits(hidden_states, p_mask=p_mask) + + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, let's remove the dimension added by batch splitting + for x in (start_positions, end_positions, cls_index, is_impossible): + if x is not None and x.dim() > 1: + x.squeeze_(-1) + + # during training, compute the end logits based on the ground truth of the start position + end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask) + + loss_fct = CrossEntropyLoss() + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if cls_index is not None and is_impossible is not None: + # Predict answerability from the representation of CLS and START + cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index) + loss_fct_cls = nn.BCEWithLogitsLoss() + cls_loss = loss_fct_cls(cls_logits, is_impossible) + + # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss + total_loss += cls_loss * 0.5 + + return SquadHeadOutput(loss=total_loss) if return_dict else (total_loss,) + + else: + # during inference, compute the end logits based on beam search + bsz, slen, hsz = hidden_states.size() + start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen) + + start_top_log_probs, start_top_index = torch.topk( + start_log_probs, self.start_n_top, dim=-1 + ) # shape (bsz, start_n_top) + start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz) + start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz) + start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz) + + hidden_states_expanded = hidden_states.unsqueeze(2).expand_as( + start_states + ) # shape (bsz, slen, start_n_top, hsz) + p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None + end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask) + end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top) + + end_top_log_probs, end_top_index = torch.topk( + end_log_probs, self.end_n_top, dim=1 + ) # shape (bsz, end_n_top, start_n_top) + end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top) + end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top) + + start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs) + cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index) + + if not return_dict: + return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + else: + return SquadHeadOutput( + start_top_log_probs=start_top_log_probs, + start_top_index=start_top_index, + end_top_log_probs=end_top_log_probs, + end_top_index=end_top_index, + cls_logits=cls_logits, + ) + + +class SequenceSummary(nn.Module): + r""" + Compute a single vector summary of a sequence hidden states. + + Args: + config ([`PretrainedConfig`]): + The config used by the model. Relevant arguments in the config class of the model are (refer to the actual + config class of your model for the default values it uses): + + - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: + + - `"last"` -- Take the last token hidden state (like XLNet) + - `"first"` -- Take the first token hidden state (like Bert) + - `"mean"` -- Take the mean of all tokens hidden states + - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) + - `"attn"` -- Not implemented now, use multi-head attention + + - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. + - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes + (otherwise to `config.hidden_size`). + - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, + another string or `None` will add no activation. + - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. + - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + + self.summary_type = getattr(config, "summary_type", "last") + if self.summary_type == "attn": + # We should use a standard multi-head attention module with absolute positional embedding for that. + # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 + # We can probably just use the multi-head attention module of PyTorch >=1.1.0 + raise NotImplementedError + + self.summary = Identity() + if hasattr(config, "summary_use_proj") and config.summary_use_proj: + if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: + num_classes = config.num_labels + else: + num_classes = config.hidden_size + self.summary = nn.Linear(config.hidden_size, num_classes) + + activation_string = getattr(config, "summary_activation", None) + self.activation: Callable = get_activation(activation_string) if activation_string else Identity() + + self.first_dropout = Identity() + if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: + self.first_dropout = nn.Dropout(config.summary_first_dropout) + + self.last_dropout = Identity() + if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: + self.last_dropout = nn.Dropout(config.summary_last_dropout) + + def forward( + self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None + ) -> torch.FloatTensor: + """ + Compute a single vector summary of a sequence hidden states. + + Args: + hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): + The hidden states of the last layer. + cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): + Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. + + Returns: + `torch.FloatTensor`: The summary of the sequence hidden states. + """ + if self.summary_type == "last": + output = hidden_states[:, -1] + elif self.summary_type == "first": + output = hidden_states[:, 0] + elif self.summary_type == "mean": + output = hidden_states.mean(dim=1) + elif self.summary_type == "cls_index": + if cls_index is None: + cls_index = torch.full_like( + hidden_states[..., :1, :], + hidden_states.shape[-2] - 1, + dtype=torch.long, + ) + else: + cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) + cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) + # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states + output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) + elif self.summary_type == "attn": + raise NotImplementedError + + output = self.first_dropout(output) + output = self.summary(output) + output = self.activation(output) + output = self.last_dropout(output) + + return output + + +def unwrap_model(model: nn.Module) -> nn.Module: + """ + Recursively unwraps a model from potential containers (as used in distributed training). + + Args: + model (`torch.nn.Module`): The model to unwrap. + """ + # since there could be multiple levels of wrapping, unwrap recursively + if hasattr(model, "module"): + return unwrap_model(model.module) + else: + return model + + +def expand_device_map(device_map, param_names, start_prefix): + """ + Expand a device map to return the correspondance parameter name to device. + """ + new_device_map = {} + param_names = [p[len(start_prefix) :] for p in param_names if p.startswith(start_prefix)] + for module, device in device_map.items(): + new_device_map.update( + {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""} + ) + return new_device_map + + +def get_disk_only_shard_files(device_map, sharded_metadata, start_prefix): + """ + Returns the list of shard files containing only weights offloaded to disk. + """ + + weight_map = { + p[len(start_prefix) :]: v for p, v in sharded_metadata["weight_map"].items() if p.startswith(start_prefix) + } + files_content = collections.defaultdict(list) + for weight_name, filename in weight_map.items(): + while len(weight_name) > 0 and weight_name not in device_map: + weight_name = ".".join(weight_name.split(".")[:-1]) + files_content[filename].append(device_map[weight_name]) + + return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}] diff --git a/llava/train/transformers_replace/models/gemma/__init__.py b/llava/train/transformers_replace/models/gemma/__init__.py new file mode 100755 index 00000000..df0d9931 --- /dev/null +++ b/llava/train/transformers_replace/models/gemma/__init__.py @@ -0,0 +1,111 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_flax_available, + is_sentencepiece_available, + is_tokenizers_available, + is_torch_available, +) + +_import_structure = { + "configuration_gemma": ["GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "GemmaConfig"], +} + +try: + if not is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_gemma"] = ["GemmaTokenizer"] + +try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_gemma_fast"] = ["GemmaTokenizerFast"] + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_gemma"] = [ + "GemmaForCausalLM", + "GemmaModel", + "GemmaPreTrainedModel", + "GemmaForSequenceClassification", + ] + +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_flax_gemma"] = [ + "FlaxGemmaForCausalLM", + "FlaxGemmaModel", + "FlaxGemmaPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_gemma import GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP, GemmaConfig + + try: + if not is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_gemma import GemmaTokenizer + + try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_gemma_fast import GemmaTokenizerFast + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_gemma import GemmaForCausalLM, GemmaForSequenceClassification, GemmaModel, GemmaPreTrainedModel + + try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_flax_gemma import FlaxGemmaForCausalLM, FlaxGemmaModel, FlaxGemmaPreTrainedModel + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/llava/train/transformers_replace/models/gemma/configuration_gemma.py b/llava/train/transformers_replace/models/gemma/configuration_gemma.py new file mode 100755 index 00000000..a54f1bd7 --- /dev/null +++ b/llava/train/transformers_replace/models/gemma/configuration_gemma.py @@ -0,0 +1,145 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Gemma model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + +logger = logging.get_logger(__name__) + +GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class GemmaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the Gemma-7B. + + e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 256000): + Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GemmaModel`] + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 24576): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 28): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 16): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + head_dim (`int`, *optional*, defaults to 256): + The attention head dimension. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 8192): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + eos_token_id (`int`, *optional*, defaults to 1): + End of stream token id. + bos_token_id (`int`, *optional*, defaults to 2): + Beginning of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import GemmaModel, GemmaConfig + + >>> # Initializing a Gemma gemma-7b style configuration + >>> configuration = GemmaConfig() + + >>> # Initializing a model from the gemma-7b style configuration + >>> model = GemmaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gemma" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=256000, + hidden_size=3072, + intermediate_size=24576, + num_hidden_layers=28, + num_attention_heads=16, + num_key_value_heads=16, + head_dim=256, + hidden_act="gelu", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + eos_token_id=1, + bos_token_id=2, + tie_word_embeddings=True, + rope_theta=10000.0, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/llava/train/transformers_replace/models/gemma/modeling_gemma.py b/llava/train/transformers_replace/models/gemma/modeling_gemma.py new file mode 100755 index 00000000..4ff4b52e --- /dev/null +++ b/llava/train/transformers_replace/models/gemma/modeling_gemma.py @@ -0,0 +1,1331 @@ +# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Gemma model.""" +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 +from ...utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from ...utils.import_utils import is_torch_fx_available +from .configuration_gemma import GemmaConfig + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "GemmaConfig" + + +def _get_unpad_data(attention_mask, seqlens_in_batch): + if seqlens_in_batch is None: + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + else: + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class GemmaRMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.zeros(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x.float()).type_as(x) + return output * (1 + self.weight) + + +ALL_LAYERNORM_LAYERS.append(GemmaRMSNorm) + + +class GemmaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.register_buffer("inv_freq", None, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.inv_freq is None: + self.inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim) + ) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Gemma +class GemmaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class GemmaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + # Ignore copy + def __init__(self, config: GemmaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = config.head_dim + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if self.hidden_size % self.num_heads != 0: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self.rotary_emb = GemmaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + if cache_position is not None: + causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] + else: + causal_mask = attention_mask + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.view(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Gemma +class GemmaFlashAttention2(GemmaAttention): + """ + Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + # Ignore copy + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (GemmaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + seqlens_in_batch=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in GemmaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length, seqlens_in_batch=seqlens_in_batch + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length, seqlens_in_batch): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask, seqlens_in_batch=seqlens_in_batch + ) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Gemma +class GemmaSdpaAttention(GemmaAttention): + """ + Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `GemmaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Ignore copy + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "GemmaModel is using GemmaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None and cache_position is not None: + causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +GEMMA_ATTENTION_CLASSES = { + "eager": GemmaAttention, + "flash_attention_2": GemmaFlashAttention2, + "sdpa": GemmaSdpaAttention, +} + + +# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with LLAMA->GEMMA,Llama->Gemma +class GemmaDecoderLayer(nn.Module): + def __init__(self, config: GemmaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = GEMMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = GemmaMLP(config) + self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + seqlens_in_batch: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + seqlens_in_batch=seqlens_in_batch, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +GEMMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GemmaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Gemma Model outputting raw hidden-states without any specific head on top.", + GEMMA_START_DOCSTRING, +) +class GemmaPreTrainedModel(PreTrainedModel): + config_class = GemmaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"] + _no_split_modules = ["GemmaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values", "causal_mask"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): + if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device: + causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device) + self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) + + for layer in self.model.layers: + weights = layer.self_attn.o_proj.weight + layer.self_attn.past_key_value = cache_cls( + self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype + ) + + def _reset_cache(self): + for layer in self.model.layers: + layer.self_attn.past_key_value = None + + +GEMMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Gemma Model outputting raw hidden-states without any specific head on top.", + GEMMA_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->GEMMA,Llama->Gemma +class GemmaModel(GemmaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`] + + Args: + config: GemmaConfig + """ + + def __init__(self, config: GemmaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # register a causal mask to separate causal and padding mask creation. Merging happends in the attention class + causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=1) + self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + seqlens_in_batch: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + past_seen_tokens = 0 + if use_cache: # kept for BC (cache positions) + if not isinstance(past_key_values, StaticCache): + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_seen_tokens = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) + + # embed positions + hidden_states = inputs_embeds + + # normalized + hidden_states = hidden_states * (self.config.hidden_size**0.5) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + seqlens_in_batch, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + seqlens_in_batch=seqlens_in_batch, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + ) + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask(self, attention_mask, input_tensor): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + batch_size, seq_length = input_tensor.shape[:2] + dtype = input_tensor.dtype + device = input_tensor.device + + # support going beyond cached `max_position_embedding` + if seq_length > self.causal_mask.shape[-1]: + causal_mask = torch.full((2 * self.causal_mask.shape[-1], 2 * self.causal_mask.shape[-1]), fill_value=1) + self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) + + if hasattr(self, "causal_mask"): # we use the current dtype to avoid any overflows + causal_mask = ( + self.causal_mask[None, None, :, :].repeat(batch_size, 1, 1, 1).to(dtype) * torch.finfo(dtype).min + ) + else: + mask = torch.full( + (self.config.max_position_embeddings, self.config.max_position_embeddings), + fill_value=torch.finfo(dtype).min, + ) + causal_mask = torch.triu(mask, diagonal=1) + + causal_mask = causal_mask.to(dtype=dtype, device=device) + if attention_mask is not None and attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) + causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill( + padding_mask, torch.finfo(dtype).min + ) + + if self.config._attn_implementation == "sdpa": + is_tracing = torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy) + if not is_tracing and attention_mask is not None and torch.any(attention_mask != 1): + causal_mask = causal_mask.mul(~torch.all(causal_mask == causal_mask.min(), dim=-1)[..., None]).to(dtype) + + return causal_mask + + +# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->GEMMA,Llama->Gemma,llama->gemma +class GemmaForCausalLM(GemmaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = GemmaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + # Ignore copy + @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + seqlens_in_batch: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GemmaForCausalLM + + >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") + >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") + + >>> prompt = "What is your favorite condiment?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "What is your favorite condiment?" + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + seqlens_in_batch=seqlens_in_batch, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + past_length = 0 + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None): + # generation with static cache + past_length = past_key_value.get_seq_length() + input_ids = input_ids[:, past_length:] + position_ids = position_ids[:, past_length:] + + # TODO @gante we should only keep a `cache_position` in generate, and do +=1. + # same goes for position ids. Could also help with continued generation. + cache_position = kwargs.get("cache_position", None) + if cache_position is None: + cache_position = torch.arange(past_length, past_length + position_ids.shape[-1], device=position_ids.device) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The Gemma Model transformer with a sequence classification head on top (linear layer). + + [`GemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GEMMA_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->GEMMA,Llama->Gemma +class GemmaForSequenceClassification(GemmaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = GemmaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/llava/train/transformers_replace/models/llama/configuring_llama.py b/llava/train/transformers_replace/models/llama/configuring_llama.py new file mode 100755 index 00000000..234767a6 --- /dev/null +++ b/llava/train/transformers_replace/models/llama/configuring_llama.py @@ -0,0 +1,188 @@ +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" LLaMA model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + +logger = logging.get_logger(__name__) + +LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class LlamaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LlamaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, + Llama 2 up to 4096, CodeLlama up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import LlamaModel, LlamaConfig + + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = LlamaConfig() + + >>> # Initializing a model from the llama-7b style configuration + >>> model = LlamaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "llama" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") diff --git a/llava/train/transformers_replace/models/llama/modeling_llama.py b/llava/train/transformers_replace/models/llama/modeling_llama.py new file mode 100755 index 00000000..dd8d1df3 --- /dev/null +++ b/llava/train/transformers_replace/models/llama/modeling_llama.py @@ -0,0 +1,1234 @@ +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch LLaMA model.""" +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from flash_attn import flash_attn_func, flash_attn_varlen_func +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import ALL_LAYERNORM_LAYERS +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_llama import LlamaConfig + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _get_unpad_data(attention_mask, seqlens_in_batch): + if seqlens_in_batch is None: + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + + @property + def sin_cached(self): + logger.warning_once( + "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" + ) + return self._sin_cached + + @property + def cos_cached(self): + logger.warning_once( + "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" + ) + return self._cos_cached + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat([F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + seqlens_in_batch: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # kv_seq_len = key_states.shape[-2] + # if past_key_value is not None: + # kv_seq_len += past_key_value[0].shape[-2] + + # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + seqlens_in_batch=seqlens_in_batch, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def flash_attn_varlen_func_helper(self): + raise NotImplementedError("This method should be used after being mocked, for supporting seq parallel.") + + def hybrid_attn_varlen_func_helper(self): + raise NotImplementedError("This method should be used after being mocked, for supporting seq parallel.") + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + seqlens_in_batch=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length, seqlens_in_batch=seqlens_in_batch + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=self.is_causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length, seqlens_in_batch): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask, seqlens_in_batch=seqlens_in_batch + ) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = ( + # LlamaAttention(config=config) + LlamaFlashAttention2(config=config) + ) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + seqlens_in_batch: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + seqlens_in_batch=seqlens_in_batch, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, decoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + seqlens_in_batch: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + + # embed positions + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = torch.utils.checkpoint.checkpoint( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + seqlens_in_batch, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + seqlens_in_batch=seqlens_in_batch, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.forward_step_id = 0 + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + seqlens_in_batch: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + seqlens_in_batch=seqlens_in_batch, + ) + + hidden_states = outputs[0] + + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to( + logits.device + ) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/llava/train/transformers_replace/models/llama/tokenization_llama.py b/llava/train/transformers_replace/models/llama/tokenization_llama.py new file mode 100755 index 00000000..e15259ef --- /dev/null +++ b/llava/train/transformers_replace/models/llama/tokenization_llama.py @@ -0,0 +1,480 @@ +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes for LLaMA.""" +import os +from shutil import copyfile +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from ...convert_slow_tokenizer import import_protobuf +from ...tokenization_utils import AddedToken, PreTrainedTokenizer +from ...utils import logging + +if TYPE_CHECKING: + from ...tokenization_utils_base import TextInput + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + }, + "tokenizer_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + }, +} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "hf-internal-testing/llama-tokenizer": 2048, +} +SPIECE_UNDERLINE = "▁" + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +# fmt: off +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" +# fmt: on + + +class LlamaTokenizer(PreTrainedTokenizer): + """ + Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is + no padding token in the original model. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + unk_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + bos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + eos_token (`str` or `tokenizers.AddedToken`, *optional*, defaults to `""`): + The end of sequence token. + pad_token (`str` or `tokenizers.AddedToken`, *optional*): + A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by + attention mechanisms or loss computation. + sp_model_kwargs (`Dict[str, Any]`, `Optional`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + add_bos_token (`bool`, *optional*, defaults to `True`): + Whether or not to add an `bos_token` at the start of sequences. + add_eos_token (`bool`, *optional*, defaults to `False`): + Whether or not to add an `eos_token` at the end of sequences. + clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like + extra spaces. + use_default_system_prompt (`bool`, *optional*, defaults to `False`): + Whether or not the default system prompt for Llama should be used. + spaces_between_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to add spaces between special tokens. + legacy (`bool`, *optional*): + Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622 + and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple + example: + + - `legacy=True`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=True) + >>> tokenizer.encode("Hello .") + [8774, 32099, 3, 5, 1] + ``` + - `legacy=False`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-base", legacy=False) + >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here + [8774, 32099, 5, 1] + ``` + Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details. + add_prefix_space (`bool`, *optional*, defaults to `True`): + Whether or not to add an initial space to the input. This allows to treat the leading word just as any + other word. + + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + use_default_system_prompt=False, + spaces_between_special_tokens=False, + legacy=None, + add_prefix_space=True, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = AddedToken(bos_token, normalized=False, special=True) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, normalized=False, special=True) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, normalized=False, special=True) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, normalized=False, special=True) if isinstance(pad_token, str) else pad_token + + if legacy is None: + logger.warning_once( + f"You are using the default legacy behaviour of the {self.__class__}. This is" + " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." + " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" + " means, and thoroughly read the reason why this was added as explained in" + " https://github.com/huggingface/transformers/pull/24565" + ) + legacy = True + + self.legacy = legacy + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.use_default_system_prompt = use_default_system_prompt + self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) + self.add_prefix_space = add_prefix_space + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + use_default_system_prompt=use_default_system_prompt, + spaces_between_special_tokens=spaces_between_special_tokens, + legacy=legacy, + add_prefix_space=add_prefix_space, + **kwargs, + ) + + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor + def get_spm_processor(self, from_slow=False): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + if self.legacy or from_slow: # no dependency on protobuf + tokenizer.Load(self.vocab_file) + return tokenizer + + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize + def tokenize(self, text: "TextInput", **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the + first token is special. + """ + if self.legacy or len(text) == 0: + return super().tokenize(text, **kwargs) + + text = text.replace(SPIECE_UNDERLINE, " ") + if self.add_prefix_space: + text = SPIECE_UNDERLINE + text + + tokens = super().tokenize(text, **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + tokens = self.sp_model.encode(text, out_type=str) + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return tokens + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + # since we manually add the prefix space, we have to remove it when decoding + if tokens[0].startswith(SPIECE_UNDERLINE) and self.add_prefix_space: + tokens[0] = tokens[0][1:] + + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for i, token in enumerate(tokens): + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special and i != 0 and self.legacy: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return ( + bos_token_id + + ([0] * len(token_ids_0)) + + eos_token_id + + bos_token_id + + ([0] * len(token_ids_1)) + + eos_token_id + ) + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + if token_ids_1 is None, only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output + + @property + def default_chat_template(self): + """ + LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. + Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict + user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering + rather than needing special tokens. The system message is partly 'embedded' in the first user message, which + results in an unusual token ordering when it is present. This template should definitely be changed if you wish + to fine-tune a model with more flexible role ordering! + + The output should look something like: + + [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer + [INST] Prompt [/INST] + + The reference for this chat template is [this code + snippet](https://github.com/facebookresearch/llama/blob/556949fdfb72da27c2f4a40b7f0e4cf0b8153a28/llama/generation.py#L320-L362) + in the original repository. + """ + logger.warning_once( + "\nNo chat template is defined for this tokenizer - using the default template " + f"for the {self.__class__.__name__} class. If the default is not appropriate for " + "your model, please set `tokenizer.chat_template` to an appropriate template. " + "See https://huggingface.co/docs/transformers/main/chat_templating for more information.\n" + ) + template = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" # Extract system message if it's present + "{% set system_message = messages[0]['content'] %}" + "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" + "{% set loop_messages = messages %}" # Or use the default system message if the flag is set + "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" # Loop over all non-system messages + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message + "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'system' %}" + "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}" + ) + template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") + default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") + template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) + + return template diff --git a/llava/train/transformers_replace/models/mistral/__init__.py b/llava/train/transformers_replace/models/mistral/__init__.py new file mode 100755 index 00000000..840af48d --- /dev/null +++ b/llava/train/transformers_replace/models/mistral/__init__.py @@ -0,0 +1,57 @@ +# Copyright 2023 Mistral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + +_import_structure = { + "configuration_mistral": ["MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MistralConfig"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_mistral"] = [ + "MistralForCausalLM", + "MistralModel", + "MistralPreTrainedModel", + "MistralForSequenceClassification", + ] + + +if TYPE_CHECKING: + from .configuration_mistral import MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MistralConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_mistral import ( + MistralForCausalLM, + MistralForSequenceClassification, + MistralModel, + MistralPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/llava/train/transformers_replace/models/mistral/configuration_mistral.py b/llava/train/transformers_replace/models/mistral/configuration_mistral.py new file mode 100755 index 00000000..834cb0dc --- /dev/null +++ b/llava/train/transformers_replace/models/mistral/configuration_mistral.py @@ -0,0 +1,150 @@ +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Mistral model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + +logger = logging.get_logger(__name__) + +MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json", + "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json", +} + + +class MistralConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an + Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1. + + [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) + [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MistralModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Mistral's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*, defaults to 4096): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import MistralModel, MistralConfig + + >>> # Initializing a Mistral 7B style configuration + >>> configuration = MistralConfig() + + >>> # Initializing a model from the Mistral 7B style configuration + >>> model = MistralModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mistral" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + sliding_window=4096, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/llava/train/transformers_replace/models/mistral/modeling_mistral.py b/llava/train/transformers_replace/models/mistral/modeling_mistral.py new file mode 100755 index 00000000..c234b410 --- /dev/null +++ b/llava/train/transformers_replace/models/mistral/modeling_mistral.py @@ -0,0 +1,1376 @@ +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Mistral model.""" +import inspect +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from flash_attn import flash_attn_func, flash_attn_varlen_func +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_mistral import MistralConfig + +_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MistralConfig" +Cache = Tuple[torch.Tensor] + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask, seqlens_in_batch): + if seqlens_in_batch is None: + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + else: + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mistral +class MistralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MistralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral +class MistralRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + + @property + def sin_cached(self): + logger.warning_once( + "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" + ) + return self._sin_cached + + @property + def cos_cached(self): + logger.warning_once( + "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" + ) + return self._cos_cached + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class MistralMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class MistralAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = MistralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class MistralFlashAttention2(MistralAttention): + """ + Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + seqlens_in_batch: Optional[torch.LongTensor] = None, + **kwargs, + ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + # if past_key_value is not None: + # if self.layer_idx is None: + # raise ValueError( + # f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + # "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + # "with a layer index." + # ) + # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # # Because the input can be padded, the absolute sequence length depends on the max position id. + # rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + # cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + seqlens_in_batch=seqlens_in_batch, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + seqlens_in_batch=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length, seqlens_in_batch=seqlens_in_batch + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length, seqlens_in_batch): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask, seqlens_in_batch=seqlens_in_batch + ) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral +class MistralSdpaAttention(MistralAttention): + """ + Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MistralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MistralAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MistralModel is using MistralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class MistralDecoderLayer(nn.Module): + def __init__(self, config: MistralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MistralFlashAttention2(config, layer_idx) + + self.mlp = MistralMLP(config) + self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + seqlens_in_batch: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + seqlens_in_batch=seqlens_in_batch, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +MISTRAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MistralConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Mistral Model outputting raw hidden-states without any specific head on top.", + MISTRAL_START_DOCSTRING, +) +class MistralPreTrainedModel(PreTrainedModel): + config_class = MistralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MistralDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MISTRAL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Mistral Model outputting raw hidden-states without any specific head on top.", + MISTRAL_START_DOCSTRING, +) +class MistralModel(MistralPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`] + + Args: + config: MistralConfig + """ + + def __init__(self, config: MistralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = "flash_attention_2" + self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + seqlens_in_batch: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + past_key_values_length = 0 + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = torch.utils.checkpoint.checkpoint( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + seqlens_in_batch, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + seqlens_in_batch=seqlens_in_batch, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class MistralForCausalLM(MistralPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MistralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + seqlens_in_batch: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MistralForCausalLM + + >>> model = MistralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + seqlens_in_batch=seqlens_in_batch, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The Mistral Model transformer with a sequence classification head on top (linear layer). + + [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MISTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mistral, LLAMA->MISTRAL +class MistralForSequenceClassification(MistralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MistralModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/llava/train/transformers_replace/models/mixtral/__init__.py b/llava/train/transformers_replace/models/mixtral/__init__.py new file mode 100755 index 00000000..09c0bea4 --- /dev/null +++ b/llava/train/transformers_replace/models/mixtral/__init__.py @@ -0,0 +1,57 @@ +# Copyright 2023 Mixtral AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available + +_import_structure = { + "configuration_mixtral": ["MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP", "MixtralConfig"], +} + + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_mixtral"] = [ + "MixtralForCausalLM", + "MixtralModel", + "MixtralPreTrainedModel", + "MixtralForSequenceClassification", + ] + + +if TYPE_CHECKING: + from .configuration_mixtral import MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP, MixtralConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_mixtral import ( + MixtralForCausalLM, + MixtralForSequenceClassification, + MixtralModel, + MixtralPreTrainedModel, + ) + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/llava/train/transformers_replace/models/mixtral/configuration_mixtral.py b/llava/train/transformers_replace/models/mixtral/configuration_mixtral.py new file mode 100755 index 00000000..8c247e89 --- /dev/null +++ b/llava/train/transformers_replace/models/mixtral/configuration_mixtral.py @@ -0,0 +1,167 @@ +# Copyright 2023 Mixtral AI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Mixtral model configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + +logger = logging.get_logger(__name__) + +MIXTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "mistral-ai/Mixtral-8x7B": "https://huggingface.co/mistral-ai/Mixtral-8x7B/resolve/main/config.json", +} + + +class MixtralConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`MixtralModel`]. It is used to instantiate an + Mixtral model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the Mixtral-7B-v0.1 or Mixtral-7B-Instruct-v0.1. + + [mixtralai/Mixtral-8x7B](https://huggingface.co/mixtralai/Mixtral-8x7B) + [mixtralai/Mixtral-7B-Instruct-v0.1](https://huggingface.co/mixtralai/Mixtral-7B-Instruct-v0.1) + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Mixtral model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MixtralModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to `4096*32`): + The maximum sequence length that this model might ever be used with. Mixtral's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 1000000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `4096`. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + num_experts_per_tok (`int`, *optional*, defaults to 2): + The number of experts to root per-token, can be also interpreted as the `top-p` routing + parameter + num_local_experts (`int`, *optional*, defaults to 8): + Number of experts per Sparse MLP layer. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabeling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + ```python + >>> from transformers import MixtralModel, MixtralConfig + + >>> # Initializing a Mixtral 7B style configuration + >>> configuration = MixtralConfig() + + >>> # Initializing a model from the Mixtral 7B style configuration + >>> model = MixtralModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mixtral" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/llava/train/transformers_replace/models/mixtral/modeling_mixtral.py b/llava/train/transformers_replace/models/mixtral/modeling_mixtral.py new file mode 100755 index 00000000..829fb139 --- /dev/null +++ b/llava/train/transformers_replace/models/mixtral/modeling_mixtral.py @@ -0,0 +1,1632 @@ +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Mixtral model.""" +import inspect +import math +import warnings +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from flash_attn import flash_attn_func, flash_attn_varlen_func +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...modeling_outputs import ModelOutput, SequenceClassifierOutputWithPast +from ...modeling_utils import PreTrainedModel +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ...utils.import_utils import is_torch_fx_available +from .configuration_mixtral import MixtralConfig + +_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MixtralConfig" +Cache = Tuple[torch.Tensor] + + +@dataclass +class MoeCausalLMOutputWithPast(ModelOutput): + """ + Base class for causal language model (or autoregressive) with mixture of experts outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + + aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided): + aux_loss for the sparse modules. + + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + aux_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class MoeModelOutputWithPast(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if + `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` + input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router logtis (post-softmax) that are computed by MoE routers, these terms are used to compute the auxiliary + loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None + + +def load_balancing_loss_func(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + num_experts (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + # treat `top_k` as tokens (shape is `top_k X [batch_size X sequence_length]`) + selected_experts = selected_experts.reshape(-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + expert_mask = torch.max(expert_mask, dim=-2).values + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(-1)) + return overall_loss * num_experts + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask, seqlens_in_batch): + if seqlens_in_batch is None: + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + else: + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral +class MixtralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MixtralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mixtral +class MixtralRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + t = t / self.scaling_factor + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) + self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) + + @property + def sin_cached(self): + logger.warning_once( + "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" + ) + return self._sin_cached + + @property + def cos_cached(self): + logger.warning_once( + "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " + "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" + ) + return self._cos_cached + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral +class MixtralAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + self.rotary_emb = MixtralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral +class MixtralFlashAttention2(MixtralAttention): + """ + Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + seqlens_in_batch: Optional[torch.LongTensor] = None, + **kwargs, + ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + # if past_key_value is not None: + # if self.layer_idx is None: + # raise ValueError( + # f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + # "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + # "with a layer index." + # ) + # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # # Because the input can be padded, the absolute sequence length depends on the max position id. + # rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + # cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + seqlens_in_batch=seqlens_in_batch, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + seqlens_in_batch=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length, seqlens_in_batch=seqlens_in_batch + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length, seqlens_in_batch): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask, seqlens_in_batch=seqlens_in_batch + ) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mixtral +class MixtralSdpaAttention(MixtralAttention): + """ + Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from MixtralAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + is_causal=self.is_causal and attention_mask is None and q_len > 1, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +class MixtralBLockSparseTop2MLP(nn.Module): + def __init__(self, config: MixtralConfig): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accomodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class MixtralDecoderLayer(nn.Module): + def __init__(self, config: MixtralConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MixtralFlashAttention2(config, layer_idx) + + self.block_sparse_moe = MixtralSparseMoeBlock(config) + self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + seqlens_in_batch: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + seqlens_in_batch=seqlens_in_batch, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +MIXTRAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MixtralConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral +class MixtralPreTrainedModel(PreTrainedModel): + config_class = MixtralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MixtralDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MIXTRAL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral +class MixtralModel(MixtralPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] + + Args: + config: MixtralConfig + """ + + def __init__(self, config: MixtralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = "flash_attention_2" + self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Ignore copy + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + seqlens_in_batch: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + past_key_values_length = 0 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + is_padding_right = attention_mask[:, -1].sum().item() != batch_size + if is_padding_right: + raise ValueError( + "You are attempting to perform batched generation with padding_side='right'" + " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " + " call `tokenizer.padding_side = 'left'` before tokenizing the input. " + ) + + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = torch.utils.checkpoint.checkpoint( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + seqlens_in_batch, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + seqlens_in_batch=seqlens_in_batch, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +class MixtralForCausalLM(MixtralPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = MixtralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + seqlens_in_batch: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + seqlens_in_batch=seqlens_in_batch, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits if return_dict else outputs[-1], self.num_experts, self.num_experts_per_tok + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (aux_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The Mixtral Model transformer with a sequence classification head on top (linear layer). + + [`MixtralForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Mixtral, LLAMA->MIXTRAL +class MixtralForSequenceClassification(MixtralPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = MixtralModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/tests/python/test_packing.py b/tests/python/test_packing.py index 6b67c355..13d95d36 100644 --- a/tests/python/test_packing.py +++ b/tests/python/test_packing.py @@ -14,10 +14,10 @@ # # SPDX-License-Identifier: Apache-2.0 +import copy import unittest import torch -from parameterized import parameterized from transformers import AutoTokenizer from llava.model import LlavaLlamaConfig, LlavaLlamaModel @@ -31,16 +31,14 @@ else: device = "cpu" -MODELS = ["lmsys/vicuna-7b-v1.5", "Qwen/Qwen2-7B"] - class TestInputPacking(unittest.TestCase): - @parameterized.expand(MODELS) - def test_loss_close(self, model_name_or_path: str): + def setUp(self): + # This test is supposed to run on a single GPU if torch.cuda.is_available(): rank = 0 torch.cuda.set_device(rank) - + model_name_or_path = "lmsys/vicuna-7b-v1.5" self.model_args = ModelArguments( model_name_or_path=model_name_or_path, version="v1", @@ -50,10 +48,10 @@ def test_loss_close(self, model_name_or_path: str): mm_use_im_patch_token=False, ) self.data_args = DataArguments() - self.training_args = TrainingArguments(output_dir="", bf16=True) + self.training_args = TrainingArguments(output_dir="") self.config = LlavaLlamaConfig.from_pretrained(model_name_or_path) prepare_config_for_training(self.config, self.model_args, self.training_args, self.data_args) - + print("Initializing tokenizer...") self.tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, model_max_length=4096, @@ -61,30 +59,91 @@ def test_loss_close(self, model_name_or_path: str): use_fast=False, legacy=False, ) + print("Initializing LlavaLlamaModel...") + self.model = LlavaLlamaModel(config=self.config) + self.model.vision_tower = self.model.vision_tower.to(device) + self.model.mm_projector = self.model.mm_projector.to(device) + self.model = self.model.to(device) - model = LlavaLlamaModel(config=self.config, attn_implementation="flash_attention_2") - model.vision_tower = model.vision_tower.to(device) - model.mm_projector = model.mm_projector.to(device) - model = model.to(device) - model.llm.pad_token_id = self.tokenizer.pad_token_id - + print("Initializing data...") data = torch.load("tests/sample_data/test_packing.pth") + # necessary for model forward + self.model.llm.pad_token_id = self.tokenizer.pad_token_id + self.data = data + + # @requires_gpu() + def test_loss_close(self): + print("Preprocessing inputs...") + data = copy.deepcopy(self.data) data["input_ids"] = data["input_ids"].to(device) - data["images"] = data["images"].to(torch.bfloat16).to(device) + data["images"] = data["images"].to(torch.float16).to(device) data["attention_mask"] = data["attention_mask"].to(device) data["labels"] = data["labels"].to(device) + data["position_ids"] = None + data["past_key_values"] = None + + ( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels, + ) = self.model.prepare_inputs_labels_for_multimodal(**data) + + print("Packing inputs...") + ( + _, + new_position_ids, + new_attention_mask, + _, + new_inputs_embeds, + new_labels, + sorted_seqlens_in_batch, + ) = self.model.repack_multimodal_data( + input_ids, + position_ids, + attention_mask, + past_key_values, + inputs_embeds, + labels, + ) - model.train() - - output = model(**data, packing=True, use_cache=False) - output_loss = output.loss.item() - del output - - target = model(**data, packing=False, use_cache=False) - target_loss = target.loss.item() - del target - - self.assertAlmostEqual(output_loss, target_loss, places=2) + print("Running models...") + + with torch.no_grad(): + # forward results with input packing + outputs = self.model.llm.forward( + input_ids=None, + attention_mask=new_attention_mask, + position_ids=new_position_ids, + past_key_values=None, + inputs_embeds=new_inputs_embeds, + labels=new_labels, + use_cache=False, + output_attentions=True, + output_hidden_states=True, + return_dict=True, + seqlens_in_batch=sorted_seqlens_in_batch, + ) + # forward results without input packing + outputs_ref = self.model.llm.forward( + input_ids=None, + attention_mask=attention_mask.to(device), + position_ids=None, + past_key_values=None, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=False, + output_attentions=True, + output_hidden_states=True, + return_dict=True, + ) + # loss should be very similar (but not the same due to numerical precision issues) + loss = outputs.loss + loss_ref = outputs_ref.loss + print("loss =", loss, "loss_ref =", loss_ref) + self.assertAlmostEqual(loss.item(), loss_ref.item(), places=2) if __name__ == "__main__":