From b7645c41463849e710f2734fb21bd7111fbc88f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 10:16:50 +0800 Subject: [PATCH 1/9] Update app.py --- app.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/app.py b/app.py index 775040f..cb37438 100644 --- a/app.py +++ b/app.py @@ -8,10 +8,8 @@ print('Done'.center(64, '-')) # 加载模型 -# model_name = 'THUDM/chatglm-6b' -model_name = 'silver/chatglm-6b-int4-slim' -# model_name = 'BelleGroup/BELLE-LLAMA-7B-2M' -# model_name = 'BelleGroup/BELLE-LLAMA-7B-2M-gptq' +model_name = 'THUDM/chatglm-6b' +# model_name = 'silver/chatglm-6b-int4-slim' if 'chatglm' in model_name.lower(): from predictors.chatglm_predictor import ChatGLM From bbb1d3ea67a4d16f91f39926d0ae6798b68c00a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 10:30:40 +0800 Subject: [PATCH 2/9] fix cuda --- predictors/chatglm_predictor.py | 3 ++- predictors/llama.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/predictors/chatglm_predictor.py b/predictors/chatglm_predictor.py index 8f3fcd4..b7600c8 100644 --- a/predictors/chatglm_predictor.py +++ b/predictors/chatglm_predictor.py @@ -104,7 +104,8 @@ def stream_chat_continue(self, input_length = len(batch_input['input_ids'][0]) final_input_ids = torch.cat( [batch_input['input_ids'], batch_answer['input_ids'][:, :-2]], - dim=-1).cuda() + dim=-1) + final_input_ids = final_input_ids.to(model.device) attention_mask = model.get_masks( final_input_ids, device=final_input_ids.device) diff --git a/predictors/llama.py b/predictors/llama.py index 6cc881f..37fb2b3 100644 --- a/predictors/llama.py +++ b/predictors/llama.py @@ -178,7 +178,8 @@ def stream_chat_continue(self, input_length = len(batch_input['input_ids'][0]) final_input_ids = torch.cat( [batch_input['input_ids'], batch_answer['input_ids'][:, :-2]], - dim=-1).cuda() + dim=-1) + final_input_ids = final_input_ids.to(model.device) attention_mask = torch.ones_like(final_input_ids).bool().to( model.device) attention_mask[:, input_length:] = False From 440bd1de900b65d3193afcca386bf6b0a1378b3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 10:31:45 +0800 Subject: [PATCH 3/9] Update chatglm_predictor.py --- predictors/chatglm_predictor.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/predictors/chatglm_predictor.py b/predictors/chatglm_predictor.py index b7600c8..d5f20a9 100644 --- a/predictors/chatglm_predictor.py +++ b/predictors/chatglm_predictor.py @@ -32,11 +32,19 @@ def __init__(self, model_name): if 'slim' in model_name: model = AutoModel.from_pretrained( model_name, trust_remote_code=True, - resume_download=True).half().to(self.device) + resume_download=True) + if self.device == 'cuda': + model = model.half().to(self.device) + else: + model = model.to(self.device) elif 'int4' in model_name: model = AutoModel.from_pretrained( model_name, trust_remote_code=True, - resume_download=True).half().to(self.device) + resume_download=True) + if self.device == 'cuda': + model = model.half().to(self.device) + else: + model = model.to(self.device) else: model = AutoModel.from_pretrained( model_name, From 57842e73cf2d03940273a67c96e2098a456de5fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 10:40:41 +0800 Subject: [PATCH 4/9] Update utils_env.py --- utils_env.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/utils_env.py b/utils_env.py index 32266db..9417fff 100644 --- a/utils_env.py +++ b/utils_env.py @@ -23,6 +23,8 @@ def collect_env(): devices[torch.cuda.get_device_name(k)].append(str(k)) for name, device_ids in devices.items(): env_info['GPU ' + ','.join(device_ids)] = name + else: + env_info['CUDA available'] = False return env_info From 5099a9a6d735241d7a9010aff33b33e70573fcc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 10:43:30 +0800 Subject: [PATCH 5/9] Update modeling_chatglm.py --- chatglm/modeling_chatglm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chatglm/modeling_chatglm.py b/chatglm/modeling_chatglm.py index 7ae2638..bda5604 100644 --- a/chatglm/modeling_chatglm.py +++ b/chatglm/modeling_chatglm.py @@ -146,7 +146,8 @@ class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, base=10000, precision=torch.half, learnable=False): super().__init__() inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) - inv_freq = inv_freq.half() + if precision == torch.half: + inv_freq = inv_freq.half() self.learnable = learnable if learnable: self.inv_freq = torch.nn.Parameter(inv_freq) From c5c1033538c2c09f7baa0b2d5afa96de500c72e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 10:58:46 +0800 Subject: [PATCH 6/9] update --- app.py | 1 + 1 file changed, 1 insertion(+) diff --git a/app.py b/app.py index cb37438..f03c3b6 100644 --- a/app.py +++ b/app.py @@ -95,3 +95,4 @@ def interrupt(allow_generate): outputs=[chatbot, query, continue_message]) interrupt_btn.click(interrupt, inputs=[allow_generate]) demo.queue(concurrency_count=4).launch(server_name='0.0.0.0', server_port=7860, share=False, inbrowser=False) +demo.close() From aabfb463224e26d85d7c44c7b1039e079858636e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 11:03:59 +0800 Subject: [PATCH 7/9] update --- chatglm/configuration_chatglm.py | 24 +- chatglm/modeling_chatglm.py | 488 +++++++++++++++++++------------ chatglm/quantization.py | 412 ++++---------------------- chatglm/tokenization_chatglm.py | 290 +++++++++++------- 4 files changed, 579 insertions(+), 635 deletions(-) diff --git a/chatglm/configuration_chatglm.py b/chatglm/configuration_chatglm.py index 52efead..78f3425 100644 --- a/chatglm/configuration_chatglm.py +++ b/chatglm/configuration_chatglm.py @@ -12,9 +12,12 @@ class ChatGLMConfig(PretrainedConfig): It is used to instantiate an ChatGLM model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the ChatGLM-6B [THUDM/ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. + + Args: vocab_size (`int`, *optional*, defaults to 150528): Vocabulary size of the ChatGLM-6B model. Defines the number of different tokens that can be represented by the @@ -36,13 +39,17 @@ class ChatGLMConfig(PretrainedConfig): use_cache (`bool`, *optional*, defaults to `True`): Whether the model should return the last key/values attentions (not used by all models). Example: + ```python - >>> from chatglm.configuration_chatglm import ChatGLMConfig - >>> from chatglm.modeling_chatglm import ChatGLMModel + >>> from configuration_chatglm import ChatGLMConfig + >>> from modeling_chatglm import ChatGLMModel + >>> # Initializing a ChatGLM-6B THUDM/ChatGLM-6B style configuration >>> configuration = ChatGLMConfig() + >>> # Initializing a model from the THUDM/ChatGLM-6B style configuration >>> model = ChatGLMModel(configuration) + >>> # Accessing the model configuration >>> configuration = model.config ``` @@ -59,12 +66,15 @@ def __init__( use_cache=False, bos_token_id=150004, eos_token_id=150005, + mask_token_id=150000, + gmask_token_id=150001, pad_token_id=0, max_sequence_length=2048, inner_hidden_size=16384, position_encoding_2d=True, quantization_bit=0, - quantization_embeddings=False, + pre_seq_len=None, + prefix_projection=False, **kwargs ): self.num_layers = num_layers @@ -78,9 +88,13 @@ def __init__( self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id + self.mask_token_id = mask_token_id + self.gmask_token_id = gmask_token_id self.position_encoding_2d = position_encoding_2d - self.quantization_bit=quantization_bit - self.quantization_embeddings=quantization_embeddings + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/chatglm/modeling_chatglm.py b/chatglm/modeling_chatglm.py index bda5604..f69f26c 100644 --- a/chatglm/modeling_chatglm.py +++ b/chatglm/modeling_chatglm.py @@ -5,6 +5,7 @@ import os import warnings import re +import sys import torch import torch.utils.checkpoint @@ -12,7 +13,7 @@ from torch import nn from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn.utils import skip_init -from typing import Optional, Tuple, Union, List, Callable +from typing import Optional, Tuple, Union, List, Callable, Dict, Any from transformers.utils import ( add_code_sample_docstrings, @@ -27,16 +28,17 @@ from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging from transformers.generation.logits_process import LogitsProcessor -from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput from .configuration_chatglm import ChatGLMConfig - # flags required to enable jit fusion kernels -torch._C._jit_set_profiling_mode(False) -torch._C._jit_set_profiling_executor(False) -torch._C._jit_override_can_fuse_on_cpu(True) -torch._C._jit_override_can_fuse_on_gpu(True) + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) logger = logging.get_logger(__name__) @@ -53,7 +55,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if torch.isnan(scores).any() or torch.isinf(scores).any(): scores.zero_() - scores[..., 20005] = 5e4 + scores[..., 5] = 5e4 return scores @@ -131,6 +133,36 @@ def load_tf_weights_in_chatglm_6b(model, config, tf_checkpoint_path): return model +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(config.hidden_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2) + ) + else: + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + @torch.jit.script def gelu_impl(x): """OpenAI's gelu implementation.""" @@ -193,6 +225,7 @@ def _apply(self, fn): self.sin_cached = fn(self.sin_cached) return super()._apply(fn) + def rotate_half(x): x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in earlier torch versions @@ -220,7 +253,7 @@ def attention_fn( use_cache=False, ): if layer_past is not None: - past_key, past_value = layer_past + past_key, past_value = layer_past[0], layer_past[1] key_layer = torch.cat((past_key, key_layer), dim=0) value_layer = torch.cat((past_value, value_layer), dim=0) @@ -248,10 +281,8 @@ def attention_fn( # [sk, b, np, hn] -> [sk, b * np, hn] key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - matmul_result = torch.empty( - output_size[0] * output_size[1], - output_size[2], - output_size[3], + matmul_result = torch.zeros( + 1, 1, 1, dtype=query_layer.dtype, device=query_layer.device, ) @@ -274,7 +305,7 @@ def attention_fn( if not (attention_mask == 0).all(): # if auto-regressive, skip attention_scores.masked_fill_(attention_mask, -10000.0) - dtype = attention_scores.type() + dtype = attention_scores.dtype attention_scores = attention_scores.float() attention_scores = attention_scores * query_key_layer_scaling_coeff @@ -316,10 +347,18 @@ def attention_fn( return outputs +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + class SelfAttention(torch.nn.Module): def __init__(self, hidden_size, num_attention_heads, layer_id, hidden_size_per_attention_head=None, bias=True, - params_dtype=torch.float, position_encoding_2d=True): + params_dtype=torch.float, position_encoding_2d=True, empty_init=True): + if empty_init: + init_method = skip_init + else: + init_method = default_init super(SelfAttention, self).__init__() self.layer_id = layer_id @@ -347,7 +386,7 @@ def __init__(self, hidden_size, num_attention_heads, self.inner_hidden_size = num_attention_heads * self.hidden_size_per_attention_head # Strided linear layer. - self.query_key_value = skip_init( + self.query_key_value = init_method( torch.nn.Linear, hidden_size, 3 * self.inner_hidden_size, @@ -355,7 +394,7 @@ def __init__(self, hidden_size, num_attention_heads, dtype=params_dtype, ) - self.dense = skip_init( + self.dense = init_method( torch.nn.Linear, self.inner_hidden_size, hidden_size, @@ -468,8 +507,12 @@ def forward(self, x): class GLU(torch.nn.Module): def __init__(self, hidden_size, inner_hidden_size=None, - layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float): + layer_id=None, bias=True, activation_func=gelu, params_dtype=torch.float, empty_init=True): super(GLU, self).__init__() + if empty_init: + init_method = skip_init + else: + init_method = default_init self.layer_id = layer_id self.activation_func = activation_func @@ -478,7 +521,7 @@ def __init__(self, hidden_size, inner_hidden_size=None, if inner_hidden_size is None: inner_hidden_size = 4 * hidden_size self.inner_hidden_size = inner_hidden_size - self.dense_h_to_4h = skip_init( + self.dense_h_to_4h = init_method( torch.nn.Linear, self.hidden_size, self.inner_hidden_size, @@ -486,7 +529,7 @@ def __init__(self, hidden_size, inner_hidden_size=None, dtype=params_dtype, ) # Project back to h. - self.dense_4h_to_h = skip_init( + self.dense_4h_to_h = init_method( torch.nn.Linear, self.inner_hidden_size, self.hidden_size, @@ -522,7 +565,8 @@ def __init__( use_bias=True, params_dtype=torch.float, num_layers=28, - position_encoding_2d=True + position_encoding_2d=True, + empty_init=True ): super(GLMBlock, self).__init__() # Set output layer initialization if not provided. @@ -542,7 +586,8 @@ def __init__( hidden_size_per_attention_head=hidden_size_per_attention_head, bias=use_bias, params_dtype=params_dtype, - position_encoding_2d=self.position_encoding_2d + position_encoding_2d=self.position_encoding_2d, + empty_init=empty_init ) # Layernorm on the input data. @@ -557,6 +602,7 @@ def __init__( bias=use_bias, layer_id=layer_id, params_dtype=params_dtype, + empty_init=empty_init ) def forward( @@ -620,10 +666,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel): """ is_parallelizable = False - supports_gradient_checkpointing = False + supports_gradient_checkpointing = True config_class = ChatGLMConfig base_model_prefix = "transformer" - _no_split_modules = ["GLM6BBlock"] + _no_split_modules = ["GLMBlock"] def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -632,11 +678,51 @@ def _init_weights(self, module: nn.Module): """Initialize the weights.""" return + def get_masks(self, input_ids, device): + batch_size, seq_length = input_ids.shape + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) + attention_mask.tril_() + for i, context_length in enumerate(context_lengths): + attention_mask[i, :, :context_length] = 1 + attention_mask.unsqueeze_(1) + attention_mask = (attention_mask < 0.5).bool() + + return attention_mask + + def get_position_ids(self, input_ids, mask_positions, device, use_gmasks=None): + batch_size, seq_length = input_ids.shape + if use_gmasks is None: + use_gmasks = [False] * batch_size + context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids] + if self.position_encoding_2d: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + for i, context_length in enumerate(context_lengths): + position_ids[i, context_length:] = mask_positions[i] + block_position_ids = [torch.cat(( + torch.zeros(context_length, dtype=torch.long, device=device), + torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1 + )) for context_length in context_lengths] + block_position_ids = torch.stack(block_position_ids, dim=0) + position_ids = torch.stack((position_ids, block_position_ids), dim=1) + else: + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + for i, context_length in enumerate(context_lengths): + if not use_gmasks[i]: + position_ids[context_length:] = mask_positions[i] + + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, ChatGLMModel): + module.gradient_checkpointing = value + CHATGLM_6B_START_DOCSTRING = r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. + Parameters: config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. @@ -647,28 +733,37 @@ def _init_weights(self, module: nn.Module): Args: input_ids (`torch.LongTensor` of shape `({0})`): Indices of input sequence tokens in the vocabulary. + Indices can be obtained using [`ChatGLM6BTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`: + - 0 corresponds to a *sentence A* token, - 1 corresponds to a *sentence B* token. + [What are token type IDs?](../glossary#token-type-ids) position_ids (`torch.LongTensor` of shape `({0})`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.max_position_embeddings - 1]`. + [What are position IDs?](../glossary#position-ids) head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert *input_ids* indices into associated vectors @@ -690,11 +785,13 @@ def _init_weights(self, module: nn.Module): ) class ChatGLMModel(ChatGLMPreTrainedModel): """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of cross-attention is added between the self-attention layers, following the architecture described in [Attention is all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` @@ -702,9 +799,12 @@ class ChatGLMModel(ChatGLMPreTrainedModel): `encoder_hidden_states` is then expected as an input to the forward pass. """ - def __init__(self, config: ChatGLMConfig): + def __init__(self, config: ChatGLMConfig, empty_init=True): super().__init__(config) - + if empty_init: + init_method = skip_init + else: + init_method = default_init # recording parameters self.max_sequence_length = config.max_sequence_length self.hidden_size = config.hidden_size @@ -716,12 +816,15 @@ def __init__(self, config: ChatGLMConfig): self.inner_hidden_size = config.inner_hidden_size self.hidden_size_per_attention_head = self.hidden_size // self.num_attention_heads self.position_encoding_2d = config.position_encoding_2d + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection - self.word_embeddings = skip_init( + self.word_embeddings = init_method( torch.nn.Embedding, num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.params_dtype ) + self.gradient_checkpointing = False def get_layer(layer_id): return GLMBlock( @@ -735,6 +838,7 @@ def get_layer(layer_id): use_bias=True, params_dtype=self.params_dtype, position_encoding_2d=self.position_encoding_2d, + empty_init=empty_init ) self.layers = torch.nn.ModuleList( @@ -744,43 +848,38 @@ def get_layer(layer_id): # Final layer norm before output. self.final_layernorm = LayerNorm(self.hidden_size, eps=self.layernorm_epsilon) + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + # total_params = sum(p.numel() for p in self.parameters()) + # trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + # print("Using p-tuning v2: # trainable_params = {} / {}".format(trainable_params, total_params)) + def get_input_embeddings(self): return self.word_embeddings def set_input_embeddings(self, new_embeddings: torch.Tensor): self.word_embeddings = new_embeddings - def get_masks(self, seq, device): - context_length = seq.index(self.config.bos_token_id) + 1 - - attention_mask = torch.ones((1, len(seq), len(seq)), device=device) - attention_mask.tril_() - attention_mask[..., :context_length - 1] = 1 - attention_mask.unsqueeze_(1) - attention_mask = (attention_mask < 0.5).bool() - - return attention_mask - - def get_position_ids(self, seq, mask_position, device, gmask=False): - context_length = seq.index(self.config.bos_token_id) + 1 - if self.position_encoding_2d: - seq_length = seq.index(self.config.bos_token_id) - position_ids = torch.arange(context_length, dtype=torch.long, device=device) - if not gmask: - position_ids[seq_length:] = mask_position - block_position_ids = torch.cat(( - torch.zeros(seq_length, dtype=torch.long, device=device), - torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1 - )) - position_ids = torch.stack((position_ids, block_position_ids), dim=0) - else: - position_ids = torch.arange(context_length, dtype=torch.long, device=device) - if not gmask: - position_ids[context_length - 1:] = mask_position - - position_ids = position_ids.unsqueeze(0) - - return position_ids + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.num_attention_heads, + self.hidden_size // self.num_attention_heads + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + # past_key_values = [(v[0], v[1]) for v in past_key_values] + return past_key_values @add_start_docstrings_to_model_forward(CHATGLM_6B_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( @@ -808,40 +907,62 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape[:2] + batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + if past_key_values is None: - past_key_values = tuple([None] * len(self.layers)) - seq = input_ids[0].tolist() + if self.pre_seq_len is not None: + past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device, + dtype=inputs_embeds.dtype) + else: + past_key_values = tuple([None] * len(self.layers)) if attention_mask is None: attention_mask = self.get_masks( - seq=seq, + input_ids, device=input_ids.device ) + if position_ids is None: - MASK, gMASK = 150000, 150001 - mask_token = MASK if MASK in input_ids else gMASK - use_gmask = False if MASK in input_ids else gMASK + MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id + seqs = input_ids.tolist() + + mask_positions, use_gmasks = [], [] + for seq in seqs: + mask_token = gMASK if gMASK in seq else MASK + use_gmask = mask_token == gMASK + mask_positions.append(seq.index(mask_token)) + use_gmasks.append(use_gmask) - mask_position = seq.index(mask_token) position_ids = self.get_position_ids( - seq=seq, - mask_position=mask_position, + input_ids, + mask_positions=mask_positions, device=input_ids.device, - gmask=use_gmask + use_gmasks=use_gmasks ) - if inputs_embeds is None: - inputs_embeds = self.word_embeddings(input_ids) + if self.pre_seq_len is not None and attention_mask is not None: + prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to( + attention_mask.device) + prefix_attention_mask = (prefix_attention_mask < 0.5).bool() + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3) # [seq_len, batch, hidden_size] hidden_states = inputs_embeds.transpose(0, 1) @@ -850,31 +971,38 @@ def forward( all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - seq_length_with_past = seq_length - past_key_values_length = 0 - if past_key_values[0] is not None: - past_key_values_length = past_key_values[0][0].shape[0] - seq_length_with_past = seq_length_with_past + past_key_values_length if attention_mask is None: attention_mask = torch.zeros(1, 1, device=input_ids.device).bool() - else: - attention_mask = attention_mask.to(input_ids.device) + attention_mask = attention_mask.to(hidden_states.device) for i, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - - layer_ret = layer( - hidden_states, - position_ids=position_ids, - attention_mask=attention_mask, - layer_id=torch.tensor(i), - layer_past=past_key_values[i], - use_cache=use_cache, - output_attentions=output_attentions - ) + layer_past = past_key_values[i] + + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + position_ids, + attention_mask, + torch.tensor(i), + layer_past, + use_cache, + output_attentions + ) + else: + layer_ret = layer( + hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + layer_id=torch.tensor(i), + layer_past=layer_past, + use_cache=use_cache, + output_attentions=output_attentions + ) hidden_states = layer_ret[0] @@ -902,8 +1030,12 @@ def forward( class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): - def __init__(self, config: ChatGLMConfig): + def __init__(self, config: ChatGLMConfig, empty_init=True): super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init # self.hidden_size = config.hidden_size # self.params_dtype = torch.half @@ -912,9 +1044,9 @@ def __init__(self, config: ChatGLMConfig): self.position_encoding_2d = config.position_encoding_2d - self.transformer = ChatGLMModel(config) + self.transformer = ChatGLMModel(config, empty_init=empty_init) - self.lm_head = skip_init( + self.lm_head = init_method( nn.Linear, config.hidden_size, config.vocab_size, @@ -927,7 +1059,7 @@ def __init__(self, config: ChatGLMConfig): self.quantized = False if self.config.quantization_bit: - self.quantize(self.config.quantization_bit, self.config.quantization_embeddings, use_quantization_cache=True, empty_init=True) + self.quantize(self.config.quantization_bit, empty_init=True) def get_output_embeddings(self): return self.lm_head @@ -935,31 +1067,40 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - def get_masks_and_position_ids(self, seq, mask_position, context_length, device, gmask=False): - attention_mask = torch.ones((1, context_length, context_length), device=device) - attention_mask.tril_() - attention_mask[..., :context_length - 1] = 1 - attention_mask.unsqueeze_(1) - attention_mask = (attention_mask < 0.5).bool() + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) - if self.position_encoding_2d: - seq_length = seq.index(self.config.bos_token_id) - position_ids = torch.arange(context_length, dtype=torch.long, device=device) - if not gmask: - position_ids[seq_length:] = mask_position - block_position_ids = torch.cat(( - torch.zeros(seq_length, dtype=torch.long, device=device), - torch.arange(context_length - seq_length, dtype=torch.long, device=device) + 1 - )) - position_ids = torch.stack((position_ids, block_position_ids), dim=0) - else: - position_ids = torch.arange(context_length, dtype=torch.long, device=device) - if not gmask: - position_ids[context_length - 1:] = mask_position + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + if attention_mask is not None and attention_mask.dtype == torch.bool: + attention_mask = torch.cat( + [attention_mask, attention_mask.new_ones((*attention_mask.shape[:3], 1))], dim=3) + new_attention_mask = attention_mask[:, :, -1:].clone() + new_attention_mask[..., -1] = False + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, new_attention_mask], dim=2 + ) - position_ids = position_ids.unsqueeze(0) + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id[:, 1, :] += 1 + model_kwargs["position_ids"] = torch.cat( + [position_ids, new_position_id], dim=-1 + ) - return attention_mask, position_ids + return model_kwargs def prepare_inputs_for_generation( self, @@ -967,27 +1108,37 @@ def prepare_inputs_for_generation( past: Optional[torch.Tensor] = None, past_key_values: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, **kwargs ) -> dict: - - MASK, gMASK = 150000, 150001 - mask_token = MASK if MASK in input_ids else gMASK - use_gmask = False if MASK in input_ids else gMASK - seq = input_ids[0].tolist() - mask_position = seq.index(mask_token) - - if mask_token not in seq: - raise ValueError("You have to add either [MASK] or [gMASK] in your input") + batch_size, seq_length = input_ids.shape + MASK, gMASK = self.config.mask_token_id, self.config.gmask_token_id + seqs = input_ids.tolist() + mask_positions, use_gmasks = [], [] + for seq in seqs: + mask_token = gMASK if gMASK in seq else MASK + use_gmask = mask_token == gMASK + mask_positions.append(seq.index(mask_token)) + use_gmasks.append(use_gmask) # only last token for input_ids if past is not None if past is not None or past_key_values is not None: - context_length = seq.index(self.config.bos_token_id) last_token = input_ids[:, -1].unsqueeze(-1) - if self.position_encoding_2d: - position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long, - device=input_ids.device) + if attention_mask is not None and attention_mask.dtype == torch.bool: + attention_mask = attention_mask[:, :, -1:] else: - position_ids = torch.tensor([[mask_position]], dtype=torch.long, device=input_ids.device) + attention_mask = None + if position_ids is not None: + position_ids = position_ids[..., -1:] + else: + context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs] + if self.position_encoding_2d: + position_ids = torch.tensor( + [[mask_position, seq_length - context_length] for mask_position, context_length in + zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1) + else: + position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long, + device=input_ids.device).unsqueeze(-1) if past is None: past = past_key_values @@ -995,15 +1146,24 @@ def prepare_inputs_for_generation( "input_ids": last_token, "past_key_values": past, "position_ids": position_ids, + "attention_mask": attention_mask } else: - attention_mask, position_ids = self.get_masks_and_position_ids( - seq=seq, - mask_position=mask_position, - context_length=len(seq), - device=input_ids.device, - gmask=use_gmask - ) + if attention_mask is not None and attention_mask.dtype != torch.bool: + logger.warning_once(f"The dtype of attention mask ({attention_mask.dtype}) is not bool") + attention_mask = None + if attention_mask is None: + attention_mask = self.get_masks( + input_ids, + device=input_ids.device + ) + if position_ids is None: + position_ids = self.get_position_ids( + input_ids, + device=input_ids.device, + mask_positions=mask_positions, + use_gmasks=use_gmasks + ) return { "input_ids": input_ids, @@ -1052,7 +1212,7 @@ def forward( shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss() + loss_fct = CrossEntropyLoss(ignore_index=-100) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) lm_logits = lm_logits.to(hidden_states.dtype) @@ -1078,6 +1238,7 @@ def _reorder_cache( This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct beam_idx at every generation step. + Output shares the same memory storage as `past`. """ return tuple( @@ -1120,10 +1281,10 @@ def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max for i, (old_query, response) in enumerate(history): prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) - input_ids = tokenizer([prompt], return_tensors="pt", padding=True) - input_ids = input_ids.to(self.device) - outputs = self.generate(**input_ids, **gen_kwargs) - outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):] + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + outputs = self.generate(**inputs, **gen_kwargs) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] response = tokenizer.decode(outputs) response = self.process_response(response) history = history + [(query, response)] @@ -1146,10 +1307,10 @@ def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = No for i, (old_query, response) in enumerate(history): prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) - input_ids = tokenizer([prompt], return_tensors="pt", padding=True) - input_ids = input_ids.to(self.device) - for outputs in self.stream_generate(**input_ids, **gen_kwargs): - outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):] + inputs = tokenizer([prompt], return_tensors="pt") + inputs = inputs.to(self.device) + for outputs in self.stream_generate(**inputs, **gen_kwargs): + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] response = tokenizer.decode(outputs) response = self.process_response(response) new_history = history + [(query, response)] @@ -1257,50 +1418,19 @@ def stream_generate( break yield input_ids - def quantize(self, bits: int, quantize_embeddings=False, use_quantization_cache=False, empty_init=False, **kwargs): + def quantize(self, bits: int, empty_init=False, **kwargs): if bits == 0: return - from .quantization import quantize, QuantizedEmbedding, QuantizedLinear, load_cpu_kernel + from .quantization import quantize if self.quantized: - if self.device == torch.device("cpu"): - logger.info("Already quantized, reloading cpu kernel.") - load_cpu_kernel(**kwargs) - else: - logger.info("Already quantized.") + logger.info("Already quantized.") return self self.quantized = True self.config.quantization_bit = bits - self.config.quantization_embeddings = quantize_embeddings - - self.transformer = quantize(self.transformer, bits, use_quantization_cache=use_quantization_cache, empty_init=empty_init, **kwargs) - - if quantize_embeddings: - logger.info("Applying quantization to embeddings") - self.transformer.word_embeddings = QuantizedEmbedding( - weight_bit_width=bits, - weight_tensor=self.transformer.word_embeddings.weight.to(self.device), - num_embeddings=self.transformer.word_embeddings.num_embeddings, - embedding_dim=self.transformer.word_embeddings.embedding_dim, - dtype=torch.half, - empty_init=True, - device=self.transformer.word_embeddings.weight.device, - ) - self.lm_head = QuantizedLinear( - weight_bit_width=bits, - weight_tensor=self.lm_head.weight.to(self.device), - bias_tensor=None, - in_features=self.lm_head.in_features, - out_features=self.lm_head.out_features, - bias=False, - quantized_weight=self.transformer.word_embeddings.weight, - quantized_weight_scale=self.transformer.word_embeddings.weight_scale, - dtype=torch.half, - empty_init=True, - device=self.lm_head.weight.device, - ) + self.transformer = quantize(self.transformer, bits, empty_init=empty_init, **kwargs) return self diff --git a/chatglm/quantization.py b/chatglm/quantization.py index 0f5f959..6f469f6 100644 --- a/chatglm/quantization.py +++ b/chatglm/quantization.py @@ -1,20 +1,20 @@ -from torch.nn import Linear, Embedding +from torch.nn import Linear from torch.nn.parameter import Parameter -import torch.nn.functional as F -import os import bz2 import torch import base64 import ctypes +from transformers.utils import logging from typing import List from functools import partial +logger = logging.get_logger(__name__) + try: from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up - class Kernel: def __init__(self, code: bytes, function_names: List[str]): self.code = code @@ -24,7 +24,6 @@ def __init__(self, code: bytes, function_names: List[str]): for name in self._function_names: setattr(self, name, KernelFunction(self._cmodule, name)) - quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ" kernels = Kernel( @@ -39,18 +38,18 @@ def __init__(self, code: bytes, function_names: List[str]): ) except Exception as exception: kernels = None - print("Failed to load cpm_kernels:", exception) + logger.warning("Failed to load cpm_kernels:" + str(exception)) class W8A16Linear(torch.autograd.Function): @staticmethod def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): ctx.inp_shape = inp.size() - ctx.weight_shape = quant_w.size() ctx.weight_bit_width = weight_bit_width out_features = quant_w.size(0) inp = inp.contiguous().view(-1, inp.size(-1)) weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) + ctx.weight_shape = weight.size() output = inp.mm(weight.t()) ctx.save_for_backward(inp, quant_w, scale_w) return output.view(*(ctx.inp_shape[:-1] + (out_features,))) @@ -62,165 +61,28 @@ def backward(ctx, grad_output: torch.Tensor): grad_output = grad_output.contiguous().view(-1, weight.size(0)) grad_input = grad_output.mm(weight) grad_weight = grad_output.t().mm(inp) - return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None - - -class W8A16LinearCPU(torch.autograd.Function): - @staticmethod - def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width, - quantization_cache=None): - ctx.inp_shape = inp.size() - ctx.weight_shape = quant_w.size() - ctx.weight_bit_width = weight_bit_width - out_features = quant_w.size(0) - inp = inp.contiguous().view(-1, inp.size(-1)) - weight = extract_weight_to_float(quant_w, scale_w, weight_bit_width, quantization_cache=quantization_cache) - output = inp.mm(weight.t()) - ctx.save_for_backward(inp, quant_w, scale_w) - return output.view(*(ctx.inp_shape[:-1] + (out_features,))) - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - inp, quant_w, scale_w = ctx.saved_tensors - weight = extract_weight_to_float(quant_w, scale_w, ctx.weight_bit_width) - grad_output = grad_output.contiguous().view(-1, weight.size(0)) - grad_input = grad_output.mm(weight) - grad_weight = grad_output.t().mm(inp) - return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None - - -default_cpu_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "quantization_kernels.c") -default_cpu_kernel_code = "QlpoOTFBWSZTWXLbSoQAAgzbgERwQXxmTwAAr/ff3kABt0Q2oRVT0hpo9RtEAAAAyBEiSQ9EGjQGQAAAwANGhowjJoNGmgMEUplMTNSMJ5TQaDJpsoMyRMj8P4mZzFSVVwqSXG8GG7MlVwiToYEQwVD7noBxMhNfkeZYtYFtbgOBUSIGtIQjhNHCEnPJsadhb3yBmRIOD3TeAtNLSaU5GgvKUBWSNuuOIHmVt0YhW6rsmDMDUjeUJGJ64R1Jm5lrh0Aa0tKjhFwPdWcGogxLDSXPWQUWTM8Sd3Qz1HMYNxx3HMeiNqNo4jeRDEfZ3gUSHIcU/heomq0vEzL1Msz5KKGxH8FrNOYw3KaxdqaEmNHYMxJFgQbR0DyRknL2L4kwUSxKRdhjRpEtUqilVfggFL1klaMS3PPRDfNqbBOPWO7m4JTVGhS9QTBDDJaEbLbrUQNB+IpJSKQbG5SZZ5gkwJEhJ3aYKJipZ/i7kinChIOW2lQg" -default_cpu_parallel_kernel_code_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), - "quantization_kernels_parallel.c") -default_cpu_parallel_kernel_code = "QlpoOTFBWSZTWZzWK2UAALXbgERwSX1mTwAAr/ff3kACNyXSbZYwBpoaNGIyAaADQwRRFT/UKDINANqAD1NABFQlPUzaaJHppGRmoAG01ARKKaaMp4gmgaNAaDQDIKVKfZ/g6v1Kem5ZsWZmZtSXS5ZwRAzKmjr1E1lKMEoQNCPkEYPACgcR5I9w/0k6JrJYHqFuHnChcD7N+DHeOQ0ajF83Tc40jgmQbOB5wt3TEHyTObDBLoxrJGBuJmNbxYZwAoKTjbIcI7GsbuVRERAR8wqwhXQjQOxiHQlgSnHjQjddXERojNmQYJJVoM2xxawMeI9asi6E1rfd7GO8S0S5vacCNGry4F1nyZbcTvSBXEMipuPfM7i0Y8kjirpbxb05jpIQjCGE8DYBNCAZyHz9EoOpDRST/I1aFCNpcjoXgyc3NjVsUvYIaYq7xopYJqcxg2g4qXofm7AaGNTzJSNguOQw4utKcEl0F1UOgI+T1hk5LusbGZ9udC1CiBeGwwFxR/QdbZDndehRPxyGt3Me1DBW45MXIY24ZD30aFNuSEUdu5LWx1sSJWLGgsmqUIFTgWhU0gfxXpzhghr2AYpV3hE06mGk1I2JyuZiFgkiz/i7kinChITmsVso" - -cpu_kernels = None - - -class CPUKernel: - def __init__(self, kernel_file="", source_code=default_cpu_kernel_code_path, compile_parallel_kernel=None, - parallel_num=None): - self.load = False - self.int8WeightExtractionFloat = None - self.int4WeightExtractionFloat = None - self.int4WeightCompression = None - self.SetNumThreads = None - - try: - if not os.path.exists(default_cpu_kernel_code_path): - with open(default_cpu_kernel_code_path, "w", encoding="utf-8") as file: - code = default_cpu_kernel_code - cpu_quantization_code = bz2.decompress(base64.b64decode(code)).decode() - file.write(cpu_quantization_code) - - if not os.path.exists(default_cpu_parallel_kernel_code_path): - with open(default_cpu_parallel_kernel_code_path, "w", encoding="utf-8") as file: - code = default_cpu_parallel_kernel_code - cpu_quantization_code = bz2.decompress(base64.b64decode(code)).decode() - file.write(cpu_quantization_code) - - except Exception as ex: - print("Error when generating default cpu kernel code(can be ignored when using custom kernels).") - - if compile_parallel_kernel is None: - compile_parallel_kernel = bool(int(os.cpu_count()) >= 4) - - if compile_parallel_kernel and source_code == default_cpu_kernel_code_path: - source_code = default_cpu_parallel_kernel_code_path - - if (not kernel_file) or (not os.path.exists(kernel_file)): - print("No compiled kernel found.") - try: - if os.path.exists(source_code): - print("Compiling kernels :", source_code) - kernel_file = source_code[:-2] + ".so" - if compile_parallel_kernel: - compile_command = "gcc -O3 -fPIC -pthread -fopenmp -std=c99 {} -shared -o {}".format( - source_code, kernel_file) - print("Compiling", compile_command) - exit_state = os.system(compile_command) - if exit_state: - print("Compile failed, using default cpu kernel code.") - compile_parallel_kernel = False - source_code = default_cpu_kernel_code_path - kernel_file = source_code[:-2] + ".so" - compile_command = "gcc -O3 -fPIC -std=c99 {} -shared -o {}".format(source_code, kernel_file) - print("Compiling", compile_command) - else: - compile_command = "gcc -O3 -fPIC -std=c99 {} -shared -o {}".format(source_code, kernel_file) - print("Compiling", compile_command) - exit_state = os.system(compile_command) - - print("Kernels compiled :", kernel_file) - else: - print("Kernel source code not found.") - return - except: - print("Failed to build kernel.") - return - if kernel_file: - kernels = ctypes.cdll.LoadLibrary(kernel_file) - self.int8WeightExtractionFloat = kernels.extract_int8_weight_to_float - self.int4WeightExtractionFloat = kernels.extract_int4_weight_to_float - self.int4WeightCompression = kernels.compress_int4_weight - if compile_parallel_kernel: - try: - self.SetNumThreads = kernels.set_num_threads - except: - print("No set_num_threads() found in kernel.") - self.SetNumThreads = lambda x: x - self.load = True - print("Load kernel :", kernel_file) - else: - print("Failed to load kernel.") - - if compile_parallel_kernel: - if parallel_num is None: - parallel_num = max(os.cpu_count() // 2, 1) - print("Setting CPU quantization kernel threads to", parallel_num) - if parallel_num < 4: - print("Parallel kernel is not recommended when parallel num < 4.") - self.SetNumThreads(parallel_num) - - self.parallel_num = parallel_num + return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None def compress_int4_weight(weight: torch.Tensor): # (n, m) - """compress weight on cpu or cuda to int4""" - if weight.device == torch.device("cpu"): - assert isinstance(cpu_kernels, CPUKernel) + with torch.cuda.device(weight.device): n, m = weight.size(0), weight.size(1) assert m % 2 == 0 m = m // 2 - out = torch.empty(n, m, dtype=torch.int8, device="cpu") - cpu_kernels.int4WeightCompression( - ctypes.c_void_p(weight.data_ptr()), - ctypes.c_void_p(out.data_ptr()), - ctypes.c_int32(n), - ctypes.c_int32(m) - ) - return out - else: - with torch.cuda.device(weight.device): - n, m = weight.size(0), weight.size(1) - assert m % 2 == 0 - m = m // 2 - out = torch.empty(n, m, dtype=torch.int8, device="cuda") - stream = torch.cuda.current_stream() + out = torch.empty(n, m, dtype=torch.int8, device="cuda") + stream = torch.cuda.current_stream() - gridDim = (n, 1, 1) - blockDim = (min(round_up(m, 32), 1024), 1, 1) + gridDim = (n, 1, 1) + blockDim = (min(round_up(m, 32), 1024), 1, 1) - kernels.int4WeightCompression( - gridDim, - blockDim, - 0, - stream, - [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), - ctypes.c_int32(m)], - ) - return out + kernels.int4WeightCompression( + gridDim, + blockDim, + 0, + stream, + [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)], + ) + return out def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int): @@ -255,237 +117,85 @@ def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, sourc return out -def extract_weight_to_float(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int, - quantization_cache=None): - """extract weight on cpu to float32""" - if source_bit_width == 8: - func = cpu_kernels.int8WeightExtractionFloat - elif source_bit_width == 4: - func = cpu_kernels.int4WeightExtractionFloat - else: - assert False, "Unsupported bit-width" - - n, m = weight.size(0), weight.size(1) - - if quantization_cache is not None: - out = quantization_cache - func( - ctypes.c_void_p(weight.data_ptr()), - ctypes.c_void_p(scale_list.data_ptr()), - ctypes.c_void_p(out.data_ptr()), - ctypes.c_int32(n), - ctypes.c_int32(m) - ) - return out.tensor - else: - out = torch.empty(n, m * (8 // source_bit_width), dtype=torch.float, device="cpu") - func( - ctypes.c_void_p(weight.data_ptr()), - ctypes.c_void_p(scale_list.data_ptr()), - ctypes.c_void_p(out.data_ptr()), - ctypes.c_int32(n), - ctypes.c_int32(m) - ) - return out - - -class CacheTensor(): - def __init__(self, *args, **kwargs): - self.tensor = torch.empty(*args, **kwargs) - - def to(self, *args, **kwargs): - self.tensor = self.tensor.to(*args, **kwargs) - - def data_ptr(self): - return self.tensor.data_ptr() - - class QuantizedLinear(Linear): - def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, quantized_weight=None, - quantized_weight_scale=None, quantization_cache=None, empty_init=False, *args, **kwargs): + def __init__(self, weight_bit_width: int, weight_tensor=None, bias_tensor=None, empty_init=False, *args, **kwargs): super(QuantizedLinear, self).__init__(*args, **kwargs) self.weight_bit_width = weight_bit_width - self.quantization_cache = quantization_cache - if (quantized_weight is not None) and (quantized_weight_scale is not None): - del self.weight - self.weight = Parameter(quantized_weight.to(kwargs["device"]), requires_grad=False) - self.weight_scale = Parameter(quantized_weight_scale.to(kwargs["device"]), requires_grad=False) - else: - shape = self.weight.shape - del self.weight - - if weight_tensor is None or empty_init: - self.weight = torch.empty( - shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"] - ) - self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"]) - else: - self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).to( - kwargs["dtype"]) - self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8) - if weight_bit_width == 4: - self.weight = compress_int4_weight(self.weight) + shape = self.weight.shape + del self.weight - self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) - self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) + if weight_tensor is None or empty_init: + self.weight = torch.empty( + shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"] + ) + self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"]) + else: + self.weight_scale = (weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half() + self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8) + if weight_bit_width == 4: + self.weight = compress_int4_weight(self.weight) + self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) + self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) if bias_tensor is not None: self.bias = Parameter(bias_tensor.to(kwargs["device"]), requires_grad=False) else: self.bias = None - def reset_parameters(self): - """To accelerate initialization""" - pass - def forward(self, input): - if self.weight.device == torch.device("cpu"): - output = W8A16LinearCPU.apply(input, self.weight, self.weight_scale, self.weight_bit_width, - self.quantization_cache) - else: - output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width) + output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width) if self.bias is not None: output = output + self.bias return output - def _apply(self, fn): - self_obj = super()._apply(fn) - if self.quantization_cache is not None: - self.quantization_cache.to(self_obj.weight.device) - self.quantization_cache.to(self_obj.weight_scale.dtype) - return self_obj - - -class QuantizedEmbedding(Embedding): # TODO: backward, check empty_init - def __init__(self, weight_bit_width: int, weight_tensor=None, quantized_weight=None, quantized_weight_scale=None, - empty_init=False, *args, **kwargs): - super(QuantizedEmbedding, self).__init__(*args, **kwargs) - self.weight_bit_width = weight_bit_width - - if (quantized_weight is not None) and (quantized_weight_scale is not None): - del self.weight - self.weight = Parameter(quantized_weight.to(kwargs["device"]), requires_grad=False) - self.weight_scale = Parameter(quantized_weight_scale.to(kwargs["device"]), requires_grad=False) - else: - shape = self.weight.shape - del self.weight - - if weight_tensor is None or empty_init: - self.weight = torch.empty( - shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=kwargs["device"] - ) - self.weight_scale = torch.empty(shape[0], dtype=kwargs["dtype"], device=kwargs["device"]) - else: - self.weight_scale = ( - weight_tensor.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half() - self.weight = torch.round(weight_tensor / self.weight_scale[:, None]).to(torch.int8) - if weight_bit_width == 4: - self.weight = compress_int4_weight(self.weight) - - self.weight = Parameter(self.weight.to(kwargs["device"]), requires_grad=False) - self.weight_scale = Parameter(self.weight_scale.to(kwargs["device"]), requires_grad=False) - - def forward(self, input): - if self.weight.device == torch.device("cpu"): - original_weight = extract_weight_to_float(weight=self.weight, scale_list=self.weight_scale, - source_bit_width=self.weight_bit_width) - else: - original_weight = extract_weight_to_half(weight=self.weight, scale_list=self.weight_scale, - source_bit_width=self.weight_bit_width) - output = F.embedding( - input, original_weight, self.padding_idx, self.max_norm, - self.norm_type, self.scale_grad_by_freq, self.sparse - ) - return output - -def load_cpu_kernel(**kwargs): - global cpu_kernels - cpu_kernels = CPUKernel(**kwargs) - assert cpu_kernels.load - - -def quantize(model, weight_bit_width, use_quantization_cache=False, empty_init=False, **kwargs): +def quantize(model, weight_bit_width, empty_init=False, **kwargs): """Replace fp16 linear with quantized linear""" - query_key_value_quantization_cache = None - dense_quantization_cache = None - dense_h_to_4h_quantization_cache = None - dense_4h_to_h_quantization_cache = None - - try: - load_cpu_kernel(**kwargs) - except: - print("Cannot load cpu kernel, don't use quantized model on cpu.") - if kernels is None: # CUDA kernels failed - print("Cannot load cuda kernel, quantization failed.") - return model - - current_device = model.device - - if model.device == torch.device("cpu"): - dtype = torch.float32 - else: - dtype = torch.half - - QuantizedLinearWithPara = partial( - QuantizedLinear, - weight_bit_width=weight_bit_width, - bias=True, - dtype=dtype, - empty_init=empty_init - ) - - if use_quantization_cache: - print("Using quantization cache") - layer = model.layers[0] - weight = layer.attention.query_key_value.weight - n, m = weight.size(0), weight.size(1) - query_key_value_quantization_cache = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False) - weight = layer.attention.dense.weight - n, m = weight.size(0), weight.size(1) - dense_quantization_cache = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False) - weight = layer.mlp.dense_h_to_4h.weight - n, m = weight.size(0), weight.size(1) - dense_h_to_4h_quantization_cache = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False) - weight = layer.mlp.dense_4h_to_h.weight - n, m = weight.size(0), weight.size(1) - dense_4h_to_h_quantization_cache = CacheTensor(n, m, dtype=dtype, device=current_device, requires_grad=False) - - print("Applying quantization to glm layers") - for layer in model.layers: - layer.attention.query_key_value = QuantizedLinearWithPara( - weight_tensor=layer.attention.query_key_value.weight.to(current_device), + layer.attention.query_key_value = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.attention.query_key_value.weight.to(torch.cuda.current_device()), bias_tensor=layer.attention.query_key_value.bias, in_features=layer.attention.query_key_value.in_features, out_features=layer.attention.query_key_value.out_features, + bias=True, + dtype=torch.half, device=layer.attention.query_key_value.weight.device, - quantization_cache=query_key_value_quantization_cache + empty_init=empty_init ) - layer.attention.dense = QuantizedLinearWithPara( - weight_tensor=layer.attention.dense.weight.to(current_device), + layer.attention.dense = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.attention.dense.weight.to(torch.cuda.current_device()), bias_tensor=layer.attention.dense.bias, in_features=layer.attention.dense.in_features, out_features=layer.attention.dense.out_features, + bias=True, + dtype=torch.half, device=layer.attention.dense.weight.device, - quantization_cache=dense_quantization_cache + empty_init=empty_init ) - layer.mlp.dense_h_to_4h = QuantizedLinearWithPara( - weight_tensor=layer.mlp.dense_h_to_4h.weight.to(current_device), + layer.mlp.dense_h_to_4h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.mlp.dense_h_to_4h.weight.to(torch.cuda.current_device()), bias_tensor=layer.mlp.dense_h_to_4h.bias, in_features=layer.mlp.dense_h_to_4h.in_features, out_features=layer.mlp.dense_h_to_4h.out_features, + bias=True, + dtype=torch.half, device=layer.mlp.dense_h_to_4h.weight.device, - quantization_cache=dense_h_to_4h_quantization_cache + empty_init=empty_init ) - layer.mlp.dense_4h_to_h = QuantizedLinearWithPara( - weight_tensor=layer.mlp.dense_4h_to_h.weight.to(current_device), + layer.mlp.dense_4h_to_h = QuantizedLinear( + weight_bit_width=weight_bit_width, + weight_tensor=layer.mlp.dense_4h_to_h.weight.to(torch.cuda.current_device()), bias_tensor=layer.mlp.dense_4h_to_h.bias, in_features=layer.mlp.dense_4h_to_h.in_features, out_features=layer.mlp.dense_4h_to_h.out_features, + bias=True, + dtype=torch.half, device=layer.mlp.dense_4h_to_h.weight.device, - quantization_cache=dense_4h_to_h_quantization_cache + empty_init=empty_init ) return model diff --git a/chatglm/tokenization_chatglm.py b/chatglm/tokenization_chatglm.py index 957a6a8..1d4f0ba 100644 --- a/chatglm/tokenization_chatglm.py +++ b/chatglm/tokenization_chatglm.py @@ -1,17 +1,13 @@ """Tokenization classes for ChatGLM.""" -import sys -import unicodedata from typing import List, Optional, Union -from functools import lru_cache import os -import collections -import re from transformers.tokenization_utils import PreTrainedTokenizer -from icetk.text_tokenizer import TextTokenizer -from icetk.utils import auto_create -import icetk.sentencepiece_model_pb2 as sp_model -from transformers.utils import logging +from transformers.utils import logging, PaddingStrategy +from transformers.tokenization_utils_base import EncodedInput, BatchEncoding +from typing import Dict +import sentencepiece as spm +import numpy as np logger = logging.get_logger(__name__) @@ -20,61 +16,52 @@ } +class TextTokenizer: + def __init__(self, model_path): + self.sp = spm.SentencePieceProcessor() + self.sp.Load(model_path) + self.num_tokens = self.sp.vocab_size() + + def encode(self, text): + return self.sp.EncodeAsIds(text) + + def decode(self, ids: List[int]): + return self.sp.DecodeIds(ids) + + def tokenize(self, text): + return self.sp.EncodeAsPieces(text) + + def convert_tokens_to_ids(self, tokens): + return [self.sp.PieceToId(token) for token in tokens] + + def convert_token_to_id(self, token): + return self.sp.PieceToId(token) + + def convert_id_to_token(self, idx): + return self.sp.IdToPiece(idx) + + def __len__(self): + return self.num_tokens + + class SPTokenizer: def __init__( - self, - vocab_file, - max_blank_length=80, - byte_fallback=True, + self, + vocab_file, + num_image_tokens=20000, + max_blank_length=80, + byte_fallback=True, ): assert vocab_file is not None self.vocab_file = vocab_file + self.num_image_tokens = num_image_tokens self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "", "", "", "", ""] self.max_blank_length = max_blank_length self.byte_fallback = byte_fallback - self.text_tokenizer = self._build_text_tokenizer(encode_special_tokens=False) - self.special_text_tokenizer = self._build_text_tokenizer(encode_special_tokens=True) + self.text_tokenizer = TextTokenizer(vocab_file) - @staticmethod - def _configure_tokenizer( - text_tokenizer: TextTokenizer, - special_tokens: List[str], - max_blank_length: int, - byte_fallback: bool, - encode_special_tokens=False, - ): - # special token - special_token_type = 4 if encode_special_tokens else 3 # 3 - CONTROL, 4 - USER_DEFINE - for token in special_tokens: - text_tokenizer.proto.pieces.append( - sp_model.ModelProto.SentencePiece(piece=token, score=0.0, type=special_token_type) - ) - # whitespaces - for token in [SPTokenizer.get_tab_token()] + [ - SPTokenizer.get_blank_token(i) for i in range(2, max_blank_length + 1) - ]: - text_tokenizer.proto.pieces.append(sp_model.ModelProto.SentencePiece(piece=token, score=0.0, type=4)) - # byte fallback - if byte_fallback: - text_tokenizer.proto.trainer_spec.byte_fallback = True - for i in range(256): - text_tokenizer.proto.pieces.append( - sp_model.ModelProto.SentencePiece(piece="<0x{:02X}>".format(i), score=0.0, type=6) - ) - text_tokenizer.refresh() - - def _build_text_tokenizer(self, encode_special_tokens=False): - tokenizer = TextTokenizer(self.vocab_file) - self._configure_tokenizer( - tokenizer, self.special_tokens, self.max_blank_length, self.byte_fallback, encode_special_tokens - ) - return tokenizer - - def _get_text_tokenizer(self, encode_special_tokens=False): - if encode_special_tokens: - return self.special_text_tokenizer - else: - return self.text_tokenizer + def _get_text_tokenizer(self): + return self.text_tokenizer @staticmethod def get_blank_token(length: int): @@ -85,10 +72,6 @@ def get_blank_token(length: int): def get_tab_token(): return f"<|tab|>" - @property - def num_image_tokens(self): - return 20000 - @property def num_text_tokens(self): return self.text_tokenizer.num_tokens @@ -112,7 +95,7 @@ def _preprocess(self, text: str, linebreak=True, whitespaces=True): return text def encode( - self, text: str, linebreak=True, whitespaces=True, special_tokens=False, add_dummy_prefix=True + self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True ) -> List[int]: """ @param text: Text to encode. @@ -124,14 +107,14 @@ def encode( text = self._preprocess(text, linebreak, whitespaces) if not add_dummy_prefix: text = "" + text - tmp = self._get_text_tokenizer(encode_special_tokens=special_tokens).encode(text) + tmp = self._get_text_tokenizer().encode(text) tokens = [x + self.num_image_tokens for x in tmp] return tokens if add_dummy_prefix else tokens[2:] - def decode(self, text_ids: List[int], special_tokens=False) -> str: + def decode(self, text_ids: List[int]) -> str: ids = [int(_id) - self.num_image_tokens for _id in text_ids] ids = [_id for _id in ids if _id >= 0] - text = self._get_text_tokenizer(encode_special_tokens=special_tokens).decode(ids) + text = self._get_text_tokenizer().decode(ids) text = text.replace("", "\n") text = text.replace(SPTokenizer.get_tab_token(), "\t") for i in range(2, self.max_blank_length + 1): @@ -139,7 +122,7 @@ def decode(self, text_ids: List[int], special_tokens=False) -> str: return text def tokenize( - self, text: str, linebreak=True, whitespaces=True, special_tokens=False, add_dummy_prefix=True + self, text: str, linebreak=True, whitespaces=True, add_dummy_prefix=True ) -> List[str]: """ @param text: Text to encode. @@ -151,7 +134,7 @@ def tokenize( text = self._preprocess(text, linebreak, whitespaces) if not add_dummy_prefix: text = "" + text - tokens = self._get_text_tokenizer(encode_special_tokens=special_tokens).tokenize(text) + tokens = self._get_text_tokenizer().tokenize(text) return tokens if add_dummy_prefix else tokens[2:] def __getitem__(self, x: Union[int, str]): @@ -172,6 +155,7 @@ def __getitem__(self, x: Union[int, str]): class ChatGLMTokenizer(PreTrainedTokenizer): """ Construct a ChatGLM tokenizer. Based on byte-level Byte-Pair-Encoding. + Args: vocab_file (`str`): Path to the vocabulary file. @@ -179,25 +163,36 @@ class ChatGLMTokenizer(PreTrainedTokenizer): vocab_files_names = {"vocab_file": "ice_text.model"} max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - model_input_names = ["input_ids"] + model_input_names = ["input_ids", "attention_mask", "position_ids"] def __init__( self, vocab_file, do_lower_case=False, remove_space=False, - bos_token='sop', - eos_token='eos', - eop_token='eop', + bos_token='', + eos_token='', + end_token='', mask_token='[MASK]', gmask_token='[gMASK]', padding_side="left", + pad_token="", + unk_token="", + num_image_tokens=20000, **kwargs ) -> None: super().__init__( do_lower_case=do_lower_case, remove_space=remove_space, padding_side=padding_side, + bos_token=bos_token, + eos_token=eos_token, + end_token=end_token, + mask_token=mask_token, + gmask_token=gmask_token, + pad_token=pad_token, + unk_token=unk_token, + num_image_tokens=num_image_tokens, **kwargs ) @@ -207,23 +202,29 @@ def __init__( self.bos_token = bos_token self.eos_token = eos_token - self.eop_token = eop_token + self.end_token = end_token self.mask_token = mask_token - self.gMASK_token = gmask_token + self.gmask_token = gmask_token - self.sp_tokenizer = SPTokenizer(vocab_file) + self.sp_tokenizer = SPTokenizer(vocab_file, num_image_tokens=num_image_tokens) """ Initialisation """ @property - def eop_token_id(self) -> Optional[int]: + def gmask_token_id(self) -> Optional[int]: + if self.gmask_token is None: + return None + return self.convert_tokens_to_ids(self.gmask_token) + + @property + def end_token_id(self) -> Optional[int]: """ - `Optional[int]`: Id of the end of sentence token in the vocabulary. Returns `None` if the token has not been + `Optional[int]`: Id of the end of context token in the vocabulary. Returns `None` if the token has not been set. """ - if self.eop_token is None: + if self.end_token is None: return None - return self.convert_tokens_to_ids(self.eop_token) + return self.convert_tokens_to_ids(self.end_token) @property def vocab_size(self): @@ -255,25 +256,20 @@ def _tokenize(self, text, **kwargs): return seq - def decode( + def _decode( self, - token_ids: Union[List[int], List[List[int]]], + token_ids: Union[int, List[int]], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool = True, - spaces_between_special_tokens: bool = True, **kwargs ) -> str: - if isinstance(token_ids[0], list): - tokens = [] - for single_token_ids in token_ids: - if self.pad_token_id in single_token_ids: # remove pad - single_token_ids = list(filter((self.pad_token_id).__ne__, single_token_ids)) - tokens.append(self.sp_tokenizer.decode(single_token_ids)) - return (tokens) - else: - if self.pad_token_id in token_ids: # remove pad - token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) - return self.sp_tokenizer.decode(token_ids) + if isinstance(token_ids, int): + token_ids = [token_ids] + if len(token_ids) == 0: + return "" + if self.pad_token_id in token_ids: # remove pad + token_ids = list(filter((self.pad_token_id).__ne__, token_ids)) + return self.sp_tokenizer.decode(token_ids) def _convert_token_to_id(self, token): """ Converts a token (str) in an id using the vocab. """ @@ -286,11 +282,13 @@ def _convert_id_to_token(self, index): def save_vocabulary(self, save_directory, filename_prefix=None): """ Save the vocabulary and special tokens file to a directory. + Args: save_directory (`str`): The directory in which to save the vocabulary. filename_prefix (`str`, *optional*): An optional prefix to add to the named of the saved files. + Returns: `Tuple(str)`: Paths to the files saved. """ @@ -315,26 +313,118 @@ def build_inputs_with_special_tokens( """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. A BERT sequence has the following format: + - single sequence: `[CLS] X [SEP]` - pair of sequences: `[CLS] A [SEP] B [SEP]` + Args: token_ids_0 (`List[int]`): List of IDs to which the special tokens will be added. token_ids_1 (`List[int]`, *optional*): Optional second list of IDs for sequence pairs. + Returns: `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. """ + gmask_id = self.sp_tokenizer[self.gmask_token] + eos_id = self.sp_tokenizer[self.eos_token] + token_ids_0 = token_ids_0 + [gmask_id, self.sp_tokenizer[self.bos_token]] if token_ids_1 is not None: - token_ids_0 += token_ids_1 - mask_ids = self.sp_tokenizer[self.mask_token] - gmask_ids = self.sp_tokenizer[self.gMASK_token] - if mask_ids not in token_ids_0 and gmask_ids not in token_ids_0: - token_ids_0 += [gmask_ids] - - if token_ids_0[-1] != mask_ids and token_ids_0[-1] != gmask_ids: - token_ids_0 += [self.sp_tokenizer[self.eos_token]] + token_ids_0 = token_ids_0 + token_ids_1 + [eos_id] + return token_ids_0 - token_ids_0 += [self.sp_tokenizer[self.bos_token]] + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) - return token_ids_0 + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + bos_token_id = self.sp_tokenizer[self.bos_token] + mask_token_id = self.sp_tokenizer[self.mask_token] + gmask_token_id = self.sp_tokenizer[self.gmask_token] + assert self.padding_side == "left" + + required_input = encoded_inputs[self.model_input_names[0]] + seq_length = len(required_input) + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if max_length is not None: + if "attention_mask" not in encoded_inputs: + if bos_token_id in required_input: + context_length = required_input.index(bos_token_id) + else: + context_length = seq_length + attention_mask = np.ones((1, seq_length, seq_length)) + attention_mask = np.tril(attention_mask) + attention_mask[:, :, :context_length] = 1 + attention_mask = np.bool_(attention_mask < 0.5) + encoded_inputs["attention_mask"] = attention_mask + + if "position_ids" not in encoded_inputs: + if bos_token_id in required_input: + context_length = required_input.index(bos_token_id) + else: + context_length = seq_length + position_ids = np.arange(seq_length, dtype=np.int64) + mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id + if mask_token in required_input: + mask_position = required_input.index(mask_token) + position_ids[context_length:] = mask_position + block_position_ids = np.concatenate( + [np.zeros(context_length, dtype=np.int64), + np.arange(1, seq_length - context_length + 1, dtype=np.int64)]) + encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0) + + if needs_to_be_padded: + difference = max_length - len(required_input) + + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = np.pad(encoded_inputs["attention_mask"], + pad_width=[(0, 0), (difference, 0), (difference, 0)], + mode='constant', constant_values=True) + if "token_type_ids" in encoded_inputs: + encoded_inputs["token_type_ids"] = [self.pad_token_type_id] * difference + encoded_inputs[ + "token_type_ids" + ] + if "special_tokens_mask" in encoded_inputs: + encoded_inputs["special_tokens_mask"] = [1] * difference + encoded_inputs["special_tokens_mask"] + if "position_ids" in encoded_inputs: + encoded_inputs["position_ids"] = np.pad(encoded_inputs["position_ids"], + pad_width=[(0, 0), (difference, 0)]) + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + + return encoded_inputs From a4ddc56aeaf59cd0765c582076d999a3e875d67e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 11:16:15 +0800 Subject: [PATCH 8/9] Update chatglm_predictor.py --- predictors/chatglm_predictor.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/predictors/chatglm_predictor.py b/predictors/chatglm_predictor.py index d5f20a9..63e84dc 100644 --- a/predictors/chatglm_predictor.py +++ b/predictors/chatglm_predictor.py @@ -10,7 +10,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor): - def __init__(self, start_pos=20005): + def __init__(self, start_pos=5): self.start_pos = start_pos def __call__(self, input_ids: torch.LongTensor, @@ -29,7 +29,15 @@ def __init__(self, model_name): self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, resume_download=True) - if 'slim' in model_name: + if self.device == 'cpu': + from chatglm.modeling_chatglm import ChatGLMForConditionalGeneration + model = ChatGLMForConditionalGeneration.from_pretrained( + model_name, + trust_remote_code=True, + resume_download=True, + torch_dtype=torch.float32, + device_map={'': self.device}) + elif 'slim' in model_name: model = AutoModel.from_pretrained( model_name, trust_remote_code=True, resume_download=True) @@ -83,8 +91,7 @@ def stream_chat_continue(self, else: answer = '' logits_processor.append( - InvalidScoreLogitsProcessor( - start_pos=20005 if 'slim' not in self.model_name else 5)) + InvalidScoreLogitsProcessor(5)) gen_kwargs = { "max_length": max_length, "do_sample": do_sample, From ccf77d253ff0700eaf93a7902aca6ad4602c4916 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=A8=E5=9F=B9=E6=96=87?= <915505626@qq.com> Date: Sun, 23 Apr 2023 11:36:36 +0800 Subject: [PATCH 9/9] Update chatglm_predictor.py --- predictors/chatglm_predictor.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/predictors/chatglm_predictor.py b/predictors/chatglm_predictor.py index 63e84dc..38f652b 100644 --- a/predictors/chatglm_predictor.py +++ b/predictors/chatglm_predictor.py @@ -29,22 +29,14 @@ def __init__(self, model_name): self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.tokenizer = AutoTokenizer.from_pretrained( model_name, trust_remote_code=True, resume_download=True) - if self.device == 'cpu': - from chatglm.modeling_chatglm import ChatGLMForConditionalGeneration - model = ChatGLMForConditionalGeneration.from_pretrained( - model_name, - trust_remote_code=True, - resume_download=True, - torch_dtype=torch.float32, - device_map={'': self.device}) - elif 'slim' in model_name: + if 'slim' in model_name: model = AutoModel.from_pretrained( model_name, trust_remote_code=True, resume_download=True) if self.device == 'cuda': model = model.half().to(self.device) else: - model = model.to(self.device) + model = model.float() elif 'int4' in model_name: model = AutoModel.from_pretrained( model_name, trust_remote_code=True, @@ -52,7 +44,7 @@ def __init__(self, model_name): if self.device == 'cuda': model = model.half().to(self.device) else: - model = model.to(self.device) + model = model.float() else: model = AutoModel.from_pretrained( model_name, @@ -62,6 +54,8 @@ def __init__(self, model_name): torch_dtype=torch.float16 if self.device == 'cuda' else torch.float32, device_map={'': self.device}) + if self.device == 'cpu': + model = model.float() model = model.eval() self.model = model self.model_name = model_name