From 5df38737a5e87e885afb66e58db18289b9db6341 Mon Sep 17 00:00:00 2001 From: lvdongyi Date: Mon, 2 Dec 2024 13:16:58 +0800 Subject: [PATCH] stack commits --- llm/config/qwen2moe/lora_argument.json | 34 --- llm/config/qwen2moe/pretrain_argument.json | 40 ---- llm/config/qwen2moe/sft_argument.json | 33 --- llm/experimental/layers/cache_kv.py | 182 ++++++++++++++++ llm/experimental/layers/custom_attention.py | 91 ++++++++ llm/experimental/observer/abs_max.py | 100 +++++++++ llm/experimental/observer/abs_max_headwise.py | 92 ++++++++ llm/experimental/observer/avg.py | 102 +++++++++ llm/experimental/observer/avg_headwise.py | 105 +++++++++ llm/experimental/observer/channel_wise.py | 50 +++++ llm/experimental/observer/uniform.py | 114 ++++++++++ paddlenlp/generation/utils.py | 89 +++++++- paddlenlp/transformers/__init__.py | 1 + paddlenlp/transformers/cache_utils.py | 158 ++++++++++++++ paddlenlp/transformers/llama/modeling.py | 107 +++++---- paddlenlp/transformers/model_utils.py | 3 + tests/test_cache_utils.py | 205 ++++++++++++++++++ tests/transformers/test_generation_utils.py | 69 ++++++ 18 files changed, 1422 insertions(+), 153 deletions(-) delete mode 100644 llm/config/qwen2moe/lora_argument.json delete mode 100644 llm/config/qwen2moe/pretrain_argument.json delete mode 100644 llm/config/qwen2moe/sft_argument.json create mode 100644 llm/experimental/layers/cache_kv.py create mode 100644 llm/experimental/layers/custom_attention.py create mode 100644 llm/experimental/observer/abs_max.py create mode 100644 llm/experimental/observer/abs_max_headwise.py create mode 100644 llm/experimental/observer/avg.py create mode 100644 llm/experimental/observer/avg_headwise.py create mode 100644 llm/experimental/observer/channel_wise.py create mode 100644 llm/experimental/observer/uniform.py create mode 100644 paddlenlp/transformers/cache_utils.py mode change 100755 => 100644 paddlenlp/transformers/llama/modeling.py create mode 100644 tests/test_cache_utils.py diff --git a/llm/config/qwen2moe/lora_argument.json b/llm/config/qwen2moe/lora_argument.json deleted file mode 100644 index 47e7adb14ecd..000000000000 --- a/llm/config/qwen2moe/lora_argument.json +++ /dev/null @@ -1,34 +0,0 @@ -{ - "model_name_or_path": "Qwen/Qwen2-57B-A14B", - "dataset_name_or_path": "./data", - "output_dir": "./checkpoints/lora_ckpts", - "per_device_train_batch_size": 4, - "gradient_accumulation_steps": 4, - "per_device_eval_batch_size": 8, - "eval_accumulation_steps":16, - "num_train_epochs": 3, - "learning_rate": 3e-04, - "warmup_steps": 30, - "logging_steps": 1, - "evaluation_strategy": "epoch", - "save_strategy": "epoch", - "src_length": 1024, - "max_length": 2048, - "bf16": true, - "fp16_opt_level": "O2", - "do_train": true, - "do_eval": true, - "disable_tqdm": true, - "load_best_model_at_end": true, - "eval_with_do_generation": false, - "metric_for_best_model": "accuracy", - "recompute": true, - "save_total_limit": 1, - "tensor_parallel_degree": 1, - "pipeline_parallel_degree": 1, - "lora": true, - "unified_checkpoint": true, - "zero_padding": false, - "use_flash_attention": true, - "pissa": false - } diff --git a/llm/config/qwen2moe/pretrain_argument.json b/llm/config/qwen2moe/pretrain_argument.json deleted file mode 100644 index f3115a64b648..000000000000 --- a/llm/config/qwen2moe/pretrain_argument.json +++ /dev/null @@ -1,40 +0,0 @@ -{ - "model_name_or_path": "Qwen/Qwen2-57B-A14B", - "tokenizer_name_or_path": "Qwen/Qwen2-57B-A14B", - "input_dir": "./data", - "output_dir": "./checkpoints/pretrain_ckpts", - "per_device_train_batch_size": 2, - "gradient_accumulation_steps": 1, - "per_device_eval_batch_size": 2, - "tensor_parallel_degree": 1, - "pipeline_parallel_degree": 1, - "sharding": "stage2", - "virtual_pp_degree": 1, - "sequence_parallel": 0, - "use_flash_attention": true, - "use_fused_rms_norm": true, - "max_seq_length": 4096, - "learning_rate": 3e-05, - "min_learning_rate": 3e-06, - "warmup_steps": 30, - "logging_steps": 1, - "max_steps": 10000, - "save_steps": 5000, - "eval_steps": 1000, - "weight_decay": 0.01, - "bf16": true, - "fp16_opt_level": "O2", - "warmup_ratio": 0.01, - "max_grad_norm": 1.0, - "dataloader_num_workers": 1, - "continue_training": 1, - "do_train": true, - "do_eval": true, - "do_predict": true, - "disable_tqdm": true, - "recompute": true, - "distributed_dataloader": 1, - "recompute_granularity": "full", - "unified_checkpoint": true, - "save_total_limit": 2 - } diff --git a/llm/config/qwen2moe/sft_argument.json b/llm/config/qwen2moe/sft_argument.json deleted file mode 100644 index c964137f2264..000000000000 --- a/llm/config/qwen2moe/sft_argument.json +++ /dev/null @@ -1,33 +0,0 @@ -{ - "model_name_or_path": "Qwen/Qwen2-57B-A14B", - "dataset_name_or_path": "./data", - "output_dir": "./checkpoints/sft_ckpts", - "per_device_train_batch_size": 1, - "gradient_accumulation_steps": 4, - "per_device_eval_batch_size": 8, - "eval_accumulation_steps":16, - "num_train_epochs": 3, - "learning_rate": 3e-05, - "warmup_steps": 30, - "logging_steps": 1, - "evaluation_strategy": "epoch", - "save_strategy": "epoch", - "src_length": 1024, - "max_length": 2048, - "bf16": true, - "fp16_opt_level": "O2", - "do_train": true, - "do_eval": true, - "disable_tqdm": true, - "load_best_model_at_end": true, - "eval_with_do_generation": false, - "metric_for_best_model": "accuracy", - "recompute": true, - "save_total_limit": 1, - "tensor_parallel_degree": 1, - "pipeline_parallel_degree": 1, - "sharding": "stage2", - "zero_padding": false, - "unified_checkpoint": true, - "use_flash_attention": true - } diff --git a/llm/experimental/layers/cache_kv.py b/llm/experimental/layers/cache_kv.py new file mode 100644 index 000000000000..e159ae8f5096 --- /dev/null +++ b/llm/experimental/layers/cache_kv.py @@ -0,0 +1,182 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +from paddle import ParamAttr +from paddle.nn import Layer +from paddle.nn.initializer import Constant +from paddle.nn.quant.format import ConvertibleQuantedLayer + + +class CacheKVMatMul(Layer): + def __init__(self): + super().__init__() + + def forward(self, x, y, transpose_x=False, transpose_y=False, name=None): + return paddle.matmul(x, y, transpose_x, transpose_y, name) + + +class QuantizedCacheKVMatMul(ConvertibleQuantedLayer): + def __init__(self, layer: Layer, q_config): + super().__init__() + # For FakeQuant + self.activation_quanter = None + self.weight_quanter = None + if q_config.activation is not None: + self.activation_quanter = q_config.activation._instance(layer) + + def forward(self, x, y, transpose_x=False, transpose_y=False, name=None): + # qdq + if self.activation_quanter is not None: + y = self.activation_quanter(y) + return paddle.matmul(x, y, transpose_x, transpose_y, name) + + def weights_to_quanters(self): + return [("weight", "weight_quanter")] + + def activation_quanters(self): + return ["activation_quanter"] + + +class ShiftSmoothCacheKVMatMul(Layer): + """ + The computational logic of ShiftSmoothCacheKVMatMul is the same as CacheKVMatMul. + The only difference is that its inputs are shift. + """ + + def __init__(self): + super().__init__() + self.sequence_parallel = False + self.dtype = None + + def forward( + self, + x, + y, + transpose_x=False, + transpose_y=False, + perm_x=None, + perm_y=None, + use_smooth_x=False, + use_smooth_out=False, + name=None, + sequence_parallel=False, + ): + self.sequence_parallel = sequence_parallel + # smooth + smooth_x, smooth_y = self._smooth(x, y, use_smooth_x) + # transpose + if perm_x is not None: + smooth_x = paddle.transpose(smooth_x, perm=perm_x) + if perm_y is not None: + smooth_y = paddle.transpose(smooth_y, perm=perm_y) + # matmul output + out = paddle.matmul(smooth_x, smooth_y, transpose_x, transpose_y, name) + if not use_smooth_out: + return out + else: + # combine heads + if self.sequence_parallel: + out = paddle.transpose(out, perm=[2, 0, 1, 3]) + else: + out = paddle.transpose(out, perm=[0, 2, 1, 3]) + return paddle.multiply(out, self.smooth_weight) + + def _smooth(self, x, y, use_smooth_x): + # For ShiftSmooth + smooth_shape = [1] + self.dtype = y.dtype + if not hasattr(self, "smooth_weight"): + self.smooth_weight = self.create_parameter( + shape=smooth_shape, attr=ParamAttr(initializer=Constant(value=1.0)), dtype=self.dtype + ) + smooth_y = y + smooth_y = paddle.divide(smooth_y, self.smooth_weight) + + if use_smooth_x: + smooth_x = x + x = paddle.multiply(smooth_x, self.smooth_weight) + return x, smooth_y + + def convert_weight(self, smooth_weight=None): + if smooth_weight is not None: + self.smooth_weight.set_value(smooth_weight.squeeze().cast(self.dtype)) + + +class QuantizedShiftSmoothCacheKVMatMul(ConvertibleQuantedLayer): + """ + The computational logic of QuantizedShiftSmoothCacheKVMatMul is the same as RowParallelLinear. + The only difference is that its inputs are shift. + """ + + def __init__(self, layer: Layer, q_config): + super().__init__() + + # For FakeQuant + self.weight_quanter = None + self.activation_quanter = None + self.smooth_weight = layer.smooth_weight + if q_config.activation is not None: + self.activation_quanter = q_config.activation._instance(layer) + + def forward( + self, + x, + y, + transpose_x=False, + transpose_y=False, + perm_x=None, + perm_y=None, + use_smooth_x=False, + use_smooth_out=False, + name=None, + sequence_parallel=False, + ): + # smooth + smooth_x, smooth_y = self._smooth(x, y, use_smooth_x) + # qdq + if self.activation_quanter is not None: + smooth_y = self.activation_quanter(smooth_y) + # transpose + if perm_x is not None: + smooth_x = paddle.transpose(smooth_x, perm=perm_x) + if perm_y is not None: + smooth_y = paddle.transpose(smooth_y, perm=perm_y) + # matmul output + out = paddle.matmul(smooth_x, smooth_y, transpose_x, transpose_y, name) + if not use_smooth_out: + return out + else: + # combine heads + if sequence_parallel: + out = paddle.transpose(out, perm=[2, 0, 1, 3]) + else: + out = paddle.transpose(out, perm=[0, 2, 1, 3]) + return paddle.multiply(out, self.smooth_weight) + + def _smooth(self, x, y, use_smooth_x): + # For ShiftSmooth + self.dtype = y.dtype + smooth_y = y + smooth_y = paddle.divide(smooth_y, self.smooth_weight) + + if use_smooth_x: + smooth_x = x + x = paddle.multiply(smooth_x, self.smooth_weight) + return x, smooth_y + + def weights_to_quanters(self): + return [("weight", "weight_quanter")] + + def activation_quanters(self): + return ["activation_quanter"] diff --git a/llm/experimental/layers/custom_attention.py b/llm/experimental/layers/custom_attention.py new file mode 100644 index 000000000000..c40c815b3346 --- /dev/null +++ b/llm/experimental/layers/custom_attention.py @@ -0,0 +1,91 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Custome Attention Layer for quantization. +""" +# import paddle +import paddle.tensor as tensor +from paddle.nn import Layer +from paddle.nn.quant.format import ConvertibleQuantedLayer + + +class QuantizedCustomAttentionLayer(ConvertibleQuantedLayer): + """ + Quantized Custom Attention Layer. + """ + + def __init__(self, layer: Layer, q_config=None): + """ + Initialize the QuantizeWrapper class. + + Args: + layer (Layer): The layer to be quantized. + q_config (QuantConfig, optional): The quantization configuration. Defaults to None. + """ + super().__init__() + # hard code: get activation quanter from weight + self.activation_quanter_k = q_config.weight._instance(layer) + self.activation_quanter_v = q_config.activation._instance(layer) + self.layer = layer + self.enable_fake_quant = False + self.quant_info = None + layer_name = self.layer.full_name() + self.layer_id = int(layer_name.split("_")[-1]) + self.kv_losses = {} + + def forward( + self, + q, + config, + k, + v, + attention_mask, + output_attentions, + # alibi, + # attn_mask_startend_row_indices, + # sequence_parallel, + **kwargs + ): + """forward""" + if self.enable_fake_quant: + self.collect_kv_quant_policy(q, k, v, **kwargs) + perm = [0, 2, 1, 3] # [1, 2, 0, 3] if self.sequence_parallel else [0, 2, 1, 3] + tmp_k = tensor.transpose(x=k, perm=perm) + tmp_v = tensor.transpose(x=v, perm=perm) + if self.activation_quanter_k is not None: + tmp_k = self.activation_quanter_k(tmp_k) + if self.activation_quanter_v is not None: + tmp_v = self.activation_quanter_v(tmp_v) + k = tensor.transpose(x=tmp_k, perm=perm) + v = tensor.transpose(x=tmp_v, perm=perm) + return self.layer( + q, + config, + k, + v, + attention_mask, + output_attentions, + # alibi, + # attn_mask_startend_row_indices, + # sequence_parallel, + **kwargs, + ) + + def weights_to_quanters(self): + """weights to quanters""" + return [] + + def activation_quanters(self): + """activation to quanters""" + return ["activation_quanter_k", "activation_quanter_v"] diff --git a/llm/experimental/observer/abs_max.py b/llm/experimental/observer/abs_max.py new file mode 100644 index 000000000000..9d30db49cba3 --- /dev/null +++ b/llm/experimental/observer/abs_max.py @@ -0,0 +1,100 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.quantization.factory import ObserverFactory + +from .uniform import UniformObserver + + +class AbsmaxObserver(ObserverFactory): + r""" + It collects maximum absolute values of target tensor. + Args: + bit_length(int, optional): Number of bits to represent an quantized integer in binary. + dtype(str, optional): The data type of input tensor. + name (str, optional): This parameter is used by developers to print debugging information. \ + For details, please refer to :ref:`api_guide_Name`. Default is None. + Examples: + .. code-block:: python + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver + quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99) + q_config = QuantConfig(activation=quanter, weight=quanter) + """ + + def __init__(self, quant_bits=8): + super(AbsmaxObserver, self).__init__(quant_bits=quant_bits) + + def _get_class(self): + return AbsmaxObserverLayer + + +class AbsmaxObserverLayer(UniformObserver): + def __init__( + self, + layer, + quant_bits=8, + ): + super(AbsmaxObserverLayer, self).__init__(quant_bits=quant_bits) + self._quant_bits = quant_bits + self._layer = layer + self._scale = None + self._zero_point = None + self._min = None + self._max = paddle.to_tensor(1e-7, dtype="float32") + self.step = 0 + + def forward(self, inputs): + """Calculate forward pass.""" + self._min, self._max = self.cal_min_max(inputs) + return inputs + + def cal_min_max(self, inputs): + abs_max_val = paddle.max(paddle.abs(inputs.cast("float32"))) + abs_max_val = paddle.maximum(abs_max_val, self._max) + return 0, abs_max_val + + def cal_thresholds(self): + """Compute thresholds for MAX function.""" + if self._scale is not None: + self._zero_point = 0 + return + self._scale, self._zero_point = self.cal_scales_zero_points() + + def min_value(self) -> float: + return self._min + + def max_value(self) -> float: + return self._max + + def bit_length(self): + """Return the bit length of quantized data.""" + return self._quant_bits + + def quant_axis(self): + """Return quantization axis.""" + return -1 + + def scales(self): + """Return output scales.""" + if self._scale is None: + self.cal_thresholds() + return self._scale + + def zero_points(self): + """Return output zero points.""" + if self._zero_point is None: + self.cal_thresholds() + return self._zero_point diff --git a/llm/experimental/observer/abs_max_headwise.py b/llm/experimental/observer/abs_max_headwise.py new file mode 100644 index 000000000000..500fbfa1ff55 --- /dev/null +++ b/llm/experimental/observer/abs_max_headwise.py @@ -0,0 +1,92 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from experimental.observer.channel_wise import ChannelWiseObserver +from paddle.quantization.factory import ObserverFactory + + +class AbsMaxHeadwiseObserver(ObserverFactory): + r""" + It collects channel-wise maximum absolute values of target weights. + Args: + bit_length(int, optional): Number of bits to represent an quantized integer in binary. + dtype(str, optional): The data type of input tensor. + name (str, optional): This parameter is used by developers to print debugging information. \ + For details, please refer to :ref:`api_guide_Name`. Default is None. + Examples: + .. code-block:: python + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import AbsMaxHeadwiseObserver + quanter = AbsMaxHeadwiseObserver() + q_config = QuantConfig(activation=None, weight=quanter) + """ + + def __init__(self, quant_bits=8, quant_axis=None): + super(AbsMaxHeadwiseObserver, self).__init__(quant_bits=quant_bits, quant_axis=quant_axis) + + def _get_class(self): + return AbsMaxHeadwiseObserverLayer + + +class AbsMaxHeadwiseObserverLayer(ChannelWiseObserver): + def __init__(self, layer, quant_bits=8, quant_axis=None): + super(AbsMaxHeadwiseObserverLayer, self).__init__( + layer, quant_bits=quant_bits, sign=True, symmetric=True, quant_axis=quant_axis + ) + self.quant_bits = quant_bits + self.calibration_loss = float("inf") + self.qmin, self.qmax = self.qmin_qmax + self._layer = layer + self._max = None + self._scale = None + self._zero_point = None + + def forward(self, inputs): + self._max = self._cal_abs_max(inputs) + return inputs + + def _cal_abs_max(self, inputs): + reduce_axis = tuple([i for i in range(len(inputs.shape)) if i != self.quant_axis()]) + abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis).cast("float32") + abs_max_values = paddle.where(abs_max_values == np.float32(0.0), np.float32(1e-8), abs_max_values) + + if self._max is not None: + abs_max_values = paddle.maximum(abs_max_values, self._max) + + return abs_max_values + + def min_value(self) -> float: + return 0.0 + + def max_value(self) -> float: + return self._max + + def cal_thresholds(self): + """Compute thresholds for MAX function.""" + self._scale = self._max + self._zero_point = paddle.zeros_like(self._scale) + + def scales(self): + """Return output scales.""" + if self._scale is None: + self.cal_thresholds() + return self._scale + + def zero_points(self): + """Return output zero points.""" + if self._zero_point is None: + self.cal_thresholds() + return self._zero_point diff --git a/llm/experimental/observer/avg.py b/llm/experimental/observer/avg.py new file mode 100644 index 000000000000..c38b3ec45c78 --- /dev/null +++ b/llm/experimental/observer/avg.py @@ -0,0 +1,102 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.quantization.factory import ObserverFactory + +from .uniform import UniformObserver + + +class AVGObserver(ObserverFactory): + r""" + It collects maximum absolute values of target tensor. + Args: + bit_length(int, optional): Number of bits to represent an quantized integer in binary. + dtype(str, optional): The data type of input tensor. + name (str, optional): This parameter is used by developers to print debugging information. \ + For details, please refer to :ref:`api_guide_Name`. Default is None. + Examples: + .. code-block:: python + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver + quanter = FakeQuanterWithAbsMaxObserver(moving_rate=0.99) + q_config = QuantConfig(activation=quanter, weight=quanter) + """ + + def __init__(self, quant_bits=8): + super(AVGObserver, self).__init__(quant_bits=quant_bits) + + def _get_class(self): + return AVGObserverLayer + + +class AVGObserverLayer(UniformObserver): + def __init__( + self, + layer, + quant_bits=8, + ): + super(AVGObserverLayer, self).__init__(quant_bits=quant_bits) + self._quant_bits = quant_bits + self._avg_list = [] + + def forward(self, inputs): + """Calculate forward pass.""" + self._scale = None + self._zero_point = None + self._min = None + self._max = None + self._avg_min, self._avg_max = self.cal_min_max(inputs) + self._avg_list.append(self._avg_max) + + return inputs + + def cal_min_max(self, inputs): + abs_avg_value = paddle.abs(inputs.reshape((inputs.shape[0], -1))) + abs_avg_value = float(paddle.mean(paddle.max(abs_avg_value, axis=(1)))) + return 0, abs_avg_value + + def cal_thresholds(self): + """Compute thresholds for MAX function.""" + if self._scale is not None: + self._zero_point = 0 + return + self._min, self._max = self._avg_min, paddle.mean(paddle.to_tensor(self._avg_list)) + self._scale, self._zero_point = self.cal_scales_zero_points() + + def min_value(self) -> float: + return self._min + + def max_value(self) -> float: + return self._max + + def bit_length(self): + """Return the bit length of quantized data.""" + return self._quant_bits + + def quant_axis(self): + """Return quantization axis.""" + return -1 + + def scales(self): + """Return output scales.""" + if self._scale is None: + self.cal_thresholds() + return self._scale + + def zero_points(self): + """Return output zero points.""" + if self._zero_point is None: + self.cal_thresholds() + return self._zero_point diff --git a/llm/experimental/observer/avg_headwise.py b/llm/experimental/observer/avg_headwise.py new file mode 100644 index 000000000000..a25fbd770019 --- /dev/null +++ b/llm/experimental/observer/avg_headwise.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from paddle.quantization.factory import ObserverFactory + +from .abs_max_headwise import AbsMaxHeadwiseObserverLayer + + +class AvgHeadwiseObserver(ObserverFactory): + r""" + It collects channel-wise maximum absolute values of target weights. + Args: + bit_length(int, optional): Number of bits to represent an quantized integer in binary. + dtype(str, optional): The data type of input tensor. + name (str, optional): This parameter is used by developers to print debugging information. \ + For details, please refer to :ref:`api_guide_Name`. Default is None. + Examples: + .. code-block:: python + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import AbsMaxHeadwiseObserver + quanter = AbsMaxHeadwiseObserver() + q_config = QuantConfig(activation=None, weight=quanter) + """ + + def __init__(self, quant_bits=8, quant_axis=None, moving_avg=False): + super(AvgHeadwiseObserver, self).__init__(quant_bits=quant_bits, quant_axis=quant_axis, moving_avg=moving_avg) + + def _get_class(self): + return AvgHeadwiseObserverLayer + + +class AvgHeadwiseObserverLayer(AbsMaxHeadwiseObserverLayer): + def __init__(self, layer, quant_bits=8, quant_axis=None, moving_avg=True): + super(AvgHeadwiseObserverLayer, self).__init__(layer, quant_bits=quant_bits, quant_axis=quant_axis) + self.quant_bits = quant_bits + self._qmin, self._qmax = self.qmin_qmax + self._max = None + self._scale = None + self._zero_point = None + if quant_axis is not None: + self._channel_axis = quant_axis + self._current_iters = 0 + self._range_update_factor_min = 0.001 + self._moving_avg = moving_avg + self.observer_enabled = True + + def forward(self, inputs, quant_axis=None): + if self.observer_enabled: + if quant_axis is not None: + self._channel_axis = quant_axis + self._max = self._cal_abs_max(inputs) + return inputs + + def _cal_abs_max(self, inputs): + self._current_iters += 1 + reduce_axis = tuple([i for i in range(len(inputs.shape)) if i != self.quant_axis()]) + abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis).cast("float32") + abs_max_values = paddle.where(abs_max_values == np.float32(0.0), np.float32(1e-8), abs_max_values) + if self._max is not None: + if self._moving_avg: + # exponential moving average update + update_factor = 1.0 / self._current_iters + update_factor = max(update_factor, self._range_update_factor_min) + abs_max_values = self._max * (1 - update_factor) + abs_max_values * update_factor + else: + # normal average + abs_max_values = (self._max * (self._current_iters - 1) + abs_max_values) / self._current_iters + return abs_max_values + + def min_value(self) -> float: + return 0.0 + + def max_value(self) -> float: + return self._max + + def cal_thresholds(self): + """Compute thresholds for MAX function.""" + if self._scale is not None: + self._zero_point = paddle.zeros_like(self._scale) + return + self._scale = self._max + self._zero_point = paddle.zeros_like(self._scale) + + def scales(self): + """Return output scales.""" + self.cal_thresholds() + return self._scale + + def zero_points(self): + """Return output zero points.""" + self.cal_thresholds() + return self._zero_point diff --git a/llm/experimental/observer/channel_wise.py b/llm/experimental/observer/channel_wise.py new file mode 100644 index 000000000000..883a74a8f9b0 --- /dev/null +++ b/llm/experimental/observer/channel_wise.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import paddle +from experimental.layers.cache_kv import CacheKVMatMul +from paddleslim.quant.observers.uniform import UniformObserver + +CHANNEL_AXIS: Dict[type, int] = { + paddle.nn.Conv2D: 0, + paddle.nn.Linear: 1, + paddle.distributed.fleet.meta_parallel.ColumnParallelLinear: 1, + paddle.distributed.fleet.meta_parallel.RowParallelLinear: 1, + CacheKVMatMul: 1, +} + + +class ChannelWiseObserver(UniformObserver): + def __init__(self, layer, quant_bits=8, sign=True, symmetric=True, quant_axis=None): + super(ChannelWiseObserver, self).__init__( + quant_bits=quant_bits, + sign=sign, + symmetric=symmetric, + ) + if quant_axis is not None: + self._channel_axis = quant_axis + else: + assert type(layer) in CHANNEL_AXIS, "Unsupported layer type: {}".format(type(layer)) + self._channel_axis = CHANNEL_AXIS[type(layer)] + self._quant_bits = quant_bits + + def quant_axis(self): + """Return quantization axis.""" + return self._channel_axis + + def bit_length(self): + """Return the bit length of quantized data.""" + return self._quant_bits diff --git a/llm/experimental/observer/uniform.py b/llm/experimental/observer/uniform.py new file mode 100644 index 000000000000..6c8882f5142f --- /dev/null +++ b/llm/experimental/observer/uniform.py @@ -0,0 +1,114 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Tuple + +import numpy as np +from paddle.quantization.base_observer import BaseObserver + + +class UniformObserver(BaseObserver): + """This is the base class for a uniform quantization observer, which provides + common functions for calculating the scale and zero-point used in uniform quantization. + Uniform quantization maps floating point values to integers, where the scale determines + the step size of the quantizer and the floating point zero is mapped to the zero-point, + an integer value ensuring that zero is quantized without error. + + Args: + quant_bits (int): The number of bits for quantization. + sign (bool): Whether the quantized integer includes a sign. + symmetric (bool): Whether it is symmetric quantization. the quantization is symmetric. + In symmetric quantization, the range of floating point values is relaxed to be symmetric + around zero and the zero-point is always 0. + + """ + + def __init__( + self, + quant_bits=8, + sign=True, + symmetric=True, + ): + super(UniformObserver, self).__init__() + self._quant_bits = quant_bits + self._sign = sign + self._symmetric = symmetric + + self._min = None + self._max = None + self._qmin = None + self._qmax = None + + self._scale = None + self._zero_point = None + + @property + def qmin_qmax(self): + """Calculate the range of the quantized integer based on the specified + quant_bits, sign, and symmetric properties.""" + if isinstance(self._quant_bits, tuple): + if self._quant_bits[0] == 4 and self._quant_bits[1] == 3 and len(self._quant_bits) == 2: + self._qmin = -448.0 + self._qmax = 448.0 + elif self._quant_bits[0] == 5 and self._quant_bits[1] == 2 and len(self._quant_bits) == 2: + self._qmin = -57344.0 + self._qmax = 57344.0 + else: + raise NotImplementedError( + "Currently, only float8_e4m3 and float8_e5m2 formats are supported. Please set quant_bits to (4,3) or (5,2) for the corresponding format." + ) + else: + if self._sign: + self._qmin = -(2 ** (self.bit_length() - 1)) + self._qmax = 2 ** (self.bit_length() - 1) - 1 + else: + self._qmin = 0 + self._qmax = 2 ** self.bit_length() + return self._qmin, self._qmax + + @abc.abstractmethod + def min_value(self) -> float: + """The minimum value of floating-point numbers.""" + raise NotImplementedError( + "Please implement the abstract method to get the The minimum value of floating-point numbers." + ) + + @abc.abstractmethod + def max_value(self) -> float: + """The maximum value of floating-point numbers.""" + raise NotImplementedError( + "Please implement the abstract method to get the the maximum value value of floating-point numbers." + ) + + def cal_scales_zero_points(self) -> Tuple[float, float]: + """Calculate the scales and zero points based on the min_value and max_value.""" + assert self.min_value() is not None and self.max_value() is not None + _qmin, _qmax = self.qmin_qmax + # For one-sided distributions, the range (_min , _max ) is relaxed to include zero. + # It is important to ensure that common operations like zero padding do not cause quantization errors. + _min = min(self.min_value(), 0.0) + _max = max(self.max_value(), 0.0) + + if self._symmetric: + self._scale = max(-_min, _max) + if self._sign: + self._zero_point = 0 + else: + self._zero_point = (_qmax + _qmin) / 2 + else: + self._scale = (_max - _min) / float(_qmax - _qmin) + self._zero_point = _qmin - round(_min / self._scale) + self._zero_point = np.clip(self._zero_point, _qmin, _qmax) + return self._scale, self._zero_point diff --git a/paddlenlp/generation/utils.py b/paddlenlp/generation/utils.py index e7c3dd162643..873234a500df 100644 --- a/paddlenlp/generation/utils.py +++ b/paddlenlp/generation/utils.py @@ -16,7 +16,8 @@ import copy import inspect -from typing import Optional, Union +from dataclasses import dataclass +from typing import Optional, Tuple, Union import paddle import paddle.distributed as dist @@ -26,6 +27,7 @@ from paddle.common_ops_import import convert_dtype from paddle.utils import map_structure +from paddlenlp.transformers.cache_utils import Cache, DynamicCache from paddlenlp.transformers.model_outputs import ModelOutput from paddlenlp.transformers.utils import get_scale_by_dtype from paddlenlp.utils.log import logger @@ -64,6 +66,33 @@ ] +@dataclass +class GreedySearchOutput(ModelOutput): + """ + Base class for outputs of generation models using greedy search. + + Args: + sequences (`paddle.Tensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(paddle.Tensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `paddle.Tensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + past_key_values (`tuple(tuple(paddle.Tensor)))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + NOTE: some models have a different `past_key_values` format, confirm with the model's documentation. + Usually a Tuple (one element for each layer of the decoder) of tuples (two elements, key tensor and value + tensor). The first Tuple is of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if + `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, + encoder_sequence_length, embed_size_per_head)`. + """ + + sequences: paddle.Tensor = None + scores: Optional[Tuple[paddle.Tensor]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[paddle.Tensor]]]] = None + + def get_unfinished_flag( input_ids: Tensor, unfinished_flag: Tensor, eos_token_id: Union[int, list[int], list[list[int]]] ) -> Tensor: @@ -861,6 +890,14 @@ def generate( model_kwargs["attention_mask"] = self.prepare_attention_mask_for_generation( input_ids, pad_token_id, eos_token_id ) + + # If a `Cache` instance is passed, checks whether the model is compatible with it + if isinstance(model_kwargs.get("past_key_values", None), Cache) and not self._supports_cache_class: + raise ValueError( + f"{self.__class__.__name__} does not support an instance of `Cache` as `past_key_values`. Please " + "check the model documentation for supported cache formats." + ) + self.is_encoder_decoder = self.config.is_encoder_decoder if self.is_encoder_decoder: @@ -1045,6 +1082,7 @@ def greedy_search( fast_ptq_sampling=False, trunc_input=True, synced_gpus=False, + return_dict_in_generate: Optional[bool] = None, **model_kwargs ): model_kwargs["use_cache"] = model_kwargs.get("use_cache", True) @@ -1119,19 +1157,25 @@ def greedy_search( if not paddle.any(unfinished_flag): generate_end = True + model_kwargs = self.update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + # Stop when there is a in all sentences if generate_end and not synced_gpus: break - model_kwargs = self.update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder - ) if fast_ptq_sampling: break if streamer is not None: streamer.end() - + if return_dict_in_generate: + return GreedySearchOutput( + sequences=input_ids[:, origin_len:] if trunc_input else input_ids, + scores=scores, + past_key_values=model_kwargs["past_key_values"], + ) return input_ids[:, origin_len:] if trunc_input else input_ids, scores def sample( @@ -1493,6 +1537,31 @@ def reorder_cache(self, cache, beam_idx): cache = map_structure(lambda x: paddle.index_select(x, beam_idx), cache) return cache + def _temporary_reorder_cache(self, past_key_values, beam_idx): + """ + Temporary function to handle the different types of cache reordering processes while we roll out `Cache`. + TODO: standardize cache formats and make all models compatible with `Cache`. It would remove the need + for this function, with `Cache.reorder_cache` being the sole remaining code path + """ + model_class = self.__class__.__name__.lower() + # Exception 1: code path for models using the legacy cache format + if isinstance(past_key_values, (tuple, list)): + past_key_values = self.reorder_cache(past_key_values, beam_idx) + # Exception 2: models with different cache formats. These are limited to `DynamicCache` until their + # cache format is standardized, to avoid adding complexity to the codebase. + elif "bloom" in model_class or "gptbigcode" in model_class: + if not isinstance(past_key_values, DynamicCache): + raise ValueError( + f"Using an unsupported cache format with {model_class}. Currently, it only supports the " + "legacy tuple format or `DynamicCache`" + ) + past_key_values = self.reorder_cache(past_key_values, beam_idx) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + # Standard code path: use the `Cache.reorder_cache` + else: + past_key_values.reorder_cache(beam_idx) + return past_key_values + def beam_search( self, input_ids, @@ -1642,10 +1711,12 @@ def beam_search( ) if "cache" in model_kwargs: # reorder the cache - model_kwargs["cache"] = self.reorder_cache(model_kwargs["cache"], beam_idx) + model_kwargs["cache"] = self._temporary_reorder_cache(model_kwargs["cache"], beam_idx) if "past_key_values" in model_kwargs: # reorder the cache - model_kwargs["past_key_values"] = self.reorder_cache(model_kwargs["past_key_values"], beam_idx) + model_kwargs["past_key_values"] = self._temporary_reorder_cache( + model_kwargs["past_key_values"], beam_idx + ) if fast_ptq_sampling: break @@ -1816,10 +1887,10 @@ def group_beam_search( if "cache" in model_kwargs: # reorder the cache - model_kwargs["cache"] = self.reorder_cache(model_kwargs["cache"], reordering_indices) + model_kwargs["cache"] = self._temporary_reorder_cache(model_kwargs["cache"], reordering_indices) if "past_key_values" in model_kwargs: # reorder the cache - model_kwargs["past_key_values"] = self.reorder_cache( + model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"], reordering_indices ) diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index c5a54765b723..44dc210607df 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -15,6 +15,7 @@ from .configuration_utils import PretrainedConfig from .model_utils import PretrainedModel, register_base_model +from .cache_utils import Cache, DynamicCache from .tokenizer_utils import ( PretrainedTokenizer, BPETokenizer, diff --git a/paddlenlp/transformers/cache_utils.py b/paddlenlp/transformers/cache_utils.py new file mode 100644 index 000000000000..802b07707c3a --- /dev/null +++ b/paddlenlp/transformers/cache_utils.py @@ -0,0 +1,158 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List, Optional, Tuple + +import paddle + + +class Cache: + """ + Base, abstract class for all caches. The actual data structure is specific to each subclass. + """ + + def update( + self, + key_states: paddle.Tensor, + value_states: paddle.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`paddle.Tensor`): + The new key states to cache. + value_states (`paddle.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. These are specific to each subclass and allow new types of + cache to be created. + + Return: + A tuple containing the updated key and value states. + """ + raise NotImplementedError("Make sure to implement `update` in a subclass.") + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + + +class DynamicCache(Cache): + """ + A cache that grows dynamically as more tokens are generated. This is the default for generative models. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]`. + """ + + def __init__(self) -> None: + self.key_cache: List[paddle.Tensor] = [] + self.value_cache: List[paddle.Tensor] = [] + self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + + def __getitem__(self, layer_idx: int) -> List[Tuple[paddle.Tensor]]: + """ + Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + sequence length. + """ + if layer_idx < len(self): + return (self.key_cache[layer_idx], self.value_cache[layer_idx]) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + """ + Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + keys and values + """ + for layer_idx in range(len(self)): + yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def update( + self, + key_states: paddle.Tensor, + value_states: paddle.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. + + Parameters: + key_states (`paddle.Tensor`): + The new key states to cache. + value_states (`paddle.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self.seen_tokens += key_states.shape[-3] + + # Update the cache + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + else: + self.key_cache[layer_idx] = paddle.concat([self.key_cache[layer_idx], key_states], axis=-3) + self.value_cache[layer_idx] = paddle.concat([self.value_cache[layer_idx], value_states], axis=-3) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + if len(self.key_cache) <= layer_idx: + return 0 + return self.key_cache[layer_idx].shape[-3] + + def reorder_cache(self, beam_idx: paddle.Tensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(beam_idx) + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(beam_idx) + + def to_legacy_cache(self) -> Tuple[Tuple[paddle.Tensor], Tuple[paddle.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]] = None) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py old mode 100755 new mode 100644 index 099abbbff68c..a39cf60ef4e5 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -39,6 +39,8 @@ recompute, ) +from ..cache_utils import Cache, DynamicCache + try: from paddle.incubate.nn.functional import fused_rotary_position_embedding except ImportError: @@ -682,10 +684,17 @@ def forward(self, x): class LlamaAttention(nn.Layer): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None, layerwise_recompute: bool = False): super().__init__() self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads @@ -915,7 +924,7 @@ def forward( self, hidden_states, position_ids: Optional[Tuple[paddle.Tensor]] = None, - past_key_value: Optional[Tuple[paddle.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[paddle.Tensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -1039,8 +1048,14 @@ def forward( kv_seq_len = key_states.shape[-3] if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-3] - + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_seq_length(self.layer_idx) + sin = cos = None if self.config.rope: if self.reshard_layer is not None: batch_size, seq_length, _, _ = query_states.shape @@ -1083,17 +1098,14 @@ def forward( cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - + cache_kwargs = {} # [bs, seq_len, num_head, head_dim] if past_key_value is not None: # reuse k, v, self_attention - key_states = paddle.concat([past_key_value[0], key_states], axis=1) - value_states = paddle.concat([past_key_value[1], value_states], axis=1) - if self.config.immediate_clear_past_key_value: - past_key_value[0]._clear_data() - past_key_value[1]._clear_data() + if sin is not None and cos is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) - past_key_value = (key_states, value_states) if use_cache else None if self.kv_indices is not None: key_states = paddle.index_select(key_states, self.kv_indices, axis=2) value_states = paddle.index_select(value_states, self.kv_indices, axis=2) @@ -1168,11 +1180,11 @@ def forward( class LlamaDecoderLayer(nn.Layer): - def __init__(self, config, layerwise_recompute: bool = False): + def __init__(self, config, layer_idx: int, layerwise_recompute: bool = False): super().__init__() self.config = config self.hidden_size = config.hidden_size - self.self_attn = LlamaAttention(config, layerwise_recompute) + self.self_attn = LlamaAttention(config, layer_idx=layer_idx, layerwise_recompute=layerwise_recompute) self.mlp = LlamaMLP(config) self.input_layernorm = LlamaRMSNorm(config) self.post_attention_layernorm = LlamaRMSNorm(config) @@ -1189,7 +1201,7 @@ def forward( position_ids: Optional[Tuple[paddle.Tensor]] = None, attention_mask: Optional[paddle.Tensor] = None, output_attentions: Optional[bool] = False, - past_key_value: Optional[Tuple[paddle.Tensor]] = None, + past_key_value: Optional[Cache] = None, use_cache: Optional[bool] = False, alibi: Optional[paddle.Tensor] = None, attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, @@ -1286,6 +1298,7 @@ class LlamaPretrainedModel(PretrainedModel): pretrained_init_configuration = LLAMA_PRETRAINED_INIT_CONFIGURATION pretrained_resource_files_map = LLAMA_PRETRAINED_RESOURCE_FILES_MAP _keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"] + _supports_cache_class = True @classmethod def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]: @@ -1518,7 +1531,7 @@ def __init__(self, config: LlamaConfig): self.layers = nn.LayerList( [ LlamaDecoderLayer( - create_skip_config_for_refined_recompute(i, config), i not in self.no_recompute_layers + create_skip_config_for_refined_recompute(i, config), i, i not in self.no_recompute_layers ) for i in range(config.num_hidden_layers) ] @@ -1674,16 +1687,15 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - if past_key_values is None: - past_key_values = tuple([None] * len(self.layers)) - # NOTE: to make cache can be clear in-time - past_key_values = list(past_key_values) + past_key_values_length = 0 + + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_seq_length() + if seq_length == 6: + raise UnboundLocalError - seq_length_with_past = seq_length - cache_length = 0 - if past_key_values[0] is not None: - cache_length = past_key_values[0][0].shape[1] - seq_length_with_past += cache_length if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) @@ -1702,7 +1714,7 @@ def forward( attention_mask = None elif attn_mask_startend_row_indices is None and attention_mask is None: # [bs, seq_len] - attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + attention_mask = paddle.ones((batch_size, seq_length + past_key_values_length), dtype=paddle.bool) if attn_mask_startend_row_indices is None and self.config.alibi: if self.config.use_long_sequence_strategies: alibi_layer = LongSequenceStrategies.build_long_sequence_strategy( @@ -1721,9 +1733,11 @@ def forward( * block_size : (self.config.tensor_parallel_rank + 1) * block_size, ] - alibi = alibi.reshape([batch_size * block_size, 1, seq_length_with_past]) + alibi = alibi.reshape([batch_size * block_size, 1, seq_length + past_key_values_length]) else: - alibi = alibi.reshape([batch_size * self.config.num_attention_heads, 1, seq_length_with_past]) + alibi = alibi.reshape( + [batch_size * self.config.num_attention_heads, 1, seq_length + past_key_values_length] + ) else: alibi = None @@ -1736,7 +1750,7 @@ def forward( attention_mask = None elif attn_mask_startend_row_indices is None: attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + attention_mask, (batch_size, seq_length), past_key_values_length, inputs_embeds.dtype ) # [bs, 1, seq_len, seq_len] is_casual = False @@ -1759,12 +1773,11 @@ def forward( # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None for idx, (decoder_layer) in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None has_gradient = not hidden_states.stop_gradient if ( @@ -1779,7 +1792,7 @@ def forward( position_ids, attention_mask, output_attentions, - past_key_value, + past_key_values, use_cache, alibi=alibi, attn_mask_startend_row_indices=attn_mask_startend_row_indices, @@ -1790,7 +1803,7 @@ def forward( position_ids, attention_mask, output_attentions, - past_key_value, + past_key_values, use_cache, alibi=alibi, attn_mask_startend_row_indices=attn_mask_startend_row_indices, @@ -1798,7 +1811,6 @@ def forward( ) # NOTE: clear outdate cache after it has been used for memory saving - past_key_value = past_key_values[idx] = None if type(layer_outputs) is tuple: hidden_states = layer_outputs[0] else: @@ -1808,7 +1820,7 @@ def forward( all_self_attns += (layer_outputs[1],) if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if self.config.use_last_token_for_generation: hidden_states = paddle.unsqueeze(hidden_states[:, -1, :], 1) @@ -1819,7 +1831,9 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = next_decoder_cache if use_cache else None + next_cache = None + if use_cache: + next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -2020,9 +2034,28 @@ def prepare_inputs_for_generation( position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))) attention_mask = kwargs.get("attention_mask", None) if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(axis=-1) + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + cache_length = past_length = past_key_values[0][0].shape[1] + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the + # older attention values, as their corresponding values are not part of the input. + if cache_length < past_length and attention_mask is not None: + attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :] + else: + input_ids = input_ids[:, -1].unsqueeze(axis=-1) position_ids = position_ids[:, -1].unsqueeze(-1) - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index 2d456e66a880..4e009b059a6e 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -1019,6 +1019,9 @@ class is a pretrained model class adding layers on top of the base model, _keys_to_ignore_on_save = None _tied_weights_keys = None + # Has support for a `Cache` instance as `past_key_values` + _supports_cache_class = False + def __init__(self, *args, **kwargs): super(PretrainedModel, self).__init__() diff --git a/tests/test_cache_utils.py b/tests/test_cache_utils.py new file mode 100644 index 000000000000..90061f3a9856 --- /dev/null +++ b/tests/test_cache_utils.py @@ -0,0 +1,205 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import unittest + +import numpy as np +import paddle + +from paddlenlp.transformers import ( # AutoModelForCausalLM,; AutoTokenizer, + DynamicCache, + LlamaForCausalLM, +) + +# from .testing_utils import slow + + +def set_seed(seed): + """sets random seed""" + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + + +class CacheTest(unittest.TestCase): + def test_cache_equivalence(self): + """Tests that we can convert back and forth between the legacy cache format and DynamicCache""" + legacy_cache = () + new_cache = DynamicCache() + + # Creates a new cache with 10 layers in both formats + for layer_idx in range(10): + new_key = paddle.rand((2, 4, 8, 16)) + new_value = paddle.rand((2, 4, 8, 16)) + new_cache.update(new_key, new_value, layer_idx) + legacy_cache += ((new_key, new_value),) + + # Sanity check 1: they must have the same shapes + self.assertTrue(len(legacy_cache), len(new_cache)) + for layer_idx in range(10): + self.assertTrue(len(legacy_cache[layer_idx]), len(legacy_cache[layer_idx])) + for key_value_idx in range(2): + self.assertTrue( + legacy_cache[layer_idx][key_value_idx].shape == new_cache[layer_idx][key_value_idx].shape + ) + + # Sanity check 2: we can get the sequence length in multiple ways with DynamicCache, and they return the + # expected value + self.assertTrue(legacy_cache[0][0].shape[-3] == new_cache[0][0].shape[-3] == new_cache.get_seq_length() == 4) + + # Sanity check 3: they must be equal, and both support indexing + for layer_idx in range(10): + for key_value_idx in range(2): + self.assertTrue( + paddle.allclose(new_cache[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx]) + ) + + # Test 1: We can convert from legacy to new with no changes + from_legacy = DynamicCache.from_legacy_cache(legacy_cache) + for layer_idx in range(10): + for key_value_idx in range(2): + self.assertTrue( + paddle.allclose(from_legacy[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx]) + ) + + # Test 2: We can convert from new to legacy with no changes + to_legacy = new_cache.to_legacy_cache() + for layer_idx in range(10): + for key_value_idx in range(2): + self.assertTrue( + paddle.allclose(to_legacy[layer_idx][key_value_idx], new_cache[layer_idx][key_value_idx]) + ) + + def test_reorder_cache_retrocompatibility(self): + """Tests that Cache.reorder_cache is retrocompatible with the legacy code path""" + legacy_reorder_fn = LlamaForCausalLM.reorder_cache # An example of a legacy `reorder_cache` function + + legacy_cache = () + new_cache = DynamicCache() + + # Creates a new cache with 10 layers in both formats + for layer_idx in range(10): + new_key = paddle.rand((4, 4, 8, 16)) + new_value = paddle.rand((4, 4, 8, 16)) + new_cache.update(new_key, new_value, layer_idx) + legacy_cache += ((new_key, new_value),) + + # Let's create some dummy beam indices. From the shape above, it is equivalent to the case where num_beams=4 + # and batch_size=1 + # beam_idx = paddle.randint(low=0, high=4, size=(4,)) + beam_idx = paddle.randint(low=0, high=4, shape=(4,)) + beam_idx = paddle.Tensor(beam_idx, dtype=paddle.int64) + print(f"beam_idx = {beam_idx}") + legacy_cache_reordered = legacy_reorder_fn(None, legacy_cache, beam_idx) + new_cache.reorder_cache(beam_idx) + + # Let's check that the results are the same + for layer_idx in range(10): + for key_value_idx in range(2): + self.assertTrue( + paddle.allclose( + new_cache[layer_idx][key_value_idx], legacy_cache_reordered[layer_idx][key_value_idx] + ) + ) + + +# @slow +# class CacheIntegrationTest(unittest.TestCase): +# def test_dynamic_cache_hard(self): +# tokenizer = AutoTokenizer.from_pretrained( +# "meta-llama/Llama-2-7b-hf", padding_side="left", from_hf_hub=True, use_fast=True +# ) +# model = AutoModelForCausalLM.from_pretrained( +# "meta-llama/Llama-2-7b-hf", +# dtype=paddle.float16, +# from_hf_hub=True, +# ) +# inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="np") +# for key in inputs: +# inputs[key] = paddle.to_tensor(inputs[key]) + +# # DynamicCache and the legacy cache format should be equivalent +# set_seed(0) +# gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256) +# set_seed(0) +# gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache()) +# self.assertListEqual(gen_out_legacy[0].tolist(), gen_out[0].tolist()) +# self.assertListEqual(gen_out_legacy[1].tolist(), gen_out[1].tolist()) + +# decoded = tokenizer.batch_decode(gen_out[0], skip_special_tokens=True) + +# expected_text = ( +# "Here's everything I know about cats. Cats are mysterious creatures. They can't talk, and they don't like " +# "to be held. They don't play fetch, and they don't like to be hugged. But they do like to be petted.\n" +# "Cats are also very independent. They don't like to be told what to do, and they don't like to be told " +# "what to eat. They are also very territorial. They don't like to share their food or their toys.\nCats " +# "are also very curious. They like to explore, and they like to play. They are also very fast. They can " +# "run very fast, and they can jump very high.\nCats are also very smart. They can learn tricks, and they " +# "can solve problems. They are also very playful. They like to play with toys, and they like to play with " +# "other cats.\nCats are also very affectionate. They like to be petted, and they like to be held. They " +# "also like to be scratched.\nCats are also very clean. They like to groom themselves, and they like to " +# "clean their litter box.\nCats are also very independent. They don't" +# ) +# self.assertEqual(decoded[0], expected_text) + +# def test_dynamic_cache_batched(self): +# tokenizer = AutoTokenizer.from_pretrained( +# "meta-llama/Llama-2-7b-hf", padding_side="left", from_hf_hub=True, use_fast=True +# ) +# tokenizer.pad_token = tokenizer.eos_token +# model = AutoModelForCausalLM.from_pretrained( +# "meta-llama/Llama-2-7b-hf", +# device_map="auto", +# dtype=paddle.float16, +# from_hf_hub=True, +# ) +# inputs = tokenizer(["A sequence: 1, 2, 3, 4, 5", "A sequence: A, B, C"], padding=True, return_tensors="np").to( +# model.device +# ) +# for key in inputs: +# inputs[key] = paddle.to_tensor(inputs[key]) + +# gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache()) +# decoded = tokenizer.batch_decode(gen_out[0], skip_special_tokens=True) +# expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"] +# self.assertListEqual(decoded, expected_text) + +# def test_dynamic_cache_beam_search(self): +# tokenizer = AutoTokenizer.from_pretrained( +# "meta-llama/Llama-2-7b-hf", padding_side="left", from_hf_hub=True, use_fast=True +# ) +# model = AutoModelForCausalLM.from_pretrained( +# "meta-llama/Llama-2-7b-hf", +# device_map="auto", +# dtype=paddle.float16, +# from_hf_hub=True, +# ) + +# inputs = tokenizer(["The best color is"], return_tensors="np") +# for key in inputs: +# inputs[key] = paddle.to_tensor(inputs[key]) +# gen_out = model.generate( +# **inputs, +# do_sample=False, +# max_new_tokens=20, +# num_beams=2, +# num_return_sequences=2, +# ) +# decoded = tokenizer.batch_decode(gen_out[0], skip_special_tokens=True) +# expected_text = [ +# "The best color is the one that makes you feel good.\nThe best color is the one that makes you feel good", +# "The best color is the one that suits you.\nThe best color is the one that suits you. The", +# ] +# self.assertListEqual(decoded, expected_text) diff --git a/tests/transformers/test_generation_utils.py b/tests/transformers/test_generation_utils.py index d36254d88bb2..2ff76992d8e2 100644 --- a/tests/transformers/test_generation_utils.py +++ b/tests/transformers/test_generation_utils.py @@ -14,9 +14,12 @@ # limitations under the License. from __future__ import annotations +import random import unittest +import numpy as np import paddle +from parameterized import parameterized from paddlenlp.generation import ( BeamSearchScorer, @@ -39,9 +42,17 @@ PretrainedConfig, PretrainedTokenizer, ) +from paddlenlp.transformers.cache_utils import DynamicCache from tests.testing_utils import slow +def set_seed(seed): + """sets random seed""" + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + + def top_k_top_p_filtering( logits, top_k=0, @@ -628,6 +639,64 @@ def test_group_beam_search_generate(self): ) self.assertListEqual(output_generate[0].tolist(), output_group_beam_search[0].tolist()) + @parameterized.expand([(1, False), (1, True), (4, False)]) + def test_new_cache_format(self, num_beams, do_sample): + # Tests that generating with the new format is exactly the same as the legacy one (for models that support it). + # 👉 tests with and without beam search so that we can test with and without cache reordering. + # 👉 tests with and without sampling so we can cover the most common use cases. + for model_class in self.all_generative_model_classes: + if not model_class._supports_cache_class: + self.skipTest("This model does not support the new cache format") + config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() + config.use_cache = True + model = self._make_model_instance(config, model_class) + model.eval() + + generation_kwargs = { + "max_new_tokens": 5, + "do_sample": do_sample, + "num_beams": num_beams, + "num_return_sequences": 1, # `num_return_sequences` has to be 1 when doing greedy search + "return_dict_in_generate": True, # Required to return `past_key_values` + } + # Sets seed before calling `generate` for the case with do_sample=True + seed = paddle.randint(0, 1000000, (1,)).item() + set_seed(seed) + legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + set_seed(seed) + new_results = model.generate( + input_ids, attention_mask=attention_mask, past_key_values=DynamicCache(), **generation_kwargs + ) + # The two sets of generated sequences must match, despite the cache format between forward passes being + # different + + self.assertListEqual(legacy_results[0].tolist(), new_results[0].tolist()) + self.assertTrue(isinstance(legacy_results.past_key_values, tuple)) + self.assertTrue(isinstance(new_results.past_key_values, DynamicCache)) + # The contents of the two caches, when converted to the same format (in both directions!), must match + legacy_cache = legacy_results.past_key_values + # print(f"legacy_cache: {legacy_cache}") + new_cache_converted = new_results.past_key_values.to_legacy_cache() + # print(f"new_cache_converted: {new_cache_converted[:,:8,:,:]}") + for layer_idx in range(len(legacy_cache)): + for kv_idx in range(len(legacy_cache[layer_idx])): + self.assertTrue( + paddle.allclose( + legacy_cache[layer_idx][kv_idx], + new_cache_converted[layer_idx][kv_idx], + ) + ) + new_cache = new_results.past_key_values + legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values) + for layer_idx in range(len(new_cache)): + for kv_idx in range(len(new_cache[layer_idx])): + self.assertTrue( + paddle.allclose( + new_cache[layer_idx][kv_idx], + legacy_cache_converted[layer_idx][kv_idx], + ) + ) + def _check_sequence_inside_sequence(self, tensor_1, tensor_2): # check if tensor_1 inside tensor_2 or tensor_2 inside tensor_1. # set to same device. we don't care what device.