diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1925d96 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +tmp*/ +build/ +*.bmodel +.vscode +*.npz diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..23cb245 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "models/Qwen/demo/third_party/abseil-cpp"] + path = models/Qwen/demo/third_party/abseil-cpp + url = https://github.com/abseil/abseil-cpp.git +[submodule "models/Qwen/demo/third_party/re2"] + path = models/Qwen/demo/third_party/re2 + url = https://github.com/google/re2.git diff --git a/models/Qwen/compile/files/Qwen-14B-Chat/config.json b/models/Qwen/compile/files/Qwen-14B-Chat/config.json new file mode 100755 index 0000000..9acee3f --- /dev/null +++ b/models/Qwen/compile/files/Qwen-14B-Chat/config.json @@ -0,0 +1,37 @@ +{ + "architectures": [ + "QWenLMHeadModel" + ], + "auto_map": { + "AutoConfig": "configuration_qwen.QWenConfig", + "AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel" + }, + "attn_dropout_prob": 0.0, + "bf16": true, + "emb_dropout_prob": 0.0, + "fp16": false, + "fp32": false, + "hidden_size": 5120, + "intermediate_size": 27392, + "initializer_range": 0.02, + "kv_channels": 128, + "layer_norm_epsilon": 1e-06, + "max_position_embeddings": 8192, + "model_type": "qwen", + "no_bias": true, + "num_attention_heads": 40, + "num_hidden_layers": 40, + "onnx_safe": null, + "rotary_emb_base": 10000, + "rotary_pct": 1.0, + "scale_attn_weights": true, + "seq_length": 2048, + "tie_word_embeddings": false, + "tokenizer_class": "QWenTokenizer", + "transformers_version": "4.32.0", + "use_cache": true, + "use_dynamic_ntk": true, + "use_flash_attn": "auto", + "use_logn_attn": true, + "vocab_size": 152064 +} \ No newline at end of file diff --git a/models/Qwen/compile/files/Qwen-14B-Chat/modeling_qwen.py b/models/Qwen/compile/files/Qwen-14B-Chat/modeling_qwen.py new file mode 100755 index 0000000..69a70b0 --- /dev/null +++ b/models/Qwen/compile/files/Qwen-14B-Chat/modeling_qwen.py @@ -0,0 +1,1346 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import importlib +import math +import pathlib +from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import warnings + +from torch.nn import CrossEntropyLoss +from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList +from transformers.generation.logits_process import LogitsProcessorList + +if TYPE_CHECKING: + from transformers.generation.streamers import BaseStreamer +from transformers.generation.utils import GenerateOutput +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +try: + from einops import rearrange +except ImportError: + rearrange = None +from torch import nn + +SUPPORT_CUDA = torch.cuda.is_available() +SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() +SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 +SUPPORT_TORCH2 = False #hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2 + + +from .configuration_qwen import QWenConfig +from .qwen_generation_utils import ( + HistoryType, + make_context, + decode_tokens, + get_stop_words_ids, + StopWordsLogitsProcessor, +) + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "qwen" +_CONFIG_FOR_DOC = "QWenConfig" + +QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"] + +_ERROR_BAD_CHAT_FORMAT = """\ +We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml". +If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat(). +我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。 +如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。 +""" + +_SENTINEL = object() +_ERROR_STREAM_IN_CHAT = """\ +Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True). +向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。 +""" + +_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED = """\ +We detect you have activated flash attention support, but running model computation on CPU. Please make sure that your input data has been placed on GPU. If you actually want to run CPU computation, please following the readme and set device_map="cpu" to disable flash attention when loading the model (calling AutoModelForCausalLM.from_pretrained). +检测到您的模型已激活了flash attention支持,但正在执行CPU运算任务。如使用flash attention,请您确认模型输入已经传到GPU上。如果您确认要执行CPU运算,请您在载入模型(调用AutoModelForCausalLM.from_pretrained)时,按照readme说法,指定device_map="cpu"以禁用flash attention。 +""" + +apply_rotary_emb_func = None +rms_norm = None +flash_attn_unpadded_func = None +flash_attn_func = None + +def _import_flash_attn(): + global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func, flash_attn_func + try: + from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func + apply_rotary_emb_func = __apply_rotary_emb_func + except ImportError: + logger.warn( + "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency " + "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary" + ) + + try: + from flash_attn.ops.rms_norm import rms_norm as __rms_norm + rms_norm = __rms_norm + except ImportError: + logger.warn( + "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency " + "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm" + ) + + try: + import flash_attn + _flash_attn_func = None + if not hasattr(flash_attn, '__version__'): + from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func + else: + if int(flash_attn.__version__.split(".")[0]) >= 2: + if int(flash_attn.__version__.split(".")[1]) >= 1: + from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func + from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func + else: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func + flash_attn_unpadded_func = __flash_attn_unpadded_func + flash_attn_func = _flash_attn_func + except ImportError: + logger.warn( + "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency " + "https://github.com/Dao-AILab/flash-attention" + ) + +def quantize_cache_v(fdata, bits, qmax, qmin): + # b, s, head, h-dim->b, head, s, h-dim + qtype = torch.uint8 + device = fdata.device + shape = fdata.shape + + fdata_cal = torch.flatten(fdata, 2) + fmax = torch.amax(fdata_cal, dim=-1, keepdim=True) + fmin = torch.amin(fdata_cal, dim=-1, keepdim=True) + # Compute params + if qmax.device != fmax.device: + qmax = qmax.to(device) + qmin = qmin.to(device) + scale = (fmax - fmin) / (qmax - qmin) + zero = qmin - fmin / scale + scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() + zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() + # Quantize + res_data = fdata / scale + zero + qdata = torch.clamp(res_data, qmin, qmax).to(qtype) + return qdata.contiguous(), scale, zero + +def dequantize_cache_torch(qdata, scale, zero): + data = scale * (qdata - zero) + return data + +class FlashSelfAttention(torch.nn.Module): + def __init__( + self, + causal=False, + softmax_scale=None, + attention_dropout=0.0, + ): + super().__init__() + assert flash_attn_unpadded_func is not None, ( + "Please install FlashAttention first, " "e.g., with pip install flash-attn" + ) + assert ( + rearrange is not None + ), "Please install einops first, e.g., with pip install einops" + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def unpad_input(self, hidden_states, attention_mask): + valid_mask = attention_mask.squeeze(1).squeeze(1).eq(0) + seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + hidden_states = hidden_states[indices] + return hidden_states, indices, cu_seqlens, max_seqlen_in_batch + + def pad_input(self, hidden_states, indices, batch, seqlen): + output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device, + dtype=hidden_states.dtype) + output[indices] = hidden_states + return rearrange(output, '(b s) ... -> b s ...', b=batch) + + def forward(self, q, k, v, attention_mask=None): + assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v))) + assert all((i.is_cuda for i in (q, k, v))) + batch_size, seqlen_q = q.shape[0], q.shape[1] + seqlen_k = k.shape[1] + seqlen_out = seqlen_q + + if flash_attn_func is not None and batch_size == 1: + dropout_p = self.dropout_p if self.training else 0 + output = flash_attn_func(q, k, v, dropout_p, softmax_scale=self.softmax_scale, causal=self.causal) + return output + + q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * seqlen_q, + step=seqlen_q, + dtype=torch.int32, + device=q.device, + ) + + if batch_size > 1 and attention_mask is not None: + k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask) + if q.size(0) == v.size(0): + q = q[indices_k] + cu_seqlens_q = cu_seqlens_k + seqlen_q = seqlen_k + v = v[indices_k] + else: + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * seqlen_k, + step=seqlen_k, + dtype=torch.int32, + device=q.device, + ) + + if self.training: + assert seqlen_k == seqlen_q + is_causal = self.causal + dropout_p = self.dropout_p + else: + is_causal = seqlen_q == seqlen_k + dropout_p = 0 + + output = flash_attn_unpadded_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqlen_q, + seqlen_k, + dropout_p, + softmax_scale=self.softmax_scale, + causal=is_causal, + ) + if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k: + output = self.pad_input(output, indices_k, batch_size, seqlen_out) + else: + new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:] + output = output.view(new_shape) + return output + + +class QWenAttention(nn.Module): + def __init__(self, config): + super().__init__() + + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + self.seq_length = config.seq_length + + self.hidden_size = config.hidden_size + self.split_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + + self.use_flash_attn = config.use_flash_attn + self.scale_attn_weights = True + + self.projection_size = config.kv_channels * config.num_attention_heads + + assert self.projection_size % config.num_attention_heads == 0 + self.hidden_size_per_attention_head = ( + self.projection_size // config.num_attention_heads + ) + + self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size) + + self.c_proj = nn.Linear( + config.hidden_size, self.projection_size, bias=not config.no_bias + ) + + self.is_fp32 = not (config.bf16 or config.fp16) + if ( + self.use_flash_attn + and flash_attn_unpadded_func is not None + and not self.is_fp32 + ): + self.core_attention_flash = FlashSelfAttention( + causal=True, attention_dropout=config.attn_dropout_prob + ) + self.bf16 = config.bf16 + + self.use_dynamic_ntk = config.use_dynamic_ntk + self.use_logn_attn = config.use_logn_attn + + logn_list = [ + math.log(i, self.seq_length) if i > self.seq_length else 1 + for i in range(1, 32768) + ] + logn_tensor = torch.tensor(logn_list)[None, :, None, None] + self.register_buffer("logn_tensor", logn_tensor, persistent=False) + + self.attn_dropout = nn.Dropout(config.attn_dropout_prob) + self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False + self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False + self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False + cache_dtype = torch.float + if self.bf16: + cache_dtype=torch.bfloat16 + elif config.fp16: + cache_dtype = torch.float16 + self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) + self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype) + + if config.use_cache_quantization and config.use_cache_kernel: + # pre check if the support files existing + module_root = pathlib.Path(__file__).parent + src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu") + if any(not (module_root/src).is_file() for src in src_files): + warnings.warn("KV cache kernel source files (.cpp and .cu) not found.") + self.cache_kernels = None + else: + try: + from .cpp_kernels import cache_autogptq_cuda_256 + self.cache_kernels = cache_autogptq_cuda_256 + except ImportError: + warnings.warn("Failed to import KV cache kernels.") + self.cache_kernels = None + + def _attn(self, query, key, value, causal_mask=None, attention_mask=None, head_mask=None): + device = query.device + if self.use_cache_quantization: + qk, qk_scale, qk_zero = key + if self.use_cache_kernel and self.cache_kernels is not None: + shape = query.shape[:-1] + (qk.shape[-2],) + attn_weights = torch.zeros(shape, dtype=torch.float16, device=device) + self.cache_kernels.vecquant8matmul_batched_faster_old( + query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(), + qk.transpose(-1, -2).contiguous(), + attn_weights, + qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(), + qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous()) + # attn_weights = attn_weights.to(query.dtype).contiguous() + else: + key = dequantize_cache_torch(qk, qk_scale, qk_zero) + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + else: + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + if self.use_cache_quantization: + size_temp = value[0].size(-1) + else: + size_temp = value.size(-1) + attn_weights = attn_weights / (size_temp ** 0.5) + + # mask_value = torch.finfo(attn_weights.dtype).min + # if causal_mask is not None: + # attn_weights = torch.where( + # causal_mask, attn_weights.to(attn_weights.dtype), mask_value + # ) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + if self.softmax_in_fp32: + attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + attn_weights = attn_weights.type(query.dtype) + attn_weights = self.attn_dropout(attn_weights) + + if head_mask is not None: + attn_weights = attn_weights * head_mask + + if self.use_cache_quantization: + qv, qv_scale, qv_zero = value + if self.use_cache_kernel and self.cache_kernels is not None: + shape = attn_weights.shape[:-1] + (query.shape[-1],) + attn_output = torch.zeros(shape, dtype=torch.float16, device=device) + self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old( + attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(), + qv.contiguous(), # dtype: int32 + attn_output, + qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(), + qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous()) + if attn_output.dtype != query.dtype: + attn_output = attn_output.to(query.dtype) + attn_weights = attn_weights.to(query.dtype) + else: + value = dequantize_cache_torch(qv, qv_scale, qv_zero) + attn_output = torch.matmul(attn_weights, value) + else: + attn_output = torch.matmul(attn_weights, value) + + attn_output = attn_output.transpose(1, 2) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor + + def _merge_heads(self, tensor, num_heads, attn_head_size): + tensor = tensor.contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ): + mixed_x_layer = self.c_attn(hidden_states) + + query, key, value = mixed_x_layer.split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if rotary_pos_emb_list is not None: + cur_len = query.shape[1] + if len(rotary_pos_emb_list) == 1: + rotary_pos_emb = rotary_pos_emb_list[0] + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query = apply_rotary_pos_emb(query, q_pos_emb) + key = apply_rotary_pos_emb(key, k_pos_emb) + else: + query_list = [] + key_list = [] + for i, rotary_pos_emb in enumerate(rotary_pos_emb_list): + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)] + key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)] + query = torch.cat(query_list, dim=0) + key = torch.cat(key_list, dim=0) + + if self.use_cache_quantization: + key = quantize_cache_v(key.permute(0, 2, 1, 3), + bits=8, + qmin=self.cache_qmin, + qmax=self.cache_qmax) + value = quantize_cache_v(value.permute(0, 2, 1, 3), + bits=8, + qmin=self.cache_qmin, + qmax=self.cache_qmax) + + if use_cache: + present = (key, value) + else: + present = None + + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + if self.use_cache_quantization: + # use_cache_quantization: + # present=((q_key,key_scale,key_zero_point), + # (q_value,value_scale,value_zero_point)) + key = (torch.cat((past_key[0], key[0]), dim=2), + torch.cat((past_key[1], key[1]), dim=2), + torch.cat((past_key[2], key[2]), dim=2)) + value = (torch.cat((past_value[0], value[0]), dim=2), + torch.cat((past_value[1], value[1]), dim=2), + torch.cat((past_value[2], value[2]), dim=2)) + else: + # not use_cache_quantization: + # present=(key,value) + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) + + + if ( + self.use_flash_attn + and flash_attn_unpadded_func is not None + and not self.is_fp32 + and query.is_cuda + ): + q, k, v = query, key, value + attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask) + else: + causal_mask = None + query = query.permute(0, 2, 1, 3) + if not self.use_cache_quantization: + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + if ( + causal_mask is None + and self.use_flash_attn + and flash_attn_unpadded_func is not None + and not self.is_fp32 + and not query.is_cuda + ): + raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED) + + if not self.use_cache_quantization and SUPPORT_TORCH2: + if attention_mask is not None: + attention_mask = attention_mask.expand(-1, -1, query.size(2), -1) + if causal_mask is not None: + attention_mask = attention_mask.masked_fill(~causal_mask, torch.finfo(query.dtype).min) + else: + attention_mask = causal_mask + attn_output = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask + ).transpose(1, 2) + attn_weight = None + else: + attn_output, attn_weight = self._attn( + query, key, value, causal_mask, attention_mask, head_mask + ) + context_layer = self._merge_heads( + attn_output, self.num_heads, self.head_dim + ) + + attn_output = self.c_proj(context_layer) + + outputs = (attn_output, present) + if output_attentions: + if ( + self.use_flash_attn + and flash_attn_unpadded_func is not None + and not self.is_fp32 + ): + raise ValueError("Cannot output attentions while using flash-attn") + elif not self.use_cache_quantization and SUPPORT_TORCH2: + raise ValueError("Cannot output attentions while using scaled_dot_product_attention") + else: + outputs += (attn_weight,) + + return outputs + + +class QWenMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.w1 = nn.Linear( + config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias + ) + self.w2 = nn.Linear( + config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias + ) + ff_dim_in = config.intermediate_size // 2 + self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias) + + def forward(self, hidden_states): + a1 = self.w1(hidden_states) + a2 = self.w2(hidden_states) + intermediate_parallel = a1 * F.silu(a2) + output = self.c_proj(intermediate_parallel) + return output + + +class QWenBlock(nn.Module): + def __init__(self, config): + super().__init__() + hidden_size = config.hidden_size + self.bf16 = config.bf16 + + self.ln_1 = RMSNorm( + hidden_size, + eps=config.layer_norm_epsilon, + ) + self.attn = QWenAttention(config) + self.ln_2 = RMSNorm( + hidden_size, + eps=config.layer_norm_epsilon, + ) + + self.mlp = QWenMLP(config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + layernorm_output = self.ln_1(hidden_states) + + attn_outputs = self.attn( + layernorm_output, + rotary_pos_emb_list, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] + + outputs = attn_outputs[1:] + + residual = hidden_states + layernorm_input = attn_output + residual + + layernorm_output = self.ln_2(layernorm_input) + + residual = layernorm_input + mlp_output = self.mlp(layernorm_output) + hidden_states = residual + mlp_output + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs + + +class QWenPreTrainedModel(PreTrainedModel): + config_class = QWenConfig + base_model_prefix = "transformer" + is_parallelizable = False + supports_gradient_checkpointing = True + _no_split_modules = ["QWenBlock"] + _skip_keys_device_placement = "past_key_values" + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, RMSNorm): + module.weight.data.fill_(1.0) + + for name, p in module.named_parameters(): + if name == "c_proj.weight": + p.data.normal_( + mean=0.0, + std=( + self.config.initializer_range + / math.sqrt(2 * self.config.num_hidden_layers) + ), + ) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, QWenModel): + module.gradient_checkpointing = value + + +class QWenModel(QWenPreTrainedModel): + _keys_to_ignore_on_load_missing = ["attn.masked_bias"] + + def __init__(self, config): + super().__init__(config) + self.vocab_size = config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + self.embed_dim = config.hidden_size + self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False + + self.gradient_checkpointing = False + self.use_dynamic_ntk = config.use_dynamic_ntk + self.seq_length = config.seq_length + + self.wte = nn.Embedding(self.vocab_size, self.embed_dim) + + self.drop = nn.Dropout(config.emb_dropout_prob) + + if config.rotary_pct == 1.0: + self.rotary_ndims = None + else: + assert config.rotary_pct < 1 + self.rotary_ndims = int( + config.kv_channels * config.rotary_pct + ) + dim = ( + self.rotary_ndims + if self.rotary_ndims is not None + else config.kv_channels + ) + self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) + + self.use_flash_attn = config.use_flash_attn + self.is_fp32 = not (config.bf16 or config.fp16) + + self.h = nn.ModuleList( + [ + QWenBlock( + config + ) + for i in range(config.num_hidden_layers) + ] + ) + self.ln_f = RMSNorm( + self.embed_dim, + eps=config.layer_norm_epsilon, + ) + + self.post_init() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def get_ntk_alpha(self, true_seq_len): + context_value = math.log(true_seq_len / self.seq_length, 2) + 1 + ntk_alpha = 2 ** math.ceil(context_value) - 1 + ntk_alpha = max(ntk_alpha, 1) + return ntk_alpha + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + if self.use_cache_quantization: + past_length = past_key_values[0][0][0].size(2) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=self.dtype) + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + encoder_attention_mask = None + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + hidden_states = inputs_embeds + + kv_seq_len = hidden_states.size()[1] + if past_key_values[0] is not None: + # past key values[0][0] shape: bs * seq_len * head_num * dim + if self.use_cache_quantization: + kv_seq_len += past_key_values[0][0][0].shape[2] + else: + kv_seq_len += past_key_values[0][0].shape[1] + + if self.training or not self.use_dynamic_ntk: + ntk_alpha_list = [1.0] + elif kv_seq_len != hidden_states.size()[1]: + ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list + else: + ntk_alpha_list = [] + if attention_mask is not None and kv_seq_len > self.seq_length: + true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32) + for i in range(hidden_states.size()[0]): + true_seq_len = true_seq_lens[i].item() + ntk_alpha = self.get_ntk_alpha(true_seq_len) + ntk_alpha_list.append(ntk_alpha) + else: + ntk_alpha = self.get_ntk_alpha(kv_seq_len) + ntk_alpha_list.append(ntk_alpha) + self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list + rotary_pos_emb_list = [ + self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list + ] + + hidden_states = self.drop(hidden_states) + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + rotary_pos_emb_list, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + rotary_pos_emb_list=rotary_pos_emb_list, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states] if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class QWenLMHeadModel(QWenPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"] + _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"] + + def __init__(self, config): + super().__init__(config) + assert ( + config.bf16 + config.fp16 + config.fp32 <= 1 + ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true" + + autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0 + + if autoset_precision: + if SUPPORT_BF16: + logger.warn( + "The model is automatically converting to bf16 for faster inference. " + "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." + ) + config.bf16 = True + elif SUPPORT_FP16: + logger.warn( + "The model is automatically converting to fp16 for faster inference. " + "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." + ) + config.fp16 = True + else: + config.fp32 = True + + if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16: + logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".") + if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16: + logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster") + if config.fp32: + if SUPPORT_BF16: + logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".") + elif SUPPORT_FP16: + logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".") + + if config.use_flash_attn == "auto": + if config.bf16 or config.fp16: + logger.warn("Try importing flash-attention for faster inference...") + config.use_flash_attn = True + else: + config.use_flash_attn = False + if config.use_flash_attn and config.fp32: + logger.warn("Flash attention will be disabled because it does NOT support fp32.") + + if config.use_flash_attn: + _import_flash_attn() + + self.transformer = QWenModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + if config.bf16: + self.transformer.bfloat16() + self.lm_head.bfloat16() + if config.fp16: + self.transformer.half() + self.lm_head.half() + self.post_init() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + + if input_ids.size(0) == 1: + attention_mask = None + else: + attention_mask = kwargs.get("attention_mask", None) + + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + + return tuple( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ) + for layer_past in past_key_values + ) + + def chat( + self, + tokenizer: PreTrainedTokenizer, + query: str, + history: Optional[HistoryType], + system: str = "You are a helpful assistant.", + stream: Optional[bool] = _SENTINEL, + stop_words_ids: Optional[List[List[int]]] = None, + generation_config: Optional[GenerationConfig] = None, + **kwargs, + ) -> Tuple[str, HistoryType]: + generation_config = generation_config if generation_config is not None else self.generation_config + + assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT + assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT + if history is None: + history = [] + else: + # make a copy of the user's input such that is is left untouched + history = copy.deepcopy(history) + + if stop_words_ids is None: + stop_words_ids = [] + + max_window_size = kwargs.get('max_window_size', None) + if max_window_size is None: + max_window_size = generation_config.max_window_size + raw_text, context_tokens = make_context( + tokenizer, + query, + history=history, + system=system, + max_window_size=max_window_size, + chat_format=generation_config.chat_format, + ) + + stop_words_ids.extend(get_stop_words_ids( + generation_config.chat_format, tokenizer + )) + input_ids = torch.tensor([context_tokens]).to(self.device) + outputs = self.generate( + input_ids, + stop_words_ids=stop_words_ids, + return_dict_in_generate=False, + generation_config=generation_config, + **kwargs, + ) + + response = decode_tokens( + outputs[0], + tokenizer, + raw_text_len=len(raw_text), + context_length=len(context_tokens), + chat_format=generation_config.chat_format, + verbose=False, + errors='replace' + ) + + # as history is a copy of the user inputs, + # we can always return the new turn to the user. + # separating input history and output history also enables the user + # to implement more complex history management + history.append((query, response)) + + return response, history + + def chat_stream( + self, + tokenizer: PreTrainedTokenizer, + query: str, + history: Optional[HistoryType], + system: str = "You are a helpful assistant.", + stop_words_ids: Optional[List[List[int]]] = None, + logits_processor: Optional[LogitsProcessorList] = None, + generation_config: Optional[GenerationConfig] = None, + **kwargs, + ) -> Generator[str, Any, None]: + generation_config = generation_config if generation_config is not None else self.generation_config + assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT + if history is None: + history = [] + if stop_words_ids is None: + stop_words_ids = [] + + max_window_size = kwargs.get('max_window_size', None) + if max_window_size is None: + max_window_size = generation_config.max_window_size + raw_text, context_tokens = make_context( + tokenizer, + query, + history=history, + system=system, + max_window_size=max_window_size, + chat_format=generation_config.chat_format, + ) + + stop_words_ids.extend(get_stop_words_ids( + generation_config.chat_format, tokenizer + )) + if stop_words_ids is not None: + stop_words_logits_processor = StopWordsLogitsProcessor( + stop_words_ids=stop_words_ids, + eos_token_id=generation_config.eos_token_id, + ) + if logits_processor is None: + logits_processor = LogitsProcessorList([stop_words_logits_processor]) + else: + logits_processor.append(stop_words_logits_processor) + input_ids = torch.tensor([context_tokens]).to(self.device) + + from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig + self.__class__.generate_stream = NewGenerationMixin.generate + self.__class__.sample_stream = NewGenerationMixin.sample_stream + stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) + + def stream_generator(): + outputs = [] + for token in self.generate_stream( + input_ids, + return_dict_in_generate=False, + generation_config=stream_config, + logits_processor=logits_processor, + seed=-1, + **kwargs): + outputs.append(token.item()) + yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore') + + return stream_generator() + + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + generation_config = generation_config if generation_config is not None else self.generation_config + + # Process stop_words_ids. + stop_words_ids = kwargs.pop("stop_words_ids", None) + if stop_words_ids is None and generation_config is not None: + stop_words_ids = getattr(generation_config, "stop_words_ids", None) + if stop_words_ids is None: + stop_words_ids = getattr(generation_config, "stop_words_ids", None) + + if stop_words_ids is not None: + stop_words_logits_processor = StopWordsLogitsProcessor( + stop_words_ids=stop_words_ids, + eos_token_id=generation_config.eos_token_id, + ) + if logits_processor is None: + logits_processor = LogitsProcessorList([stop_words_logits_processor]) + else: + logits_processor.append(stop_words_logits_processor) + + return super().generate( + inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + assistant_model=assistant_model, + streamer=streamer, + **kwargs, + ) + + +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000): + super().__init__() + self.dim = dim + self.base = base + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + if importlib.util.find_spec("einops") is None: + raise RuntimeError("einops is required for Rotary Embedding") + + self._rotary_pos_emb_cache = None + self._seq_len_cached = 0 + self._ntk_alpha_cached = 1.0 + self._ntk_alpha_cached_list = [1.0] + + def update_rotary_pos_emb_cache(self, seqlen, ntk_alpha=1.0): + if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: + base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) + self.inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() + / self.dim + ) + ) + self._seq_len_cached = max(2 * seqlen, 16) + self._ntk_alpha_cached = ntk_alpha + seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device) + freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + from einops import rearrange + + emb = rearrange(emb, "n d -> 1 n 1 d") + + cos, sin = emb.cos(), emb.sin() + self._rotary_pos_emb_cache = [cos, sin] + + def forward(self, max_seq_len, ntk_alpha=1.0): + self.update_rotary_pos_emb_cache(max_seq_len, ntk_alpha) + cos, sin = self._rotary_pos_emb_cache + return [cos[:, :max_seq_len], sin[:, :max_seq_len]] + + +def _rotate_half(x): + from einops import rearrange + + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t, freqs): + """ Apply rotary embedding to the first rotary_dim of the iput + + Arguments: + t (tensor(batch_size, seq_len, n_head, head_dim)): + the input embedding/hidden states + freqs (list[tensor(1, seq_len, 1, rotary_dim), tensor(1, seq_len, 1, rotary_dim)]): + the cached cos/sin position embeddings + """ + rot_dim = freqs[0].shape[-1] + cos, sin = freqs + t_float = t.float() + if apply_rotary_emb_func is not None and t.is_cuda: + # apply_rotary_emb in flash_attn requires cos/sin to be of + # shape (seqlen, rotary_dim / 2) and apply rotary embedding + # to the first rotary_dim of the input + cos = cos.squeeze(0).squeeze(1)[:, : rot_dim // 2] + sin = sin.squeeze(0).squeeze(1)[:, : rot_dim // 2] + return apply_rotary_emb_func(t_float, cos, sin).type_as(t) + else: + t_rot, t_pass = t_float[..., :rot_dim], t_float[..., rot_dim:] + t_rot = (t_rot * cos) + (_rotate_half(t_rot) * sin) + return torch.cat((t_rot, t_pass), dim=-1).type_as(t) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + if rms_norm is not None and x.is_cuda: + return rms_norm(x, self.weight, self.eps) + else: + output = self._norm(x.float()).type_as(x) + return output * self.weight diff --git a/models/Qwen/compile/files/Qwen-1_8B-Chat/config.json b/models/Qwen/compile/files/Qwen-1_8B-Chat/config.json new file mode 100755 index 0000000..06745a8 --- /dev/null +++ b/models/Qwen/compile/files/Qwen-1_8B-Chat/config.json @@ -0,0 +1,37 @@ +{ + "architectures": [ + "QWenLMHeadModel" + ], + "auto_map": { + "AutoConfig": "configuration_qwen.QWenConfig", + "AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel" + }, + "attn_dropout_prob": 0.0, + "bf16": true, + "emb_dropout_prob": 0.0, + "fp16": false, + "fp32": false, + "hidden_size": 2048, + "intermediate_size": 11008, + "initializer_range": 0.02, + "kv_channels": 128, + "layer_norm_epsilon": 1e-06, + "max_position_embeddings": 512, + "model_type": "qwen", + "no_bias": true, + "num_attention_heads": 16, + "num_hidden_layers": 24, + "onnx_safe": null, + "rotary_emb_base": 10000, + "rotary_pct": 1.0, + "scale_attn_weights": true, + "seq_length": 512, + "tie_word_embeddings": false, + "tokenizer_class": "QWenTokenizer", + "transformers_version": "4.32.0", + "use_cache": true, + "use_dynamic_ntk": true, + "use_flash_attn": "auto", + "use_logn_attn": true, + "vocab_size": 151936 +} \ No newline at end of file diff --git a/models/Qwen/compile/files/Qwen-1_8B-Chat/modeling_qwen.py b/models/Qwen/compile/files/Qwen-1_8B-Chat/modeling_qwen.py new file mode 100755 index 0000000..d8c884a --- /dev/null +++ b/models/Qwen/compile/files/Qwen-1_8B-Chat/modeling_qwen.py @@ -0,0 +1,1445 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import importlib +import math +import pathlib +from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import warnings +from torch.cuda.amp import autocast + +from torch.nn import CrossEntropyLoss +from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList +from transformers.generation.logits_process import LogitsProcessorList + +if TYPE_CHECKING: + from transformers.generation.streamers import BaseStreamer +from transformers.generation.utils import GenerateOutput +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +try: + from einops import rearrange +except ImportError: + rearrange = None +from torch import nn + +SUPPORT_CUDA = torch.cuda.is_available() +SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() +SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 +SUPPORT_TORCH2 = False #hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2 + + +from .configuration_qwen import QWenConfig +from .qwen_generation_utils import ( + HistoryType, + make_context, + decode_tokens, + get_stop_words_ids, + StopWordsLogitsProcessor, +) + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "qwen" +_CONFIG_FOR_DOC = "QWenConfig" + +QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"] + +_ERROR_BAD_CHAT_FORMAT = """\ +We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml". +If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat(). +我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。 +如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。 +""" + +_SENTINEL = object() +_ERROR_STREAM_IN_CHAT = """\ +Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True). +向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。 +""" + +_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED = """\ +We detect you have activated flash attention support, but running model computation on CPU. Please make sure that your input data has been placed on GPU. If you actually want to run CPU computation, please following the readme and set device_map="cpu" to disable flash attention when loading the model (calling AutoModelForCausalLM.from_pretrained). +检测到您的模型已激活了flash attention支持,但正在执行CPU运算任务。如使用flash attention,请您确认模型输入已经传到GPU上。如果您确认要执行CPU运算,请您在载入模型(调用AutoModelForCausalLM.from_pretrained)时,按照readme说法,指定device_map="cpu"以禁用flash attention。 +""" + +apply_rotary_emb_func = None +rms_norm = None +flash_attn_unpadded_func = None + +def _import_flash_attn(): + global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func + try: + from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func + apply_rotary_emb_func = __apply_rotary_emb_func + except ImportError: + logger.warn( + "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency " + "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary" + ) + + try: + from flash_attn.ops.rms_norm import rms_norm as __rms_norm + rms_norm = __rms_norm + except ImportError: + logger.warn( + "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency " + "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm" + ) + + try: + import flash_attn + if not hasattr(flash_attn, '__version__'): + from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func + else: + if int(flash_attn.__version__.split(".")[0]) >= 2: + from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func + else: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func + flash_attn_unpadded_func = __flash_attn_unpadded_func + except ImportError: + logger.warn( + "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency " + "https://github.com/Dao-AILab/flash-attention" + ) + +def quantize_cache_v(fdata, bits, qmax, qmin): + # b, s, head, h-dim->b, head, s, h-dim + qtype = torch.uint8 + device = fdata.device + shape = fdata.shape + + fdata_cal = torch.flatten(fdata, 2) + fmax = torch.amax(fdata_cal, dim=-1, keepdim=True) + fmin = torch.amin(fdata_cal, dim=-1, keepdim=True) + # Compute params + if qmax.device != fmax.device: + qmax = qmax.to(device) + qmin = qmin.to(device) + scale = (fmax - fmin) / (qmax - qmin) + zero = qmin - fmin / scale + scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() + zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() + # Quantize + res_data = fdata / scale + zero + qdata = torch.clamp(res_data, qmin, qmax).to(qtype) + return qdata.contiguous(), scale, zero + +def dequantize_cache_torch(qdata, scale, zero): + data = scale * (qdata - zero) + return data + +class FlashSelfAttention(torch.nn.Module): + def __init__( + self, + causal=False, + softmax_scale=None, + attention_dropout=0.0, + ): + super().__init__() + assert flash_attn_unpadded_func is not None, ( + "Please install FlashAttention first, " "e.g., with pip install flash-attn" + ) + assert ( + rearrange is not None + ), "Please install einops first, e.g., with pip install einops" + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def unpad_input(self, hidden_states, attention_mask): + valid_mask = attention_mask.squeeze(1).squeeze(1).eq(0) + seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + hidden_states = hidden_states[indices] + return hidden_states, indices, cu_seqlens, max_seqlen_in_batch + + def pad_input(self, hidden_states, indices, batch, seqlen): + output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device, + dtype=hidden_states.dtype) + output[indices] = hidden_states + return rearrange(output, '(b s) ... -> b s ...', b=batch) + + def forward(self, q, k, v, attention_mask=None): + assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v))) + assert all((i.is_cuda for i in (q, k, v))) + batch_size, seqlen_q = q.shape[0], q.shape[1] + seqlen_k = k.shape[1] + seqlen_out = seqlen_q + + q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * seqlen_q, + step=seqlen_q, + dtype=torch.int32, + device=q.device, + ) + + if batch_size > 1 and attention_mask is not None: + k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask) + if q.size(0) == v.size(0): + q = q[indices_k] + cu_seqlens_q = cu_seqlens_k + seqlen_q = seqlen_k + v = v[indices_k] + else: + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * seqlen_k, + step=seqlen_k, + dtype=torch.int32, + device=q.device, + ) + + if self.training: + assert seqlen_k == seqlen_q + is_causal = self.causal + dropout_p = self.dropout_p + else: + is_causal = seqlen_q == seqlen_k + dropout_p = 0 + + output = flash_attn_unpadded_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqlen_q, + seqlen_k, + dropout_p, + softmax_scale=self.softmax_scale, + causal=is_causal, + ) + if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k: + output = self.pad_input(output, indices_k, batch_size, seqlen_out) + else: + new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:] + output = output.view(new_shape) + return output + + +class QWenAttention(nn.Module): + def __init__(self, config): + super().__init__() + + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + self.seq_length = config.seq_length + + self.hidden_size = config.hidden_size + self.split_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + + self.use_flash_attn = config.use_flash_attn + self.scale_attn_weights = True + + self.projection_size = config.kv_channels * config.num_attention_heads + + assert self.projection_size % config.num_attention_heads == 0 + self.hidden_size_per_attention_head = ( + self.projection_size // config.num_attention_heads + ) + + self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size) + + self.c_proj = nn.Linear( + config.hidden_size, self.projection_size, bias=not config.no_bias + ) + + self.is_fp32 = not (config.bf16 or config.fp16) + if ( + self.use_flash_attn + and flash_attn_unpadded_func is not None + and not self.is_fp32 + ): + self.core_attention_flash = FlashSelfAttention( + causal=True, attention_dropout=config.attn_dropout_prob + ) + self.bf16 = config.bf16 + + self.use_dynamic_ntk = config.use_dynamic_ntk + self.use_logn_attn = config.use_logn_attn + + logn_list = [ + math.log(i, self.seq_length) if i > self.seq_length else 1 + for i in range(1, 32768) + ] + logn_tensor = torch.tensor(logn_list)[None, :, None, None] + self.register_buffer("logn_tensor", logn_tensor, persistent=False) + + self.attn_dropout = nn.Dropout(config.attn_dropout_prob) + self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False + self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False + self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False + cache_dtype = torch.float + if self.bf16: + cache_dtype=torch.bfloat16 + elif config.fp16: + cache_dtype = torch.float16 + self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) + self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype) + + if config.use_cache_quantization and config.use_cache_kernel: + # pre check if the support files existing + module_root = pathlib.Path(__file__).parent + src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu") + if any(not (module_root/src).is_file() for src in src_files): + warnings.warn("KV cache kernel source files (.cpp and .cu) not found.") + self.cache_kernels = None + else: + try: + from .cpp_kernels import cache_autogptq_cuda_256 + self.cache_kernels = cache_autogptq_cuda_256 + except ImportError: + warnings.warn("Failed to import KV cache kernels.") + self.cache_kernels = None + + def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None): + device = query.device + if self.use_cache_quantization: + qk, qk_scale, qk_zero = key + if self.use_cache_kernel and self.cache_kernels is not None: + shape = query.shape[:-1] + (qk.shape[-2],) + attn_weights = torch.zeros(shape, dtype=torch.float16, device=device) + self.cache_kernels.vecquant8matmul_batched_faster_old( + query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(), + qk.transpose(-1, -2).contiguous(), + attn_weights, + qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(), + qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous()) + # attn_weights = attn_weights.to(query.dtype).contiguous() + else: + key = dequantize_cache_torch(qk, qk_scale, qk_zero) + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + else: + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + if self.use_cache_quantization: + size_temp = value[0].size(-1) + else: + size_temp = value.size(-1) + attn_weights = attn_weights / (size_temp ** 0.5) + + # torch.full( + # [], + # size_temp ** 0.5, + # dtype=attn_weights.dtype, + # device=attn_weights.device, + # ) + # if self.use_cache_quantization: + # query_length, key_length = query.size(-2), key[0].size(-2) + # else: + # query_length, key_length = query.size(-2), key.size(-2) + # causal_mask = registered_causal_mask[ + # :, :, key_length - query_length : key_length, :key_length + # ] + # mask_value = torch.finfo(attn_weights.dtype).min + # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to( + # attn_weights.device + # ) + # attn_weights = torch.where( + # causal_mask, attn_weights.to(attn_weights.dtype), mask_value + # ) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + if self.softmax_in_fp32: + attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + attn_weights = attn_weights.type(query.dtype) + attn_weights = self.attn_dropout(attn_weights) + + if head_mask is not None: + attn_weights = attn_weights * head_mask + + if self.use_cache_quantization: + qv, qv_scale, qv_zero = value + if self.use_cache_kernel and self.cache_kernels is not None: + shape = attn_weights.shape[:-1] + (query.shape[-1],) + attn_output = torch.zeros(shape, dtype=torch.float16, device=device) + self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old( + attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(), + qv.contiguous(), # dtype: int32 + attn_output, + qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(), + qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous()) + if attn_output.dtype != query.dtype: + attn_output = attn_output.to(query.dtype) + attn_weights = attn_weights.to(query.dtype) + else: + value = dequantize_cache_torch(qv, qv_scale, qv_zero) + attn_output = torch.matmul(attn_weights, value) + else: + attn_output = torch.matmul(attn_weights, value) + + attn_output = attn_output.transpose(1, 2) + + return attn_output, attn_weights + + def _upcast_and_reordered_attn( + self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None + ): + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + attn_weights = torch.empty( + bsz * num_heads, + q_seq_len, + k_seq_len, + dtype=torch.float32, + device=query.device, + ) + + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape( + -1, dk, k_seq_len + ) + attn_weights = torch.baddbmm( + attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor + ) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + # query_length, key_length = query.size(-2), key.size(-2) + # causal_mask = registered_causal_mask[ + # :, :, key_length - query_length : key_length, :key_length + # ] + # mask_value = torch.finfo(attn_weights.dtype).min + # mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to( + # attn_weights.device + # ) + # attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if attn_weights.dtype != torch.float32: + raise RuntimeError( + "Error with upcasting, attn_weights does not have dtype torch.float32" + ) + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor + + def _merge_heads(self, tensor, num_heads, attn_head_size): + tensor = tensor.contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, + registered_causal_mask: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ): + mixed_x_layer = self.c_attn(hidden_states) + + query, key, value = mixed_x_layer.split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if rotary_pos_emb_list is not None: + cur_len = query.shape[1] + if len(rotary_pos_emb_list) == 1: + rotary_pos_emb = rotary_pos_emb_list[0] + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query = apply_rotary_pos_emb(query, q_pos_emb) + key = apply_rotary_pos_emb(key, k_pos_emb) + else: + query_list = [] + key_list = [] + for i, rotary_pos_emb in enumerate(rotary_pos_emb_list): + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)] + key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)] + query = torch.cat(query_list, dim=0) + key = torch.cat(key_list, dim=0) + + if self.use_cache_quantization: + key = quantize_cache_v(key.permute(0, 2, 1, 3), + bits=8, + qmin=self.cache_qmin, + qmax=self.cache_qmax) + value = quantize_cache_v(value.permute(0, 2, 1, 3), + bits=8, + qmin=self.cache_qmin, + qmax=self.cache_qmax) + if use_cache: + present = (key, value) + else: + present = None + + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + if self.use_cache_quantization: + # use_cache_quantization: + # present=((q_key,key_scale,key_zero_point), + # (q_value,value_scale,value_zero_point)) + key = (torch.cat((past_key[0], key[0]), dim=2), + torch.cat((past_key[1], key[1]), dim=2), + torch.cat((past_key[2], key[2]), dim=2)) + value = (torch.cat((past_value[0], value[0]), dim=2), + torch.cat((past_value[1], value[1]), dim=2), + torch.cat((past_value[2], value[2]), dim=2)) + else: + # not use_cache_quantization: + # present=(key,value) + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) + + if self.use_logn_attn and not self.training: + if self.use_cache_quantization: + seq_start = key[0].size(2) - query.size(1) + seq_end = key[0].size(2) + else: + seq_start = key.size(1) - query.size(1) + seq_end = key.size(1) + logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) + query = query * logn_tensor.expand_as(query) + + if ( + self.use_flash_attn + and flash_attn_unpadded_func is not None + and not self.is_fp32 + and query.is_cuda + ): + q, k, v = query, key, value + attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask) + else: + query = query.permute(0, 2, 1, 3) + if not self.use_cache_quantization: + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + if ( + registered_causal_mask is None + and self.use_flash_attn + and flash_attn_unpadded_func is not None + and not self.is_fp32 + and not query.is_cuda + ): + raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED) + + if not self.use_cache_quantization and SUPPORT_TORCH2: + causal_mask = registered_causal_mask[ + :, :, key.size(-2) - query.size(-2): key.size(-2), :key.size(-2) + ] + if attention_mask is not None: + attention_mask = attention_mask.expand( + -1, -1, causal_mask.size(2), -1 + ).masked_fill(~causal_mask, torch.finfo(query.dtype).min) + else: + attention_mask = causal_mask + attn_output = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask + ).transpose(1, 2) + attn_weight = None + else: + attn_output, attn_weight = self._attn( + query, key, value, registered_causal_mask, attention_mask, head_mask + ) + context_layer = self._merge_heads( + attn_output, self.num_heads, self.head_dim + ) + + attn_output = self.c_proj(context_layer) + + outputs = (attn_output, present) + if output_attentions: + if ( + self.use_flash_attn + and flash_attn_unpadded_func is not None + and not self.is_fp32 + ): + raise ValueError("Cannot output attentions while using flash-attn") + else: + outputs += (attn_weight,) + + return outputs + + +class QWenMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.w1 = nn.Linear( + config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias + ) + self.w2 = nn.Linear( + config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias + ) + ff_dim_in = config.intermediate_size // 2 + self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias) + + def forward(self, hidden_states): + a1 = self.w1(hidden_states) + a2 = self.w2(hidden_states) + intermediate_parallel = a1 * F.silu(a2) + output = self.c_proj(intermediate_parallel) + return output + +class QWenBlock(nn.Module): + def __init__(self, config): + super().__init__() + hidden_size = config.hidden_size + self.bf16 = config.bf16 + + self.ln_1 = RMSNorm( + hidden_size, + eps=config.layer_norm_epsilon, + ) + self.attn = QWenAttention(config) + self.ln_2 = RMSNorm( + hidden_size, + eps=config.layer_norm_epsilon, + ) + + self.mlp = QWenMLP(config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, + registered_causal_mask: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + layernorm_output = self.ln_1(hidden_states) + + attn_outputs = self.attn( + layernorm_output, + rotary_pos_emb_list, + registered_causal_mask=registered_causal_mask, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] + + outputs = attn_outputs[1:] + + residual = hidden_states + layernorm_input = attn_output + residual + + layernorm_output = self.ln_2(layernorm_input) + + residual = layernorm_input + mlp_output = self.mlp(layernorm_output) + hidden_states = residual + mlp_output + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs + + +class QWenPreTrainedModel(PreTrainedModel): + config_class = QWenConfig + base_model_prefix = "transformer" + is_parallelizable = False + supports_gradient_checkpointing = True + _no_split_modules = ["QWenBlock"] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, RMSNorm): + module.weight.data.fill_(1.0) + + for name, p in module.named_parameters(): + if name == "c_proj.weight": + p.data.normal_( + mean=0.0, + std=( + self.config.initializer_range + / math.sqrt(2 * self.config.num_hidden_layers) + ), + ) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, QWenModel): + module.gradient_checkpointing = value + + +class QWenModel(QWenPreTrainedModel): + _keys_to_ignore_on_load_missing = ["attn.masked_bias"] + + def __init__(self, config): + super().__init__(config) + self.vocab_size = config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + self.embed_dim = config.hidden_size + self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False + + self.gradient_checkpointing = False + self.use_dynamic_ntk = config.use_dynamic_ntk + self.seq_length = config.seq_length + + self.wte = nn.Embedding(self.vocab_size, self.embed_dim) + + self.drop = nn.Dropout(config.emb_dropout_prob) + + if config.rotary_pct == 1.0: + self.rotary_ndims = None + else: + assert config.rotary_pct < 1 + self.rotary_ndims = int( + config.kv_channels * config.rotary_pct + ) + dim = ( + self.rotary_ndims + if self.rotary_ndims is not None + else config.kv_channels + ) + self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) + + self.use_flash_attn = config.use_flash_attn + self.is_fp32 = not (config.bf16 or config.fp16) + if ( + self.use_flash_attn + and flash_attn_unpadded_func is not None + and not self.is_fp32 + ): + self.registered_causal_mask = None + else: + max_positions = config.max_position_embeddings + self.register_buffer( + "registered_causal_mask", + torch.tril( + torch.ones((max_positions, max_positions), dtype=torch.bool) + ).view(1, 1, max_positions, max_positions), + persistent=False, + ) + + self.h = nn.ModuleList( + [ + QWenBlock( + config + ) + for i in range(config.num_hidden_layers) + ] + ) + self.ln_f = RMSNorm( + self.embed_dim, + eps=config.layer_norm_epsilon, + ) + + self.post_init() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def get_ntk_alpha(self, true_seq_len): + context_value = math.log(true_seq_len / self.seq_length, 2) + 1 + ntk_alpha = 2 ** math.ceil(context_value) - 1 + ntk_alpha = max(ntk_alpha, 1) + return ntk_alpha + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + if self.use_cache_quantization: + past_length = past_key_values[0][0][0].size(2) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=self.dtype) + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + encoder_attention_mask = None + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + hidden_states = inputs_embeds + + kv_seq_len = hidden_states.size()[1] + if past_key_values[0] is not None: + # past key values[0][0] shape: bs * seq_len * head_num * dim + if self.use_cache_quantization: + kv_seq_len += past_key_values[0][0][0].shape[2] + else: + kv_seq_len += past_key_values[0][0].shape[1] + + if self.training or not self.use_dynamic_ntk: + ntk_alpha_list = [1.0] + elif kv_seq_len != hidden_states.size()[1]: + ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list + else: + ntk_alpha_list = [] + if attention_mask is not None and kv_seq_len > self.seq_length: + true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32) + for i in range(hidden_states.size()[0]): + true_seq_len = true_seq_lens[i].item() + ntk_alpha = self.get_ntk_alpha(true_seq_len) + ntk_alpha_list.append(ntk_alpha) + else: + ntk_alpha = self.get_ntk_alpha(kv_seq_len) + ntk_alpha_list.append(ntk_alpha) + self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list + rotary_pos_emb_list = [ + self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list + ] + + hidden_states = self.drop(hidden_states) + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + rotary_pos_emb_list, + self.registered_causal_mask, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + rotary_pos_emb_list=rotary_pos_emb_list, + registered_causal_mask=self.registered_causal_mask, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states] if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class QWenLMHeadModel(QWenPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"] + _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"] + + def __init__(self, config): + super().__init__(config) + assert ( + config.bf16 + config.fp16 + config.fp32 <= 1 + ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true" + logger.warn( + "Warning: please make sure that you are using the latest codes and checkpoints, " + "especially if you used Qwen-7B before 09.25.2023." + "请使用最新模型和代码,尤其如果你在9月25日前已经开始使用Qwen-7B,千万注意不要使用错误代码和模型。" + ) + + autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0 + + if autoset_precision: + if SUPPORT_BF16: + logger.warn( + "The model is automatically converting to bf16 for faster inference. " + "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." + ) + config.bf16 = True + elif SUPPORT_FP16: + logger.warn( + "The model is automatically converting to fp16 for faster inference. " + "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." + ) + config.fp16 = True + else: + config.fp32 = True + + if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16: + logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".") + if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16: + logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster") + if config.fp32: + if SUPPORT_BF16: + logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".") + elif SUPPORT_FP16: + logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".") + + if config.use_flash_attn == "auto": + if config.bf16 or config.fp16: + logger.warn("Try importing flash-attention for faster inference...") + config.use_flash_attn = True + else: + config.use_flash_attn = False + if config.use_flash_attn and config.fp32: + logger.warn("Flash attention will be disabled because it does NOT support fp32.") + + if config.use_flash_attn: + _import_flash_attn() + + self.transformer = QWenModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + if config.bf16: + self.transformer.bfloat16() + self.lm_head.bfloat16() + if config.fp16: + self.transformer.half() + self.lm_head.half() + self.post_init() + + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs + ): + token_type_ids = kwargs.get("token_type_ids", None) + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + + return tuple( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ) + for layer_past in past_key_values + ) + + def chat( + self, + tokenizer: PreTrainedTokenizer, + query: str, + history: Optional[HistoryType], + system: str = "You are a helpful assistant.", + stream: Optional[bool] = _SENTINEL, + stop_words_ids: Optional[List[List[int]]] = None, + generation_config: Optional[GenerationConfig] = None, + **kwargs, + ) -> Tuple[str, HistoryType]: + generation_config = generation_config if generation_config is not None else self.generation_config + + assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT + assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT + if history is None: + history = [] + else: + # make a copy of the user's input such that is is left untouched + history = copy.deepcopy(history) + + if stop_words_ids is None: + stop_words_ids = [] + + max_window_size = kwargs.get('max_window_size', None) + if max_window_size is None: + max_window_size = generation_config.max_window_size + raw_text, context_tokens = make_context( + tokenizer, + query, + history=history, + system=system, + max_window_size=max_window_size, + chat_format=generation_config.chat_format, + ) + + stop_words_ids.extend(get_stop_words_ids( + generation_config.chat_format, tokenizer + )) + input_ids = torch.tensor([context_tokens]).to(self.device) + outputs = self.generate( + input_ids, + stop_words_ids=stop_words_ids, + return_dict_in_generate=False, + generation_config=generation_config, + **kwargs, + ) + + response = decode_tokens( + outputs[0], + tokenizer, + raw_text_len=len(raw_text), + context_length=len(context_tokens), + chat_format=generation_config.chat_format, + verbose=False, + errors='replace' + ) + + # as history is a copy of the user inputs, + # we can always return the new turn to the user. + # separating input history and output history also enables the user + # to implement more complex history management + history.append((query, response)) + + return response, history + + def chat_stream( + self, + tokenizer: PreTrainedTokenizer, + query: str, + history: Optional[HistoryType], + system: str = "You are a helpful assistant.", + stop_words_ids: Optional[List[List[int]]] = None, + logits_processor: Optional[LogitsProcessorList] = None, + generation_config: Optional[GenerationConfig] = None, + **kwargs, + ) -> Generator[str, Any, None]: + generation_config = generation_config if generation_config is not None else self.generation_config + assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT + if history is None: + history = [] + if stop_words_ids is None: + stop_words_ids = [] + + max_window_size = kwargs.get('max_window_size', None) + if max_window_size is None: + max_window_size = generation_config.max_window_size + raw_text, context_tokens = make_context( + tokenizer, + query, + history=history, + system=system, + max_window_size=max_window_size, + chat_format=generation_config.chat_format, + ) + + stop_words_ids.extend(get_stop_words_ids( + generation_config.chat_format, tokenizer + )) + if stop_words_ids is not None: + stop_words_logits_processor = StopWordsLogitsProcessor( + stop_words_ids=stop_words_ids, + eos_token_id=generation_config.eos_token_id, + ) + if logits_processor is None: + logits_processor = LogitsProcessorList([stop_words_logits_processor]) + else: + logits_processor.append(stop_words_logits_processor) + input_ids = torch.tensor([context_tokens]).to(self.device) + + from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig + self.__class__.generate_stream = NewGenerationMixin.generate + self.__class__.sample_stream = NewGenerationMixin.sample_stream + stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) + + def stream_generator(): + outputs = [] + for token in self.generate_stream( + input_ids, + return_dict_in_generate=False, + generation_config=stream_config, + logits_processor=logits_processor, + seed=-1, + **kwargs): + outputs.append(token.item()) + yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore') + + return stream_generator() + + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + generation_config = generation_config if generation_config is not None else self.generation_config + + # Process stop_words_ids. + stop_words_ids = kwargs.pop("stop_words_ids", None) + if stop_words_ids is None and generation_config is not None: + stop_words_ids = getattr(generation_config, "stop_words_ids", None) + if stop_words_ids is None: + stop_words_ids = getattr(generation_config, "stop_words_ids", None) + + if stop_words_ids is not None: + stop_words_logits_processor = StopWordsLogitsProcessor( + stop_words_ids=stop_words_ids, + eos_token_id=generation_config.eos_token_id, + ) + if logits_processor is None: + logits_processor = LogitsProcessorList([stop_words_logits_processor]) + else: + logits_processor.append(stop_words_logits_processor) + + return super().generate( + inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + assistant_model=assistant_model, + streamer=streamer, + **kwargs, + ) + + +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000): + super().__init__() + self.dim = dim + self.base = base + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + if importlib.util.find_spec("einops") is None: + raise RuntimeError("einops is required for Rotary Embedding") + + self._rotary_pos_emb_cache = None + self._seq_len_cached = 0 + self._ntk_alpha_cached = 1.0 + self._ntk_alpha_cached_list = [1.0] + + def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0): + seqlen = max_seq_len + offset + if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: + base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) + self.inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() + / self.dim + ) + ) + self._seq_len_cached = max(2 * seqlen, 16) + self._ntk_alpha_cached = ntk_alpha + seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device) + freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + from einops import rearrange + + emb = rearrange(emb, "n d -> 1 n 1 d") + + cos, sin = emb.cos(), emb.sin() + self._rotary_pos_emb_cache = [cos, sin] + + def forward(self, max_seq_len, offset=0, ntk_alpha=1.0): + self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha) + cos, sin = self._rotary_pos_emb_cache + return [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]] + + +def _rotate_half(x): + from einops import rearrange + + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t, freqs): + cos, sin = freqs + if apply_rotary_emb_func is not None and t.is_cuda: + t_ = t.float() + cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2] + sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2] + output = apply_rotary_emb_func(t_, cos, sin).type_as(t) + return output + else: + rot_dim = freqs[0].shape[-1] + cos, sin = freqs + t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] + t_ = t_.float() + t_pass_ = t_pass_.float() + t_ = (t_ * cos) + (_rotate_half(t_) * sin) + return torch.cat((t_, t_pass_), dim=-1).type_as(t) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + if rms_norm is not None and x.is_cuda: + return rms_norm(x, self.weight, self.eps) + else: + output = self._norm(x.float()).type_as(x) + return output * self.weight diff --git a/models/Qwen/compile/files/Qwen-7B-Chat/config.json b/models/Qwen/compile/files/Qwen-7B-Chat/config.json new file mode 100755 index 0000000..91bcf88 --- /dev/null +++ b/models/Qwen/compile/files/Qwen-7B-Chat/config.json @@ -0,0 +1,37 @@ +{ + "architectures": [ + "QWenLMHeadModel" + ], + "auto_map": { + "AutoConfig": "configuration_qwen.QWenConfig", + "AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel" + }, + "attn_dropout_prob": 0.0, + "bf16": true, + "emb_dropout_prob": 0.0, + "fp16": false, + "fp32": false, + "hidden_size": 4096, + "intermediate_size": 22016, + "initializer_range": 0.02, + "kv_channels": 128, + "layer_norm_epsilon": 1e-06, + "max_position_embeddings": 512, + "model_type": "qwen", + "no_bias": true, + "num_attention_heads": 32, + "num_hidden_layers": 32, + "onnx_safe": null, + "rotary_emb_base": 10000, + "rotary_pct": 1.0, + "scale_attn_weights": true, + "seq_length": 512, + "tie_word_embeddings": false, + "tokenizer_class": "QWenTokenizer", + "transformers_version": "4.32.0", + "use_cache": true, + "use_dynamic_ntk": true, + "use_flash_attn": "auto", + "use_logn_attn": true, + "vocab_size": 151936 +} \ No newline at end of file diff --git a/models/Qwen/compile/files/Qwen-7B-Chat/modeling_qwen.py b/models/Qwen/compile/files/Qwen-7B-Chat/modeling_qwen.py new file mode 100755 index 0000000..d8c884a --- /dev/null +++ b/models/Qwen/compile/files/Qwen-7B-Chat/modeling_qwen.py @@ -0,0 +1,1445 @@ +# Copyright (c) Alibaba Cloud. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import importlib +import math +import pathlib +from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import warnings +from torch.cuda.amp import autocast + +from torch.nn import CrossEntropyLoss +from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList +from transformers.generation.logits_process import LogitsProcessorList + +if TYPE_CHECKING: + from transformers.generation.streamers import BaseStreamer +from transformers.generation.utils import GenerateOutput +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging + +try: + from einops import rearrange +except ImportError: + rearrange = None +from torch import nn + +SUPPORT_CUDA = torch.cuda.is_available() +SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported() +SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7 +SUPPORT_TORCH2 = False #hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2 + + +from .configuration_qwen import QWenConfig +from .qwen_generation_utils import ( + HistoryType, + make_context, + decode_tokens, + get_stop_words_ids, + StopWordsLogitsProcessor, +) + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "qwen" +_CONFIG_FOR_DOC = "QWenConfig" + +QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"] + +_ERROR_BAD_CHAT_FORMAT = """\ +We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml". +If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat(). +我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。 +如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。 +""" + +_SENTINEL = object() +_ERROR_STREAM_IN_CHAT = """\ +Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True). +向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。 +""" + +_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED = """\ +We detect you have activated flash attention support, but running model computation on CPU. Please make sure that your input data has been placed on GPU. If you actually want to run CPU computation, please following the readme and set device_map="cpu" to disable flash attention when loading the model (calling AutoModelForCausalLM.from_pretrained). +检测到您的模型已激活了flash attention支持,但正在执行CPU运算任务。如使用flash attention,请您确认模型输入已经传到GPU上。如果您确认要执行CPU运算,请您在载入模型(调用AutoModelForCausalLM.from_pretrained)时,按照readme说法,指定device_map="cpu"以禁用flash attention。 +""" + +apply_rotary_emb_func = None +rms_norm = None +flash_attn_unpadded_func = None + +def _import_flash_attn(): + global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func + try: + from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func + apply_rotary_emb_func = __apply_rotary_emb_func + except ImportError: + logger.warn( + "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency " + "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary" + ) + + try: + from flash_attn.ops.rms_norm import rms_norm as __rms_norm + rms_norm = __rms_norm + except ImportError: + logger.warn( + "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency " + "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm" + ) + + try: + import flash_attn + if not hasattr(flash_attn, '__version__'): + from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func + else: + if int(flash_attn.__version__.split(".")[0]) >= 2: + from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func + else: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func + flash_attn_unpadded_func = __flash_attn_unpadded_func + except ImportError: + logger.warn( + "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency " + "https://github.com/Dao-AILab/flash-attention" + ) + +def quantize_cache_v(fdata, bits, qmax, qmin): + # b, s, head, h-dim->b, head, s, h-dim + qtype = torch.uint8 + device = fdata.device + shape = fdata.shape + + fdata_cal = torch.flatten(fdata, 2) + fmax = torch.amax(fdata_cal, dim=-1, keepdim=True) + fmin = torch.amin(fdata_cal, dim=-1, keepdim=True) + # Compute params + if qmax.device != fmax.device: + qmax = qmax.to(device) + qmin = qmin.to(device) + scale = (fmax - fmin) / (qmax - qmin) + zero = qmin - fmin / scale + scale = scale.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() + zero = zero.unsqueeze(-1).repeat(1,1,shape[2],1).contiguous() + # Quantize + res_data = fdata / scale + zero + qdata = torch.clamp(res_data, qmin, qmax).to(qtype) + return qdata.contiguous(), scale, zero + +def dequantize_cache_torch(qdata, scale, zero): + data = scale * (qdata - zero) + return data + +class FlashSelfAttention(torch.nn.Module): + def __init__( + self, + causal=False, + softmax_scale=None, + attention_dropout=0.0, + ): + super().__init__() + assert flash_attn_unpadded_func is not None, ( + "Please install FlashAttention first, " "e.g., with pip install flash-attn" + ) + assert ( + rearrange is not None + ), "Please install einops first, e.g., with pip install einops" + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def unpad_input(self, hidden_states, attention_mask): + valid_mask = attention_mask.squeeze(1).squeeze(1).eq(0) + seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + hidden_states = hidden_states[indices] + return hidden_states, indices, cu_seqlens, max_seqlen_in_batch + + def pad_input(self, hidden_states, indices, batch, seqlen): + output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device, + dtype=hidden_states.dtype) + output[indices] = hidden_states + return rearrange(output, '(b s) ... -> b s ...', b=batch) + + def forward(self, q, k, v, attention_mask=None): + assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v))) + assert all((i.is_cuda for i in (q, k, v))) + batch_size, seqlen_q = q.shape[0], q.shape[1] + seqlen_k = k.shape[1] + seqlen_out = seqlen_q + + q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * seqlen_q, + step=seqlen_q, + dtype=torch.int32, + device=q.device, + ) + + if batch_size > 1 and attention_mask is not None: + k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask) + if q.size(0) == v.size(0): + q = q[indices_k] + cu_seqlens_q = cu_seqlens_k + seqlen_q = seqlen_k + v = v[indices_k] + else: + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * seqlen_k, + step=seqlen_k, + dtype=torch.int32, + device=q.device, + ) + + if self.training: + assert seqlen_k == seqlen_q + is_causal = self.causal + dropout_p = self.dropout_p + else: + is_causal = seqlen_q == seqlen_k + dropout_p = 0 + + output = flash_attn_unpadded_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + seqlen_q, + seqlen_k, + dropout_p, + softmax_scale=self.softmax_scale, + causal=is_causal, + ) + if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k: + output = self.pad_input(output, indices_k, batch_size, seqlen_out) + else: + new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:] + output = output.view(new_shape) + return output + + +class QWenAttention(nn.Module): + def __init__(self, config): + super().__init__() + + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + self.seq_length = config.seq_length + + self.hidden_size = config.hidden_size + self.split_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + + self.use_flash_attn = config.use_flash_attn + self.scale_attn_weights = True + + self.projection_size = config.kv_channels * config.num_attention_heads + + assert self.projection_size % config.num_attention_heads == 0 + self.hidden_size_per_attention_head = ( + self.projection_size // config.num_attention_heads + ) + + self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size) + + self.c_proj = nn.Linear( + config.hidden_size, self.projection_size, bias=not config.no_bias + ) + + self.is_fp32 = not (config.bf16 or config.fp16) + if ( + self.use_flash_attn + and flash_attn_unpadded_func is not None + and not self.is_fp32 + ): + self.core_attention_flash = FlashSelfAttention( + causal=True, attention_dropout=config.attn_dropout_prob + ) + self.bf16 = config.bf16 + + self.use_dynamic_ntk = config.use_dynamic_ntk + self.use_logn_attn = config.use_logn_attn + + logn_list = [ + math.log(i, self.seq_length) if i > self.seq_length else 1 + for i in range(1, 32768) + ] + logn_tensor = torch.tensor(logn_list)[None, :, None, None] + self.register_buffer("logn_tensor", logn_tensor, persistent=False) + + self.attn_dropout = nn.Dropout(config.attn_dropout_prob) + self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False + self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False + self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False + cache_dtype = torch.float + if self.bf16: + cache_dtype=torch.bfloat16 + elif config.fp16: + cache_dtype = torch.float16 + self.cache_qmax = torch.tensor(torch.iinfo(torch.uint8).max, dtype=cache_dtype) + self.cache_qmin = torch.tensor(torch.iinfo(torch.uint8).min, dtype=cache_dtype) + + if config.use_cache_quantization and config.use_cache_kernel: + # pre check if the support files existing + module_root = pathlib.Path(__file__).parent + src_files = ("cache_autogptq_cuda_256.cpp", "cache_autogptq_cuda_kernel_256.cu") + if any(not (module_root/src).is_file() for src in src_files): + warnings.warn("KV cache kernel source files (.cpp and .cu) not found.") + self.cache_kernels = None + else: + try: + from .cpp_kernels import cache_autogptq_cuda_256 + self.cache_kernels = cache_autogptq_cuda_256 + except ImportError: + warnings.warn("Failed to import KV cache kernels.") + self.cache_kernels = None + + def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None): + device = query.device + if self.use_cache_quantization: + qk, qk_scale, qk_zero = key + if self.use_cache_kernel and self.cache_kernels is not None: + shape = query.shape[:-1] + (qk.shape[-2],) + attn_weights = torch.zeros(shape, dtype=torch.float16, device=device) + self.cache_kernels.vecquant8matmul_batched_faster_old( + query.contiguous() if query.dtype == torch.float16 else query.to(torch.float16).contiguous(), + qk.transpose(-1, -2).contiguous(), + attn_weights, + qk_scale.contiguous() if qk_scale.dtype == torch.float16 else qk_scale.to(torch.float16).contiguous(), + qk_zero.contiguous()if qk_zero.dtype == torch.float16 else qk_zero.to(torch.float16).contiguous()) + # attn_weights = attn_weights.to(query.dtype).contiguous() + else: + key = dequantize_cache_torch(qk, qk_scale, qk_zero) + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + else: + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + if self.use_cache_quantization: + size_temp = value[0].size(-1) + else: + size_temp = value.size(-1) + attn_weights = attn_weights / (size_temp ** 0.5) + + # torch.full( + # [], + # size_temp ** 0.5, + # dtype=attn_weights.dtype, + # device=attn_weights.device, + # ) + # if self.use_cache_quantization: + # query_length, key_length = query.size(-2), key[0].size(-2) + # else: + # query_length, key_length = query.size(-2), key.size(-2) + # causal_mask = registered_causal_mask[ + # :, :, key_length - query_length : key_length, :key_length + # ] + # mask_value = torch.finfo(attn_weights.dtype).min + # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to( + # attn_weights.device + # ) + # attn_weights = torch.where( + # causal_mask, attn_weights.to(attn_weights.dtype), mask_value + # ) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + if self.softmax_in_fp32: + attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1) + else: + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + attn_weights = attn_weights.type(query.dtype) + attn_weights = self.attn_dropout(attn_weights) + + if head_mask is not None: + attn_weights = attn_weights * head_mask + + if self.use_cache_quantization: + qv, qv_scale, qv_zero = value + if self.use_cache_kernel and self.cache_kernels is not None: + shape = attn_weights.shape[:-1] + (query.shape[-1],) + attn_output = torch.zeros(shape, dtype=torch.float16, device=device) + self.cache_kernels.vecquant8matmul_batched_column_compression_faster_old( + attn_weights.contiguous() if attn_weights.dtype == torch.float16 else attn_weights.to(torch.float16).contiguous(), + qv.contiguous(), # dtype: int32 + attn_output, + qv_scale.contiguous() if qv_scale.dtype == torch.float16 else qv_scale.to(torch.float16).contiguous(), + qv_zero.contiguous() if qv_zero.dtype == torch.float16 else qv_zero.to(torch.float16).contiguous()) + if attn_output.dtype != query.dtype: + attn_output = attn_output.to(query.dtype) + attn_weights = attn_weights.to(query.dtype) + else: + value = dequantize_cache_torch(qv, qv_scale, qv_zero) + attn_output = torch.matmul(attn_weights, value) + else: + attn_output = torch.matmul(attn_weights, value) + + attn_output = attn_output.transpose(1, 2) + + return attn_output, attn_weights + + def _upcast_and_reordered_attn( + self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None + ): + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + attn_weights = torch.empty( + bsz * num_heads, + q_seq_len, + k_seq_len, + dtype=torch.float32, + device=query.device, + ) + + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape( + -1, dk, k_seq_len + ) + attn_weights = torch.baddbmm( + attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor + ) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + # query_length, key_length = query.size(-2), key.size(-2) + # causal_mask = registered_causal_mask[ + # :, :, key_length - query_length : key_length, :key_length + # ] + # mask_value = torch.finfo(attn_weights.dtype).min + # mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to( + # attn_weights.device + # ) + # attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if attn_weights.dtype != torch.float32: + raise RuntimeError( + "Error with upcasting, attn_weights does not have dtype torch.float32" + ) + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor + + def _merge_heads(self, tensor, num_heads, attn_head_size): + tensor = tensor.contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, + registered_causal_mask: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ): + mixed_x_layer = self.c_attn(hidden_states) + + query, key, value = mixed_x_layer.split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if rotary_pos_emb_list is not None: + cur_len = query.shape[1] + if len(rotary_pos_emb_list) == 1: + rotary_pos_emb = rotary_pos_emb_list[0] + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query = apply_rotary_pos_emb(query, q_pos_emb) + key = apply_rotary_pos_emb(key, k_pos_emb) + else: + query_list = [] + key_list = [] + for i, rotary_pos_emb in enumerate(rotary_pos_emb_list): + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query_list += [apply_rotary_pos_emb(query[i:i+1, :, :], q_pos_emb)] + key_list += [apply_rotary_pos_emb(key[i:i+1, :, :], k_pos_emb)] + query = torch.cat(query_list, dim=0) + key = torch.cat(key_list, dim=0) + + if self.use_cache_quantization: + key = quantize_cache_v(key.permute(0, 2, 1, 3), + bits=8, + qmin=self.cache_qmin, + qmax=self.cache_qmax) + value = quantize_cache_v(value.permute(0, 2, 1, 3), + bits=8, + qmin=self.cache_qmin, + qmax=self.cache_qmax) + if use_cache: + present = (key, value) + else: + present = None + + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + if self.use_cache_quantization: + # use_cache_quantization: + # present=((q_key,key_scale,key_zero_point), + # (q_value,value_scale,value_zero_point)) + key = (torch.cat((past_key[0], key[0]), dim=2), + torch.cat((past_key[1], key[1]), dim=2), + torch.cat((past_key[2], key[2]), dim=2)) + value = (torch.cat((past_value[0], value[0]), dim=2), + torch.cat((past_value[1], value[1]), dim=2), + torch.cat((past_value[2], value[2]), dim=2)) + else: + # not use_cache_quantization: + # present=(key,value) + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) + + if self.use_logn_attn and not self.training: + if self.use_cache_quantization: + seq_start = key[0].size(2) - query.size(1) + seq_end = key[0].size(2) + else: + seq_start = key.size(1) - query.size(1) + seq_end = key.size(1) + logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) + query = query * logn_tensor.expand_as(query) + + if ( + self.use_flash_attn + and flash_attn_unpadded_func is not None + and not self.is_fp32 + and query.is_cuda + ): + q, k, v = query, key, value + attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask) + else: + query = query.permute(0, 2, 1, 3) + if not self.use_cache_quantization: + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + if ( + registered_causal_mask is None + and self.use_flash_attn + and flash_attn_unpadded_func is not None + and not self.is_fp32 + and not query.is_cuda + ): + raise Exception(_ERROR_INPUT_CPU_QUERY_WITH_FLASH_ATTN_ACTIVATED) + + if not self.use_cache_quantization and SUPPORT_TORCH2: + causal_mask = registered_causal_mask[ + :, :, key.size(-2) - query.size(-2): key.size(-2), :key.size(-2) + ] + if attention_mask is not None: + attention_mask = attention_mask.expand( + -1, -1, causal_mask.size(2), -1 + ).masked_fill(~causal_mask, torch.finfo(query.dtype).min) + else: + attention_mask = causal_mask + attn_output = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask + ).transpose(1, 2) + attn_weight = None + else: + attn_output, attn_weight = self._attn( + query, key, value, registered_causal_mask, attention_mask, head_mask + ) + context_layer = self._merge_heads( + attn_output, self.num_heads, self.head_dim + ) + + attn_output = self.c_proj(context_layer) + + outputs = (attn_output, present) + if output_attentions: + if ( + self.use_flash_attn + and flash_attn_unpadded_func is not None + and not self.is_fp32 + ): + raise ValueError("Cannot output attentions while using flash-attn") + else: + outputs += (attn_weight,) + + return outputs + + +class QWenMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.w1 = nn.Linear( + config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias + ) + self.w2 = nn.Linear( + config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias + ) + ff_dim_in = config.intermediate_size // 2 + self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias) + + def forward(self, hidden_states): + a1 = self.w1(hidden_states) + a2 = self.w2(hidden_states) + intermediate_parallel = a1 * F.silu(a2) + output = self.c_proj(intermediate_parallel) + return output + +class QWenBlock(nn.Module): + def __init__(self, config): + super().__init__() + hidden_size = config.hidden_size + self.bf16 = config.bf16 + + self.ln_1 = RMSNorm( + hidden_size, + eps=config.layer_norm_epsilon, + ) + self.attn = QWenAttention(config) + self.ln_2 = RMSNorm( + hidden_size, + eps=config.layer_norm_epsilon, + ) + + self.mlp = QWenMLP(config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, + registered_causal_mask: Optional[torch.Tensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + layernorm_output = self.ln_1(hidden_states) + + attn_outputs = self.attn( + layernorm_output, + rotary_pos_emb_list, + registered_causal_mask=registered_causal_mask, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] + + outputs = attn_outputs[1:] + + residual = hidden_states + layernorm_input = attn_output + residual + + layernorm_output = self.ln_2(layernorm_input) + + residual = layernorm_input + mlp_output = self.mlp(layernorm_output) + hidden_states = residual + mlp_output + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs + + +class QWenPreTrainedModel(PreTrainedModel): + config_class = QWenConfig + base_model_prefix = "transformer" + is_parallelizable = False + supports_gradient_checkpointing = True + _no_split_modules = ["QWenBlock"] + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, RMSNorm): + module.weight.data.fill_(1.0) + + for name, p in module.named_parameters(): + if name == "c_proj.weight": + p.data.normal_( + mean=0.0, + std=( + self.config.initializer_range + / math.sqrt(2 * self.config.num_hidden_layers) + ), + ) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, QWenModel): + module.gradient_checkpointing = value + + +class QWenModel(QWenPreTrainedModel): + _keys_to_ignore_on_load_missing = ["attn.masked_bias"] + + def __init__(self, config): + super().__init__(config) + self.vocab_size = config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + self.embed_dim = config.hidden_size + self.use_cache_quantization = self.config.use_cache_quantization if hasattr(self.config, 'use_cache_quantization') else False + + self.gradient_checkpointing = False + self.use_dynamic_ntk = config.use_dynamic_ntk + self.seq_length = config.seq_length + + self.wte = nn.Embedding(self.vocab_size, self.embed_dim) + + self.drop = nn.Dropout(config.emb_dropout_prob) + + if config.rotary_pct == 1.0: + self.rotary_ndims = None + else: + assert config.rotary_pct < 1 + self.rotary_ndims = int( + config.kv_channels * config.rotary_pct + ) + dim = ( + self.rotary_ndims + if self.rotary_ndims is not None + else config.kv_channels + ) + self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) + + self.use_flash_attn = config.use_flash_attn + self.is_fp32 = not (config.bf16 or config.fp16) + if ( + self.use_flash_attn + and flash_attn_unpadded_func is not None + and not self.is_fp32 + ): + self.registered_causal_mask = None + else: + max_positions = config.max_position_embeddings + self.register_buffer( + "registered_causal_mask", + torch.tril( + torch.ones((max_positions, max_positions), dtype=torch.bool) + ).view(1, 1, max_positions, max_positions), + persistent=False, + ) + + self.h = nn.ModuleList( + [ + QWenBlock( + config + ) + for i in range(config.num_hidden_layers) + ] + ) + self.ln_f = RMSNorm( + self.embed_dim, + eps=config.layer_norm_epsilon, + ) + + self.post_init() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def get_ntk_alpha(self, true_seq_len): + context_value = math.log(true_seq_len / self.seq_length, 2) + 1 + ntk_alpha = 2 ** math.ceil(context_value) - 1 + ntk_alpha = max(ntk_alpha, 1) + return ntk_alpha + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + if self.use_cache_quantization: + past_length = past_key_values[0][0][0].size(2) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange( + past_length, + input_shape[-1] + past_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + if attention_mask is not None: + if batch_size <= 0: + raise ValueError("batch_size has to be defined and > 0") + attention_mask = attention_mask.view(batch_size, -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=self.dtype) + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + encoder_attention_mask = None + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + hidden_states = inputs_embeds + + kv_seq_len = hidden_states.size()[1] + if past_key_values[0] is not None: + # past key values[0][0] shape: bs * seq_len * head_num * dim + if self.use_cache_quantization: + kv_seq_len += past_key_values[0][0][0].shape[2] + else: + kv_seq_len += past_key_values[0][0].shape[1] + + if self.training or not self.use_dynamic_ntk: + ntk_alpha_list = [1.0] + elif kv_seq_len != hidden_states.size()[1]: + ntk_alpha_list = self.rotary_emb._ntk_alpha_cached_list + else: + ntk_alpha_list = [] + if attention_mask is not None and kv_seq_len > self.seq_length: + true_seq_lens = attention_mask.squeeze(1).squeeze(1).eq(0).sum(dim=-1, dtype=torch.int32) + for i in range(hidden_states.size()[0]): + true_seq_len = true_seq_lens[i].item() + ntk_alpha = self.get_ntk_alpha(true_seq_len) + ntk_alpha_list.append(ntk_alpha) + else: + ntk_alpha = self.get_ntk_alpha(kv_seq_len) + ntk_alpha_list.append(ntk_alpha) + self.rotary_emb._ntk_alpha_cached_list = ntk_alpha_list + rotary_pos_emb_list = [ + self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha) for ntk_alpha in ntk_alpha_list + ] + + hidden_states = self.drop(hidden_states) + output_shape = input_shape + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache, output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + rotary_pos_emb_list, + self.registered_causal_mask, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + rotary_pos_emb_list=rotary_pos_emb_list, + registered_causal_mask=self.registered_causal_mask, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, presents, all_hidden_states] if v is not None + ) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class QWenLMHeadModel(QWenPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"] + _keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"] + + def __init__(self, config): + super().__init__(config) + assert ( + config.bf16 + config.fp16 + config.fp32 <= 1 + ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true" + logger.warn( + "Warning: please make sure that you are using the latest codes and checkpoints, " + "especially if you used Qwen-7B before 09.25.2023." + "请使用最新模型和代码,尤其如果你在9月25日前已经开始使用Qwen-7B,千万注意不要使用错误代码和模型。" + ) + + autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0 + + if autoset_precision: + if SUPPORT_BF16: + logger.warn( + "The model is automatically converting to bf16 for faster inference. " + "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." + ) + config.bf16 = True + elif SUPPORT_FP16: + logger.warn( + "The model is automatically converting to fp16 for faster inference. " + "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"." + ) + config.fp16 = True + else: + config.fp32 = True + + if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16: + logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".") + if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16: + logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster") + if config.fp32: + if SUPPORT_BF16: + logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".") + elif SUPPORT_FP16: + logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".") + + if config.use_flash_attn == "auto": + if config.bf16 or config.fp16: + logger.warn("Try importing flash-attention for faster inference...") + config.use_flash_attn = True + else: + config.use_flash_attn = False + if config.use_flash_attn and config.fp32: + logger.warn("Flash attention will be disabled because it does NOT support fp32.") + + if config.use_flash_attn: + _import_flash_attn() + + self.transformer = QWenModel(config) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + if config.bf16: + self.transformer.bfloat16() + self.lm_head.bfloat16() + if config.fp16: + self.transformer.half() + self.lm_head.half() + self.post_init() + + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs + ): + token_type_ids = kwargs.get("token_type_ids", None) + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + + return tuple( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ) + for layer_past in past_key_values + ) + + def chat( + self, + tokenizer: PreTrainedTokenizer, + query: str, + history: Optional[HistoryType], + system: str = "You are a helpful assistant.", + stream: Optional[bool] = _SENTINEL, + stop_words_ids: Optional[List[List[int]]] = None, + generation_config: Optional[GenerationConfig] = None, + **kwargs, + ) -> Tuple[str, HistoryType]: + generation_config = generation_config if generation_config is not None else self.generation_config + + assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT + assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT + if history is None: + history = [] + else: + # make a copy of the user's input such that is is left untouched + history = copy.deepcopy(history) + + if stop_words_ids is None: + stop_words_ids = [] + + max_window_size = kwargs.get('max_window_size', None) + if max_window_size is None: + max_window_size = generation_config.max_window_size + raw_text, context_tokens = make_context( + tokenizer, + query, + history=history, + system=system, + max_window_size=max_window_size, + chat_format=generation_config.chat_format, + ) + + stop_words_ids.extend(get_stop_words_ids( + generation_config.chat_format, tokenizer + )) + input_ids = torch.tensor([context_tokens]).to(self.device) + outputs = self.generate( + input_ids, + stop_words_ids=stop_words_ids, + return_dict_in_generate=False, + generation_config=generation_config, + **kwargs, + ) + + response = decode_tokens( + outputs[0], + tokenizer, + raw_text_len=len(raw_text), + context_length=len(context_tokens), + chat_format=generation_config.chat_format, + verbose=False, + errors='replace' + ) + + # as history is a copy of the user inputs, + # we can always return the new turn to the user. + # separating input history and output history also enables the user + # to implement more complex history management + history.append((query, response)) + + return response, history + + def chat_stream( + self, + tokenizer: PreTrainedTokenizer, + query: str, + history: Optional[HistoryType], + system: str = "You are a helpful assistant.", + stop_words_ids: Optional[List[List[int]]] = None, + logits_processor: Optional[LogitsProcessorList] = None, + generation_config: Optional[GenerationConfig] = None, + **kwargs, + ) -> Generator[str, Any, None]: + generation_config = generation_config if generation_config is not None else self.generation_config + assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT + if history is None: + history = [] + if stop_words_ids is None: + stop_words_ids = [] + + max_window_size = kwargs.get('max_window_size', None) + if max_window_size is None: + max_window_size = generation_config.max_window_size + raw_text, context_tokens = make_context( + tokenizer, + query, + history=history, + system=system, + max_window_size=max_window_size, + chat_format=generation_config.chat_format, + ) + + stop_words_ids.extend(get_stop_words_ids( + generation_config.chat_format, tokenizer + )) + if stop_words_ids is not None: + stop_words_logits_processor = StopWordsLogitsProcessor( + stop_words_ids=stop_words_ids, + eos_token_id=generation_config.eos_token_id, + ) + if logits_processor is None: + logits_processor = LogitsProcessorList([stop_words_logits_processor]) + else: + logits_processor.append(stop_words_logits_processor) + input_ids = torch.tensor([context_tokens]).to(self.device) + + from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig + self.__class__.generate_stream = NewGenerationMixin.generate + self.__class__.sample_stream = NewGenerationMixin.sample_stream + stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) + + def stream_generator(): + outputs = [] + for token in self.generate_stream( + input_ids, + return_dict_in_generate=False, + generation_config=stream_config, + logits_processor=logits_processor, + seed=-1, + **kwargs): + outputs.append(token.item()) + yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore') + + return stream_generator() + + def generate( + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + **kwargs, + ) -> Union[GenerateOutput, torch.LongTensor]: + generation_config = generation_config if generation_config is not None else self.generation_config + + # Process stop_words_ids. + stop_words_ids = kwargs.pop("stop_words_ids", None) + if stop_words_ids is None and generation_config is not None: + stop_words_ids = getattr(generation_config, "stop_words_ids", None) + if stop_words_ids is None: + stop_words_ids = getattr(generation_config, "stop_words_ids", None) + + if stop_words_ids is not None: + stop_words_logits_processor = StopWordsLogitsProcessor( + stop_words_ids=stop_words_ids, + eos_token_id=generation_config.eos_token_id, + ) + if logits_processor is None: + logits_processor = LogitsProcessorList([stop_words_logits_processor]) + else: + logits_processor.append(stop_words_logits_processor) + + return super().generate( + inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + assistant_model=assistant_model, + streamer=streamer, + **kwargs, + ) + + +class RotaryEmbedding(torch.nn.Module): + def __init__(self, dim, base=10000): + super().__init__() + self.dim = dim + self.base = base + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + if importlib.util.find_spec("einops") is None: + raise RuntimeError("einops is required for Rotary Embedding") + + self._rotary_pos_emb_cache = None + self._seq_len_cached = 0 + self._ntk_alpha_cached = 1.0 + self._ntk_alpha_cached_list = [1.0] + + def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0): + seqlen = max_seq_len + offset + if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: + base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) + self.inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() + / self.dim + ) + ) + self._seq_len_cached = max(2 * seqlen, 16) + self._ntk_alpha_cached = ntk_alpha + seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device) + freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + from einops import rearrange + + emb = rearrange(emb, "n d -> 1 n 1 d") + + cos, sin = emb.cos(), emb.sin() + self._rotary_pos_emb_cache = [cos, sin] + + def forward(self, max_seq_len, offset=0, ntk_alpha=1.0): + self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha) + cos, sin = self._rotary_pos_emb_cache + return [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]] + + +def _rotate_half(x): + from einops import rearrange + + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t, freqs): + cos, sin = freqs + if apply_rotary_emb_func is not None and t.is_cuda: + t_ = t.float() + cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2] + sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2] + output = apply_rotary_emb_func(t_, cos, sin).type_as(t) + return output + else: + rot_dim = freqs[0].shape[-1] + cos, sin = freqs + t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] + t_ = t_.float() + t_pass_ = t_pass_.float() + t_ = (t_ * cos) + (_rotate_half(t_) * sin) + return torch.cat((t_, t_pass_), dim=-1).type_as(t) + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + if rms_norm is not None and x.is_cuda: + return rms_norm(x, self.weight, self.eps) + else: + output = self._norm(x.float()).type_as(x) + return output * self.weight diff --git a/models/Qwen/demo/CMakeLists.txt b/models/Qwen/demo/CMakeLists.txt old mode 100644 new mode 100755 index dbfc639..1a1874f --- a/models/Qwen/demo/CMakeLists.txt +++ b/models/Qwen/demo/CMakeLists.txt @@ -8,8 +8,6 @@ if (NOT DEFINED TARGET_ARCH) endif() include_directories(${PROJECT_SOURCE_DIR}/../support/include) -# include_directories(${PROJECT_SOURCE_DIR}/../support/third_party/abseil-cpp) -# include_directories(${PROJECT_SOURCE_DIR}/../support/third_party/re2) if (${CMAKE_HOST_SYSTEM_PROCESSOR} STREQUAL "aarch64") add_definitions(-DSOC_TARGET) @@ -33,14 +31,8 @@ set(CMAKE_BUILD_TYPE "Debug") set(ABSL_ENABLE_INSTALL ON) set(ABSL_PROPAGATE_CXX_STD ON) -# find_package(re2 REQUIRED) add_subdirectory(third_party/abseil-cpp) add_subdirectory(third_party/re2) -# link_directories(${PROJECT_SOURCE_DIR}/../support/third_party/abseil-cpp) -# link_directories(${PROJECT_SOURCE_DIR}/../support/third_party/re2) - -# add_executable(tokenizer tokenizer.cpp) -# target_link_libraries(tokenizer re2) add_executable(qwen demo.cpp tokenizer.cpp) target_link_libraries(qwen bmrt bmlib re2) diff --git a/models/Qwen/demo/demo.cpp b/models/Qwen/demo/demo.cpp old mode 100644 new mode 100755 diff --git a/models/Qwen/demo/demo_parallel.cpp b/models/Qwen/demo/demo_parallel.cpp new file mode 100755 index 0000000..3e405be --- /dev/null +++ b/models/Qwen/demo/demo_parallel.cpp @@ -0,0 +1,641 @@ +//===----------------------------------------------------------------------===// +// +// Copyright (C) 2023 Sophgo Technologies Inc. All rights reserved. +// +// TPU-MLIR is licensed under the 2-Clause BSD License except for the +// third-party components. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include +#include "memory.h" +#include "tokenizer.h" +#include "bmruntime_interface.h" +#include + +// #define EXPORT_RESULTS +#ifdef EXPORT_RESULTS +#include "cnpy.h" +static cnpy::npz_t map; +static int x = 0; + +template +static void add_array(std::string name, bm_handle_t bm_handle, + const bm_tensor_t &dst) { + std::vector data(dst.device_mem.size / sizeof(T)); + bm_memcpy_d2s(bm_handle, data.data(), dst.device_mem); + std::vector shape; + for (int i = 0; i < dst.shape.num_dims; ++i) { + shape.push_back(dst.shape.dims[i]); + } + cnpy::npz_add_array(map, name, data); +} + +static void save_array(std::string filename) { + cnpy::npz_save_all(filename, map); +} + +static void clear_array() { + map.clear(); +} +#endif + +void dump_tensor(bm_handle_t bm_handle, bm_tensor_t &tensor) { + auto shape = tensor.shape; + int size = 1; + for (int i = 0; i < shape.num_dims; ++i){ + size *= shape.dims[i]; + } + std::vector data(size); + bm_memcpy_d2s(bm_handle, data.data(), tensor.device_mem); + // std::cout<< data[0] << "\t" << data[data.size()-1] << std::endl; + auto ptr = data.data(); + ptr[0] = ptr[0]; +} + + +static const uint16_t BF16_NEG_10000 = 0xC61C; // -9984 by bfloat16 + +static const std::string TOKENIZER_MODEL = "qwen.tiktoken"; + +class QwenChat { +public: + void init(const std::vector &devid, std::string model); + void chat(); + void deinit(); + +private: + void answer(const std::string &input_str); + void tokenizer_encode(const std::string &input_str, std::vector &tokens); + int forward_first(std::vector &tokens); + int forward_next(int cur_token); + void move2end(const bm_tensor_t &kv); + void load_tiktoken(); + +private: + std::vector handles; + bm_handle_t bm_handle; + void *p_bmrt; + std::vector net_blocks; + std::vector net_blocks_cache; + const bm_net_info_t *net_embed; + const bm_net_info_t *net_embed_cache; + const bm_net_info_t *net_lm; + std::vector inputs_embed_512, outputs_embed_512; + std::vector inputs_pid, next_pid, inputs_attention, next_attention; + std::vector> past_keys, past_values; + std::vector present_key_cache, present_value_cache; + std::vector inputs_lm, outputs_lm; + std::string name_embed; + std::string name_embed_cache; + std::string name_lm; + std::vector name_blocks; + std::vector name_blocks_cache; + + int device_num; + int token_length; + int SEQLEN; // read from bmodel + int NUM_LAYERS; // read from bmodel + std::unique_ptr tk; + std::vector history; +}; + +void QwenChat::load_tiktoken() { + printf("Load %s ... \n", TOKENIZER_MODEL.c_str()); + tk = std::make_unique(TOKENIZER_MODEL); +} + +void QwenChat::init(const std::vector &devices, std::string model) { + device_num = devices.size(); + load_tiktoken(); + // request bm_handle + std::cout << "Device [ "; + for (auto d : devices) { + std::cout << d << " "; + } + std::cout << "] loading ....\n"; + for (auto d : devices) { + bm_handle_t h; + bm_status_t status = bm_dev_request(&h, d); + assert(BM_SUCCESS == status); + handles.push_back(h); + } + bm_handle = handles[0]; +// create bmruntime +#ifdef SOC_TARGET + p_bmrt = bmrt_create(handles[0]); +#else + p_bmrt = bmrt_create_ex(handles.data(), handles.size()); +#endif + assert(NULL != p_bmrt); + + // load bmodel by file + printf("Model[%s] loading ....\n", model.c_str()); + bool ret = bmrt_load_bmodel(p_bmrt, model.c_str()); + assert(true == ret); + printf("Done!\n"); + + + // embed, lm_head + name_embed = "embedding_1"; + name_embed_cache = "embedding_0"; + name_lm = "lm_head"; + net_embed = bmrt_get_network_info(p_bmrt, name_embed.c_str()); + net_embed_cache = bmrt_get_network_info(p_bmrt, name_embed_cache.c_str()); + net_lm = bmrt_get_network_info(p_bmrt, name_lm.c_str()); + SEQLEN = net_embed->stages[0].input_shapes[0].dims[1]; // real seqlen + auto num_nets = bmrt_get_network_number(p_bmrt); + NUM_LAYERS = (num_nets - 3) / 2; + + name_blocks.resize(NUM_LAYERS); + name_blocks_cache.resize(NUM_LAYERS); + net_blocks.resize(NUM_LAYERS); + net_blocks_cache.resize(NUM_LAYERS); + past_keys.resize(NUM_LAYERS); + past_values.resize(NUM_LAYERS); + + // blocks + for (int i = 0; i < NUM_LAYERS; i++) { + name_blocks[i] = "qwen_block_" + std::to_string(i); + name_blocks_cache[i] = "qwen_block_cache_" + std::to_string(i); + } + for (int i = 0; i < NUM_LAYERS; i++) { + net_blocks[i] = bmrt_get_network_info(p_bmrt, name_blocks[i].c_str()); + net_blocks_cache[i] = + bmrt_get_network_info(p_bmrt, name_blocks_cache[i].c_str()); + } + + // net device mem + inputs_embed_512.resize(net_embed->input_num); + for (int i = 0; i < device_num; ++i) { + ret = bmrt_tensor_ex(&inputs_embed_512[i], p_bmrt, + net_embed->input_loc_devices[i], + net_embed->input_dtypes[i], + net_embed->stages[0].input_shapes[i]); + assert(true == ret); + } + + outputs_embed_512.resize(net_embed->output_num); + for (int i = 0; i < device_num; ++i) { + ret = bmrt_tensor_ex(&outputs_embed_512[i], p_bmrt, + net_embed->output_loc_devices[i], + net_embed->output_dtypes[i], + net_embed->stages[0].output_shapes[i]); + assert(true == ret); + } + + inputs_pid.resize(device_num); + inputs_attention.resize(device_num); + int in_num = net_blocks[0]->input_num / device_num; + for (int i = 0; i < device_num; ++i) { + ret = bmrt_tensor_ex(&inputs_pid[i], p_bmrt, + net_blocks[0]->input_loc_devices[1 + i * in_num], + net_blocks[0]->input_dtypes[1 + i * in_num], + net_blocks[0]->stages[0].input_shapes[1 + i * in_num]); + assert(true == ret); + + ret = bmrt_tensor_ex(&inputs_attention[i], p_bmrt, + net_blocks[0]->input_loc_devices[2 + i * in_num], + net_blocks[0]->input_dtypes[2 + i * in_num], + net_blocks[0]->stages[0].input_shapes[2 + i * in_num]); + assert(true == ret); + } + + + next_pid.resize(device_num); + next_attention.resize(device_num); + int in_num_cache = net_blocks_cache[0]->input_num / device_num; + for (int i = 0; i < device_num; ++i) { + ret = bmrt_tensor_ex(&next_pid[i], p_bmrt, + net_blocks_cache[0]->input_loc_devices[1 + i * in_num_cache], + net_blocks_cache[0]->input_dtypes[1 + i * in_num_cache], + net_blocks_cache[0]->stages[0].input_shapes[1 + i * in_num_cache]); + assert(true == ret); + + ret = bmrt_tensor_ex(&next_attention[i], p_bmrt, + net_blocks_cache[0]->input_loc_devices[2 + i * in_num_cache], + net_blocks_cache[0]->input_dtypes[2 + i * in_num_cache], + net_blocks_cache[0]->stages[0].input_shapes[2 + i * in_num_cache]); + assert(true == ret); + } + + int out_num = net_blocks[0]->output_num / device_num; + for (int i = 0; i < NUM_LAYERS; i++) { + past_keys[i].resize(device_num); + past_values[i].resize(device_num); + for (int j = 0; j < device_num; j++) { + ret = bmrt_tensor_ex(&past_keys[i][j], p_bmrt, + net_blocks[0]->output_loc_devices[1 + j * out_num], + net_blocks[0]->output_dtypes[1 + j * out_num], + net_blocks[0]->stages[0].output_shapes[1 + j * out_num]); + assert(true == ret); + ret = bmrt_tensor_ex(&past_values[i][j], p_bmrt, + net_blocks[0]->output_loc_devices[2 + j * out_num], + net_blocks[0]->output_dtypes[2 + j * out_num], + net_blocks[0]->stages[0].output_shapes[2 + j * out_num]); + assert(true == ret); + } + } + + present_key_cache.resize(device_num); + present_value_cache.resize(device_num); + inputs_lm.resize(device_num); + outputs_lm.resize(device_num); + // int out_num_cache = net_blocks_cache[0]->output_num / device_num; + for (int i = 0; i < device_num; ++i) { + present_key_cache[i] = past_keys[0][i]; + present_value_cache[i] = past_values[0][i]; + present_key_cache[i].shape.dims[1] = 1; + present_value_cache[i].shape.dims[1] = 1; + + ret = bmrt_tensor_ex(&inputs_lm[i], p_bmrt, i, net_lm->input_dtypes[0], + net_lm->stages[0].input_shapes[0]); + assert(true == ret); + ret = bmrt_tensor_ex(&outputs_lm[i], p_bmrt, i, net_lm->output_dtypes[0], + net_lm->stages[0].output_shapes[0]); + assert(true == ret); + } +} + +void QwenChat::deinit() { + for (int i = 0; i < device_num; ++i) { + bm_free_device(handles[i], inputs_embed_512[i].device_mem); + bm_free_device(handles[i], outputs_embed_512[i].device_mem); + bm_free_device(handles[i], inputs_pid[i].device_mem); + bm_free_device(handles[i], next_pid[i].device_mem); + bm_free_device(handles[i], inputs_attention[i].device_mem); + bm_free_device(handles[i], next_attention[i].device_mem); + // bm_free_device(handles[i], present_key_cache[i].device_mem); + // bm_free_device(handles[i], present_value_cache[i].device_mem); + bm_free_device(handles[i], inputs_lm[i].device_mem); + bm_free_device(handles[i], outputs_lm[i].device_mem); + } + for (int i = 0; i < NUM_LAYERS; i++) { + for (int j = 0; j < device_num; j++) { + bm_free_device(handles[j], past_keys[i][j].device_mem); + bm_free_device(handles[j], past_values[i][j].device_mem); + } + } + bmrt_destroy(p_bmrt); + for (auto h : handles) { + bm_dev_free(h); + } +} + +// after first block, move real result to end of mem +void QwenChat::move2end(const bm_tensor_t &kv) { + if (token_length >= SEQLEN) { + return; + } + auto total_size = bm_mem_get_device_size(kv.device_mem); + auto bytes = total_size / SEQLEN; + auto real_size = token_length * bytes; + auto mem = + bm_mem_from_device(bm_mem_get_device_addr(kv.device_mem), real_size); + auto buffer = new uint8_t[real_size]; + auto dst = new uint8_t[total_size]; + bm_memcpy_d2s(bm_handle, (void *)buffer, mem); + memset(dst, 0, total_size - real_size); + memcpy(dst + total_size - real_size, buffer, real_size); + bm_memcpy_s2d(bm_handle, kv.device_mem, (void *)dst); + delete[] buffer; + delete[] dst; +} + +int QwenChat::forward_first(std::vector &tokens) { + std::vector input_ids(SEQLEN, 0); + std::vector position_id(SEQLEN, 0); + std::vector attention_mask(SEQLEN * SEQLEN, BF16_NEG_10000); + std::copy(tokens.begin(), tokens.end(), input_ids.data()); + + for (int i = 0; i < token_length; i++) { + position_id[i] = i; + } + for (int i = 0; i < token_length; i++) { + for (int j = 0; j < SEQLEN; j++) { + if (j <= i) { + attention_mask[i * SEQLEN + j] = 0; + } + } + } + + // forward embeding + std::vector input_nums(device_num, 1); + std::vector datas(device_num, (void*)input_ids.data()); + bmrt_memcpy_s2d_parallel(p_bmrt, inputs_embed_512.data(), datas.data(), + input_nums.data(), device_num); + auto ret = + bmrt_launch_tensor_ex(p_bmrt, name_embed.c_str(), + inputs_embed_512.data(), inputs_embed_512.size(), + outputs_embed_512.data(), outputs_embed_512.size(), + true, false); + assert(ret); + bm_thread_sync(bm_handle); + + // forward blocks + std::vector pos_id_datas(device_num, (void*)position_id.data()); + std::vector in_attn_datas(device_num, (void*)attention_mask.data()); + bmrt_memcpy_s2d_parallel(p_bmrt, inputs_pid.data(), pos_id_datas.data(), + input_nums.data(), device_num); + bmrt_memcpy_s2d_parallel(p_bmrt, inputs_attention.data(),in_attn_datas.data(), + input_nums.data(), device_num); + auto embed_512 = outputs_embed_512; + std::vector inputs_block; + std::vector outputs_block; + for (int i = 0; i < device_num; ++i) { + embed_512[i].shape = net_blocks[0]->stages[0].input_shapes[0]; + inputs_block.push_back(embed_512[i]); + inputs_block.push_back(inputs_pid[i]); + inputs_block.push_back(inputs_attention[i]); + outputs_block.push_back(embed_512[i]); + outputs_block.push_back(past_keys[0][i]); + outputs_block.push_back(past_values[0][i]); +#ifdef EXPORT_RESULTS + add_array(net_blocks[0]->input_names[0 + i * 3], handles[i], inputs_block[0 + i * 3]); + add_array(net_blocks[0]->input_names[1 + i * 3], handles[i], inputs_block[1 + i * 3]); + add_array(net_blocks[0]->input_names[2 + i * 3], handles[i], inputs_block[2 + i * 3]); + } + save_array(std::to_string(device_num) + "dev_block_0_inputs.npz"); + clear_array(); +#else + } +#endif + for (int i = 0; i < NUM_LAYERS; i++) { + for (int j = 0; j < device_num; ++j) { + outputs_block[1 + j * 3] = past_keys[i][j]; + outputs_block[2 + j * 3] = past_values[i][j]; + } + ret = bmrt_launch_tensor_ex(p_bmrt, name_blocks[i].c_str(), + inputs_block.data(), inputs_block.size(), + outputs_block.data(), outputs_block.size(), + true, false); + assert(ret); + bm_thread_sync(bm_handle); +#ifdef EXPORT_RESULTS + // if (i == 0) { + // add_array(net_blocks[0]->output_names[0], handles[i], outputs_block[0]); + // add_array(net_blocks[0]->output_names[1], handles[i], outputs_block[1]); + // add_array(net_blocks[0]->output_names[2], handles[i], outputs_block[2]); + // save_array("block_0_outputs.npz"); + // } +#endif + } + + int bytes = embed_512[0].device_mem.size / SEQLEN; + bm_memcpy_d2d_byte(bm_handle, inputs_lm[0].device_mem, 0, + embed_512[0].device_mem, (token_length - 1) * bytes, + bytes); + ret = bmrt_launch_tensor_ex(p_bmrt, name_lm.c_str(), &inputs_lm[0], 1, + &outputs_lm[0], 1, + true, false); + assert(ret); + bm_thread_sync(bm_handle); + + int token = 0; + bm_memcpy_d2s(bm_handle, (void *)&token, outputs_lm[0].device_mem); + return token; +} + +int QwenChat::forward_next(int cur_token) { + std::vector attention_mask(SEQLEN + 1, 0); + for (int i = token_length - 1; i < SEQLEN; i++) { + attention_mask[i] = BF16_NEG_10000; + } + int32_t position_id = token_length - 1; + + // embedding + std::vector inputs_embed; + std::vector input_datas; + std::vector input_nums(device_num, 1); + for (int i = 0; i < device_num; ++i) { + inputs_embed.push_back(outputs_lm[i]); // token_id + inputs_embed[i].shape = net_embed_cache->stages[0].input_shapes[0]; + input_datas.push_back((void*)(&cur_token)); + } + bmrt_memcpy_s2d_parallel(p_bmrt, inputs_embed.data(), input_datas.data(), + input_nums.data(), device_num); + auto ret = bmrt_launch_tensor_ex(p_bmrt, name_embed_cache.c_str(), + inputs_embed.data(), inputs_embed.size(), + inputs_lm.data(), inputs_lm.size(), true, false); + assert(ret); + bm_thread_sync(bm_handle); + + // blocks + std::vector attn_datas(device_num, attention_mask.data()); + std::vector pid_datas(device_num, &position_id); + bmrt_memcpy_s2d_parallel(p_bmrt, next_attention.data(), attn_datas.data(), + input_nums.data(), device_num); + bmrt_memcpy_s2d_parallel(p_bmrt, next_pid.data(), pid_datas.data(), + input_nums.data(), device_num); + // WARNING: make inputs_lm device_num + std::vector embed_1 = inputs_lm; + for (int i = 0; i < device_num; ++i) { + embed_1[i].shape = net_blocks_cache[0]->stages[0].input_shapes[0]; + } + std::vector inputs_block; + std::vector outputs_block; + for (int i = 0; i < device_num; ++i) { + inputs_block.push_back(embed_1[i]); + inputs_block.push_back(next_pid[i]); + inputs_block.push_back(next_attention[i]); + inputs_block.push_back(past_keys[0][i]); + inputs_block.push_back(past_values[0][i]); + outputs_block.push_back(embed_1[i]); + outputs_block.push_back(present_key_cache[i]); + outputs_block.push_back(present_value_cache[i]); +#ifdef EXPORT_RESULTS + if (x == 0) { + add_array(net_blocks_cache[0]->input_names[0 + i * 5], handles[i], inputs_block[0 + i * 5]); + add_array(net_blocks_cache[0]->input_names[1 + i * 5], handles[i], inputs_block[1 + i * 5]); + add_array(net_blocks_cache[0]->input_names[2 + i * 5], handles[i], inputs_block[2 + i * 5]); + add_array(net_blocks_cache[0]->input_names[3 + i * 5], handles[i], inputs_block[3 + i * 5]); + add_array(net_blocks_cache[0]->input_names[4 + i * 5], handles[i], inputs_block[4 + i * 5]); + } + } + if (x == 0) { + save_array(std::to_string(device_num) + "dev_block_cache_0_inputs.npz"); + clear_array(); + x++; + } +#else +} +#endif + for (int i = 0; i < NUM_LAYERS; i++) { + for (int j = 0; j < device_num; ++j) { + inputs_block[3 + j * 5] = past_keys[i][j]; + inputs_block[4 + j * 5] = past_values[i][j]; + int bytes = bm_mem_get_device_size(past_keys[0][j].device_mem) / SEQLEN; + int token_offset = (token_length - 1) * bytes; + bm_set_device_mem(&outputs_block[1 + j * 3].device_mem, bytes, + bm_mem_get_device_addr(past_keys[i][j].device_mem) + token_offset); + bm_set_device_mem(&outputs_block[2 + j * 3].device_mem, bytes, + bm_mem_get_device_addr(past_values[i][j].device_mem) + token_offset); + } + ret = bmrt_launch_tensor_ex(p_bmrt, name_blocks_cache[i].c_str(), + inputs_block.data(), inputs_block.size(), + outputs_block.data(), outputs_block.size(), + true, false); + assert(ret); + bm_thread_sync(bm_handle); + } + + ret = bmrt_launch_tensor_ex(p_bmrt, name_lm.c_str(), &inputs_lm[0], 1, + &outputs_lm[0], 1, true, false); + assert(ret); + bm_thread_sync(bm_handle); + + int token = 0; + bm_memcpy_d2s(bm_handle, (void *)&token, outputs_lm[0].device_mem); + return token; +} + +void QwenChat::chat() { + while (true) { + std::cout << "\nQuestion: "; + std::string input_str; + std::getline(std::cin, input_str); + if (input_str.empty()) { + continue; + } + if (input_str == "exit" || input_str == "quit") { + break; + } + if (input_str == "clear") { + history.clear(); + continue; + } + std::cout << "\nAnswer: " << std::flush; + answer(input_str); + std::cout << std::endl; + } +} + +void QwenChat::answer(const std::string &input_str) { + int tok_num = 0; + history.emplace_back(std::move(input_str)); + auto input_ids = tk->encode_history(history, SEQLEN); + token_length = input_ids.size(); + auto time_1 = std::chrono::system_clock::now(); + int pre_token = 0; + int token = forward_first(input_ids); + auto time_2 = std::chrono::system_clock::now(); + std::string result; + while (token != tk->im_end_id && token_length < SEQLEN) { + std::vector pre_ids = {pre_token}; + std::vector ids = {pre_token, token}; + auto pre_word = tk->decode(pre_ids); + auto word = tk->decode(ids); + std::string diff = word.substr(pre_word.size()); + result += diff; + std::cout << diff << std::flush; + if (token_length < SEQLEN) { + token_length++; + } + tok_num++; + token = forward_next(token); + } + auto time_3 = std::chrono::system_clock::now(); + auto ftl_dur = + std::chrono::duration_cast(time_2 - time_1); + auto tps_dur = + std::chrono::duration_cast(time_3 - time_2); + double tps = tok_num / (tps_dur.count() * 1e-6); + if (token_length >= SEQLEN) { + printf(" ......\nWarning: cleanup early history\n"); + } + // double tht = tokens.size() / (tht_dur.count() * 1e-6); + printf("\nFTL:%f s, TPS: %f tokens/s\n", ftl_dur.count() * 1e-6, tps); + history.emplace_back(result); + if (token_length + 128 >= SEQLEN) { + int num = (history.size() + 3) / 4 * 2; + history.erase(history.begin(), history.begin() + num); + } +} + +static void split(const std::string &s, const std::string &delim, + std::vector &ret) { + size_t last = 0; + size_t index = s.find_first_of(delim, last); + while (index != std::string::npos) { + ret.push_back(s.substr(last, index - last)); + last = index + 1; + index = s.find_first_of(delim, last); + } + if (last < s.length()) { + ret.push_back(s.substr(last)); + } +} + +static std::vector parseCascadeDevices(const std::string &str) { + std::vector devices; + std::vector sub_str; + split(str, ",", sub_str); + for (auto &s : sub_str) { + devices.push_back(std::atoi(s.c_str())); + } + return devices; +} + +void Usage() { + printf("Usage:\n" + " --help : Show help info.\n" + " --model : Set model path \n" + " --devid : Set devices to run for model, e.g. 1,2. if not " + "set, use 0\n"); +} + +void processArguments(int argc, char *argv[], std::string &qwen_model, + std::vector &devices) { + struct option longOptions[] = {{"model", required_argument, nullptr, 'm'}, + {"devid", required_argument, nullptr, 'd'}, + {"help", no_argument, nullptr, 'h'}, + {nullptr, 0, nullptr, 0}}; + + int optionIndex = 0; + int option; + + while ((option = getopt_long(argc, argv, "m:d:h:", longOptions, + &optionIndex)) != -1) { + switch (option) { + case 'm': + qwen_model = optarg; + break; + case 'd': + devices = parseCascadeDevices(optarg); + break; + case 'h': + Usage(); + exit(EXIT_SUCCESS); + case '?': + Usage(); + exit(EXIT_FAILURE); + default: + exit(EXIT_FAILURE); + } + } +} + +int main(int argc, char **argv) { + // set your bmodel path here + printf("Demo for QwenChat in BM1684X\n"); + std::string qwen_model; + std::vector devices = {0}; + processArguments(argc, argv, qwen_model, devices); + if (qwen_model.empty()) { + Usage(); + exit(EXIT_FAILURE); + } + + QwenChat qwen; + printf("Init Environment ...\n"); + qwen.init(devices, qwen_model); + printf("==========================\n"); + qwen.chat(); + qwen.deinit(); + return 0; +} diff --git a/models/Qwen/demo/third_party/abseil-cpp b/models/Qwen/demo/third_party/abseil-cpp new file mode 160000 index 0000000..36442dd --- /dev/null +++ b/models/Qwen/demo/third_party/abseil-cpp @@ -0,0 +1 @@ +Subproject commit 36442dd87ed5f568f483f702252c6c5e6028aeb3 diff --git a/models/Qwen/demo/third_party/re2 b/models/Qwen/demo/third_party/re2 new file mode 160000 index 0000000..ab7c591 --- /dev/null +++ b/models/Qwen/demo/third_party/re2 @@ -0,0 +1 @@ +Subproject commit ab7c5918b418428ed17dbe564e0d8402bd7d743d diff --git a/models/Qwen/demo/tokenizer.cpp b/models/Qwen/demo/tokenizer.cpp new file mode 100755 index 0000000..2522189 --- /dev/null +++ b/models/Qwen/demo/tokenizer.cpp @@ -0,0 +1,103 @@ +#include "tokenizer.h" + +static const std::string PAT_STR = + R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?:$|[^\S])|\s+)"; + +static std::pair _parse(const std::string &line) { + auto pos = line.find(" "); + if (pos == std::string::npos) { + throw std::runtime_error("invalid encoder line: " + line); + } + + auto token = base64::decode({line.data(), pos}); + int rank = 0; + try { + rank = std::stoul(line.substr(pos + 1)); + } catch (const std::exception &) { + throw std::runtime_error("invalid encoder rank: " + line); + } + + return {std::move(token), rank}; +} + +QwenTokenizer::QwenTokenizer(const std::string &tiktoken_path) { + std::ifstream file(tiktoken_path); + if (!file) { + throw std::runtime_error("failed to open encoder file: " + tiktoken_path); + } + + ankerl::unordered_dense::map encoder; + std::string line; + while (std::getline(file, line)) { + auto [token, rank] = _parse(line); + + if (!encoder.emplace(std::move(token), rank).second) { + throw std::runtime_error("duplicate item: " + line); + } + } + + std::vector special_tokens_s{"<|endoftext|>", "<|im_start|>", + "<|im_end|>"}; + char buffer[14]; + for (size_t i = 0; i < 205; i++) { + snprintf(buffer, 14, "<|extra_%zu|>", i); + special_tokens_s.push_back(buffer); + } + size_t encoder_size = encoder.size(); + ankerl::unordered_dense::map special_tokens; + special_tokens.reserve(special_tokens_s.size()); + for (size_t i = 0; i < special_tokens_s.size(); i++) { + special_tokens[special_tokens_s[i]] = encoder_size + i; + } + + tokenizer = tiktoken::tiktoken(std::move(encoder), special_tokens, PAT_STR); +} + +auto QwenTokenizer::build_prompt(const std::vector &history) const + -> std::string { + if (history.size() % 2 != 1) { + std::cout << "invalid history size " << history.size(); + exit(-1); + } + + std::ostringstream oss_prompt; + oss_prompt << "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"; + for (size_t i = 0; i < history.size() - 1; i += 2) { + oss_prompt << "\n<|im_start|>user\n" + << history[i] << "<|im_end|>\n<|im_start|>" << history[i + 1] + << "<|im_end|>"; + } + oss_prompt << "\n<|im_start|>user\n" + << history.back() << "<|im_end|>\n<|im_start|>assistant\n"; + + return oss_prompt.str(); +} + +auto QwenTokenizer::encode(const std::string &text, int max_length) const + -> std::vector { + auto ids = tokenizer.encode(text); + if ((int)ids.size() > max_length) { + ids.erase(ids.begin(), ids.end() - max_length); + } + return ids; +} + +auto QwenTokenizer::decode(const std::vector &ids) const -> std::string { + std::vector normal_ids(ids); + normal_ids.erase(std::remove_if(normal_ids.begin(), normal_ids.end(), + [this](int id) { return is_special_id(id); }), + normal_ids.end()); + auto text = tokenizer.decode(normal_ids); + return text; +} + +auto QwenTokenizer::encode_history(const std::vector &history, + int max_length) const -> std::vector { + std::string prompt = build_prompt(history); + std::vector input_ids = encode(prompt, max_length); + return input_ids; +} + +auto QwenTokenizer::is_special_id(int id) const -> bool { + return id == eod_id || id == im_start_id || id == im_end_id; +} \ No newline at end of file diff --git a/models/Qwen/support/include/base64.h b/models/Qwen/support/include/base64.h new file mode 100755 index 0000000..393090b --- /dev/null +++ b/models/Qwen/support/include/base64.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include + +namespace base64 { + +static auto pos_of_char(const unsigned char chr) -> size_t { + if (chr >= 'A' && chr <= 'Z') return chr - 'A'; + else if (chr >= 'a' && chr <= 'z') return chr - 'a' + ('Z' - 'A') + 1; + else if (chr >= '0' && chr <= '9') return chr - '0' + ('Z' - 'A') + ('z' - 'a') + 2; + else if (chr == '+' || chr == '-') return 62; + else if (chr == '/' || chr == '_') return 63; + else throw std::runtime_error("Input is not valid base64-encoded data."); +} + +inline auto decode(std::string_view s) -> std::string { + if (s.empty()) throw std::runtime_error("empty input"); + size_t length = s.length(); + size_t idx = 0; + + std::string out; + out.reserve(length / 4 * 3); + + while (idx < length) { + size_t pos_of_char_1 = pos_of_char(s.at(idx + 1)); + out.push_back(static_cast(((pos_of_char(s.at(idx+0))) << 2 ) + ((pos_of_char_1 & 0x30) >> 4))); + if ((idx + 2 < length) && s.at(idx + 2) != '=' && s.at(idx + 2) != '.') { + size_t pos_of_char_2 = pos_of_char(s.at(idx + 2)); + out.push_back(static_cast(((pos_of_char_1 & 0x0f) << 4) + ((pos_of_char_2 & 0x3c) >> 2))); + if ((idx + 3 < length) && s.at(idx + 3) != '=' && s.at(idx + 3) != '.') { + out.push_back(static_cast(((pos_of_char_2 & 0x03) << 6) + pos_of_char(s.at(idx+3)))); + } + } + idx += 4; + } + return out; +} + +} // namespace base64 diff --git a/models/Qwen/support/include/bmdef.h b/models/Qwen/support/include/bmdef.h new file mode 100755 index 0000000..a7e1db0 --- /dev/null +++ b/models/Qwen/support/include/bmdef.h @@ -0,0 +1,131 @@ +/***************************************************************************** + * + * Copyright (c) 2016-2026 by Sophgo Technologies Inc. All rights reserved. + * + * The material in this file is confidential and contains trade secrets + * of Sophgo Technologies Inc. This is proprietary information owned by + * Sophgo Technologies Inc. No part of this work may be disclosed, + * reproduced, copied, transmitted, or used in any way for any purpose, + * without the express written permission of Sophgo Technologies Inc. + * + *****************************************************************************/ + +#ifndef __BMRUNTIME_DEFINE_H__ +#define __BMRUNTIME_DEFINE_H__ + +#include "bmlib_runtime.h" +#include +#include + +#if defined(__cplusplus) +extern "C" { +#endif + +/* --------------------------------------------------------------------------*/ +/* basic definitions */ + +/* bm_data_type_t holds the type for a scalar value */ +typedef enum bm_data_type_e { + BM_FLOAT32 = 0, + BM_FLOAT16 = 1, + BM_INT8 = 2, + BM_UINT8 = 3, + BM_INT16 = 4, + BM_UINT16 = 5, + BM_INT32 = 6, + BM_UINT32 = 7, + BM_BFLOAT16 = 8, + BM_INT4 = 9, + BM_UINT4 = 10, +} bm_data_type_t; + +/* store mode definitions */ +typedef enum bm_store_mode_e { + BM_STORE_1N = 0, /* default, if not sure, use 0 */ + BM_STORE_2N = 1, + BM_STORE_4N = 2, +} bm_store_mode_t; + +/* bm_shape_t holds the shape info */ +#define BM_MAX_DIMS_NUM 8 +typedef struct bm_shape_s { + int num_dims; + int dims[BM_MAX_DIMS_NUM]; +} bm_shape_t; + +typedef struct bm_shape_ex_s { + bm_shape_t shape; + int elem_num; +} bm_shape_ex_t; + +/* +bm_tensor_t holds a multi-dimensional array of elements of a single data type +and tensor are in device memory */ +typedef struct bm_tensor_s { + bm_data_type_t dtype; + bm_shape_t shape; + bm_device_mem_t device_mem; + bm_store_mode_t st_mode; /* user can set 0 as default store mode */ +} bm_tensor_t; + +/* --------------------------------------------------------------------------*/ +/* network information structure */ + +/* bm_stage_info_t holds input/output shapes and device mems; every network can contain one or more + * stages */ +typedef struct bm_stage_info_s { + bm_shape_t *input_shapes; /* input_shapes[0] / [1] / ... / [input_num-1] */ + bm_shape_t *output_shapes; /* output_shapes[0] / [1] / ... / [output_num-1] */ + bm_device_mem_t *input_mems; /* input_mems[0] / [1] / ... / [input_num-1] */ + bm_device_mem_t *output_mems; /* output_mems[0] / [1] / ... / [output_num-1] */ +} bm_stage_info_t; + +/* bm_tensor_info_t holds all information of one net. + * scale for float type is 1.0 as default */ +typedef struct bm_net_info_s { + const char* name; /* net name */ + bool is_dynamic; /* dynamic or static */ + int input_num; /* number of inputs */ + char const** input_names; /* input_names[0] / [1] / .../ [input_num-1] */ + bm_data_type_t* input_dtypes; /* input_dtypes[0] / [1] / .../ [input_num-1] */ + float* input_scales; /* input_scales[0] / [1] / .../ [input_num-1] */ + int output_num; /* number of outputs */ + char const** output_names; /* output_names[0] / [1] / .../ [output_num-1] */ + bm_data_type_t* output_dtypes; /* output_dtypes[0] / [1] / .../ [output_num-1] */ + float* output_scales; /* output_scales[0] / [1] / .../ [output_num-1] */ + int stage_num; /* number of stages */ + bm_stage_info_t* stages; /* stages[0] / [1] / ... / [stage_num-1] */ + size_t* max_input_bytes; /* max_input_bytes[0]/ [1] / ... / [input_num-1] */ + size_t* max_output_bytes; /* max_output_bytes[0] / [1] / ... / [output_num-1] */ + int* input_zero_point; /* input_zero_point[0] / [1] / .../ [input_num-1] */ + int* output_zero_point; /* output_zero_point[0] / [1] / .../ [output_num-1] */ + int *input_loc_devices; /* input_loc_device[0] / [1] / .../ [input_num-1] */ + int *output_loc_devices; /* output_loc_device[0] / [1] / .../ [output_num-1] */ + int core_num; /* core number */ + bool io_alone; /* whether io is alone from neuron space */ +} bm_net_info_t; + +typedef struct api_info_s { + /// @brief api_id to be sent to driver + int32_t api_id; + /// @brief api data to be sent to driver + uint8_t **api_data; + /// @brief size of the api data to be sent to driver + size_t api_data_size; + /// @brief subsize of the api data to be sent to driver + size_t *api_data_subsize; + /// @brief offset of input tensors' addr in api_data + uint32_t *input_addr_offset; + /// @brief number of the offset of input tensors' addr in api_data + size_t input_addr_offset_number; + /// @brief offset of output tensors' addr in api_data + uint32_t *output_addr_offset; + /// @brief number of the offset of output tensors' addr in api_data + size_t output_addr_offset_number; +} api_info_c; + +#if defined(__cplusplus) +} +#endif + +#endif /* __BM_NET_H__ */ diff --git a/models/Qwen/support/include/bmlib_runtime.h b/models/Qwen/support/include/bmlib_runtime.h new file mode 100755 index 0000000..60094e1 --- /dev/null +++ b/models/Qwen/support/include/bmlib_runtime.h @@ -0,0 +1,2579 @@ +/***************************************************************************** + * + * Copyright (c) 2016-2026 by Bitmain Technologies Inc. All rights reserved. + * + * The material in this file is confidential and contains trade secrets + * of Bitmain Technologies Inc. This is proprietary information owned by + * Bitmain Technologies Inc. No part of this work may be disclosed, + * reproduced, copied, transmitted, or used in any way for any purpose, + * without the express written permission of Bitmain Technologies Inc. + * + *****************************************************************************/ + +/************************************************************************** + * bmlib_runtime defines interfaces that operate TPU devices. + * The functions can be divided into serveral categories. + * 1) device handle creation and destroy + * 2) memory help functions + * 3) global memory allocation and free + * 4) data transfer between host and device + * 5) data transfer within device memory + * 6) api send and synchronization + * 7) global memory map and coherence + * 8) trace and profile + * 9) power management + * 10) miscellaneous functions + *************************************************************************/ + +#ifndef BMLIB_RUNTIME_H_ +#define BMLIB_RUNTIME_H_ +#if defined(_WIN32) && !defined(__MINGW32__) + #include + #define DECL_EXPORT __declspec(dllexport) + #define DECL_IMPORT __declspec(dllimport) +#else + #include + #include + #include + #define DECL_EXPORT + #define DECL_IMPORT +#endif + +#if defined(__cplusplus) +extern "C" { +#endif + +typedef enum { + MODULE_CDMA = 0, + MODULE_GDMA = 1, + MODULE_TPU = 2, + MODULE_SMMU = 3, + MODULE_SRAM = 4, + MODULE_END = 5 +} MODULE_ID; + +#define BM_MEM_ADDR_NULL (0xfffffffff) + +#ifndef BM_MEM_DESC_T_ +#define BM_MEM_DESC_T_ +/* BM function return code definitions */ +typedef enum { + BM_SUCCESS = 0, + BM_ERR_DEVNOTREADY = 1, /* Device not ready yet */ + BM_ERR_FAILURE = 2, /* General failure */ + BM_ERR_TIMEOUT = 3, /* Timeout */ + BM_ERR_PARAM = 4, /* Parameters invalid */ + BM_ERR_NOMEM = 5, /* Not enough memory */ + BM_ERR_DATA = 6, /* Data error */ + BM_ERR_BUSY = 7, /* Busy */ + BM_ERR_NOFEATURE = 8, /* Not supported yet */ + BM_NOT_SUPPORTED = 9 +} bm_status_t; + +/* BM memory type definitions */ +typedef enum { + BM_MEM_TYPE_DEVICE = 0, + BM_MEM_TYPE_HOST = 1, + BM_MEM_TYPE_SYSTEM = 2, + BM_MEM_TYPE_INT8_DEVICE = 3, + BM_MEM_TYPE_INVALID = 4 +} bm_mem_type_t; + +typedef enum { + PERF_MONITOR_GDMA = 0, + PERF_MONITOR_TPU = 1 +} PERF_MONITOR_ID; + +typedef enum { + BMCPU_IDLE = 0, + BMCPU_RUNNING = 1, + BMCPU_FAULT = 2 +} bm_cpu_status_t; + +/* +* bm performace monitor +*/ +typedef struct bm_perf_monitor { + long long buffer_start_addr; /*buffer address to store perf data*/ + int buffer_size; /*buffer size*/ + PERF_MONITOR_ID monitor_id; /*PERF_MONITOR_GDMA or PERF_MONITOR_TPU*/ +} bm_perf_monitor_t; + +typedef union { + struct { + bm_mem_type_t mem_type : 3; + unsigned int gmem_heapid : 3; + unsigned int reserved : 26; + } u; + unsigned int rawflags; +} bm_mem_flags_t; + +/* BM memory descriptor definition*/ +typedef struct bm_mem_desc { + union { + struct { +#ifdef __linux__ + unsigned long device_addr; +#else + unsigned long long device_addr; +#endif + unsigned int reserved; + int dmabuf_fd; + } device; + + struct { + void *system_addr; + unsigned int reserved0; + int reserved1; + } system; + } u; + + bm_mem_flags_t flags; + unsigned int size; +} bm_mem_desc_t; + +typedef struct bm_mem_desc bm_device_mem_t; +typedef struct bm_mem_desc bm_system_mem_t; + +typedef struct sg_mem_desc { + union { + struct { +#ifdef __linux__ + unsigned long device_addr; +#else + unsigned long long device_addr; +#endif + unsigned int reserved; + int dmabuf_fd; + } device; + + struct { + void *system_addr; + unsigned int reserved0; + int reserved1; + } system; + } u; + + bm_mem_flags_t flags; + unsigned long long size; +} sg_mem_desc_t; + +typedef struct sg_mem_desc sg_device_mem_t; +typedef struct sg_mem_desc sg_system_mem_t; +#endif + +struct bm_context; +typedef struct bm_context *bm_handle_t; + +#define MD5SUM_LEN 16 +#define LIB_MAX_NAME_LEN 64 +#define FUNC_MAX_NAME_LEN 64 + +typedef struct bm_module +{ + // void *lib_handle; + char lib_name[LIB_MAX_NAME_LEN]; + unsigned char md5[MD5SUM_LEN]; +}bm_module; + +typedef struct bm_module *tpu_kernel_module_t; +typedef int tpu_kernel_function_t; + +/** + * @name tpu_kernel_load_module_file + * @brief To load dyn file + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] module_file dyn file + * @retval dyn lib ptr + */ +tpu_kernel_module_t tpu_kernel_load_module_file(bm_handle_t handle, const char *module_file); + +/** + * @name tpu_kernel_load_module_file_key + * @brief To load dyn file with key + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] module_file dyn file + * @param [in] key identification str + * @param [in] size key size + * @retval dyn lib ptr + */ +tpu_kernel_module_t tpu_kernel_load_module_file_key(bm_handle_t handle, const char *module_file, const char *key, int size); + +/** + * @name tpu_kernel_unload_module + * @brief To unload dyn file + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] p_module dyn lib ptr + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +bm_status_t tpu_kernel_unload_module(bm_handle_t handle, tpu_kernel_module_t p_module); + +/** + * @name tpu_kernel_free_module + * @brief To free p_module when not use + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] p_module dyn lib ptr + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +bm_status_t tpu_kernel_free_module(bm_handle_t handle, tpu_kernel_module_t p_module); + +/** + * @name tpu_kernel_load_module + * @brief To load dyn module + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] data dyn module + * @param [in] length dyn module size + * @retval dyn lib ptr + */ +tpu_kernel_module_t tpu_kernel_load_module(bm_handle_t handle, const char *data, size_t length); + +/** + * @name tpu_kernel_get_function + * @brief To get function from lib + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] module dyn module + * @param [in] function funtion name + * @retval function id + */ +tpu_kernel_function_t tpu_kernel_get_function(bm_handle_t handle, tpu_kernel_module_t module, const char *function); + +/** + * @name tpu_kernel_launch + * @brief To launch function with sync + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] function function id + * @param [in] args funtion args + * @param [in] size args size + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +bm_status_t tpu_kernel_launch(bm_handle_t handle, tpu_kernel_function_t function, void *args, size_t size); + +/** + * @name tpu_kernel_launch_async + * @brief To launch function with async + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] function function id + * @param [in] args funtion args + * @param [in] size args size + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +bm_status_t tpu_kernel_launch_async(bm_handle_t handle, tpu_kernel_function_t function, void *args, size_t size); + +/** + * @name tpu_kernel_launch_async_multi_cores + * @brief To launch function with async for multi cores + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] func_name function name + * @param [in] api_param funtion params + * @param [in] api_size params size + * @param [in] core_list list of core ids + * @param [in] core_num number of cores + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +bm_status_t tpu_kernel_launch_async_multi_cores(bm_handle_t handle, const char *func_name, const void *api_param, + size_t api_size, const int* core_list, const int core_num); + +/** + * @name tpu_kernel_launch_sync_multi_cores + * @brief To launch function with sync for multi cores + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] func_name function name + * @param [in] api_param funtion params + * @param [in] api_size params size + * @param [in] core_list list of core ids + * @param [in] core_num number of cores + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +bm_status_t tpu_kernel_launch_sync_multi_cores(bm_handle_t handle, const char *func_name, const void *api_param, + size_t api_size, const int* core_list, const int core_num); + +/** + * @name tpu_kernel_sync + * @brief To sync + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +bm_status_t tpu_kernel_sync(bm_handle_t handle); +void show_md5(unsigned char md5[]); + +DECL_EXPORT void bmlib_log(const char *tag, int level, const char *fmt, ...); + +#ifndef USING_CMODEL +#define BM_CHECK_RET(call) \ + do { \ + bm_status_t ret = (bm_status_t)call; \ + if (ret != BM_SUCCESS) { \ + bmlib_log("BM_CHECK",16,"BM_CHECK_RET fail %s: %s: %d\n", __FILE__, __func__, __LINE__); \ + return ret; \ + } \ + } while (0) +#else +#define BM_CHECK_RET(call) \ + do { \ + bm_status_t ret = call; \ + if (ret != BM_SUCCESS) { \ + bmlib_log("BM_CHECK",16,"BM_CHECK_RET failed %d\n", ret);\ + ASSERT(0); \ + exit(-ret); \ + } \ + } while (0) +#endif + +/*******************handle releated functions *********************************/ +/** + * @name bm_dev_getcount + * @brief To get the number of sophon devices in system. + * If N is got, valid devid is [0, N-1] + * @ingroup bmlib_runtime + * + * @param [out] count The result number of sophon devices + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_dev_getcount(int *count); + +/** + * @name bm_dev_query + * @brief To query if a device is present + * @ingroup bmlib_runtime + * + * @param [in] devid The id of the device to query + * @retval BM_SUCCESS Device is present + * Other code Devcie is not present + */ +DECL_EXPORT bm_status_t bm_dev_query(int devid); + +/** + * @name bm_dev_request + * @brief To create a handle for the given device + * @ingroup bmlib_runtime + * + * @param [out] handle The created handle + * @param [in] devid Specify on which device to create handle + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_dev_request(bm_handle_t *handle, int devid); + +/** + * @name bm_get_devid + * @brief To get device index for the given handle + * @ingroup bmlib_runtime + * + * @param [in] handle The given handle + * @retval int device index that the handle points to. + */ +DECL_EXPORT int bm_get_devid(bm_handle_t handle); + +/** + * @name bm_dev_free + * @brief To free a handle + * @ingroup bmlib_runtime + * + * @param [in] handle The handle to free + */ +DECL_EXPORT void bm_dev_free(bm_handle_t handle); + +/*******************memory help functions ************************************/ +/** + * @name bm_mem_get_type + * @brief To get a memory descriptor's type + * @ingroup bmlib_runtime + * + * @param [in] mem The memory descriptor queried + * @retval BM_MEM_TYPE_DEVICE Device global memory + * @retval BM_MEM_TYPE_SYSTEM Host user memory + */ +DECL_EXPORT bm_mem_type_t bm_mem_get_type(struct bm_mem_desc mem); + +/** + * @name sg_mem_get_type + * @brief To get a memory descriptor's type + * @ingroup bmlib_runtime + * + * @param [in] mem The memory descriptor queried + * @retval BM_MEM_TYPE_DEVICE Device global memory + * @retval BM_MEM_TYPE_SYSTEM Host user memory + */ +DECL_EXPORT bm_mem_type_t sg_mem_get_type(struct sg_mem_desc mem); + +/** + * @name bm_mem_get_device_addr + * @brief To get a device memory descriptor's address + * @ingroup bmlib_runtime + * + * @param [in] mem The device memory descriptor queried + * @retval unsigned long long The device memory address + */ +DECL_EXPORT unsigned long long bm_mem_get_device_addr(struct bm_mem_desc mem); + +/** + * @name sg_mem_get_device_addr + * @brief To get a device memory descriptor's address + * @ingroup bmlib_runtime + * + * @param [in] mem The device memory descriptor queried + * @retval unsigned long long The device memory address + */ +DECL_EXPORT unsigned long long sg_mem_get_device_addr(struct sg_mem_desc mem); + +/** + * @name bm_mem_set_device_addr + * @brief To set a device memory descriptor's address + * @ingroup bmlib_runtime + * + * @param [in] pmem The device memory descriptor pointer + * @param ]in] addr The new device address of the device memory + */ +DECL_EXPORT void bm_mem_set_device_addr(struct bm_mem_desc* pmem, unsigned long long addr); + +/** + * @name sg_mem_set_device_addr + * @brief To set a device memory descriptor's address + * @ingroup bmlib_runtime + * + * @param [in] pmem The device memory descriptor pointer + * @param ]in] addr The new device address of the device memory + */ +DECL_EXPORT void sg_mem_set_device_addr(struct sg_mem_desc* pmem, unsigned long long addr); + +/** + * @name bm_mem_get_device_size + * @brief To get a device memory descriptor's size + * @ingroup bmlib_runtime + * + * @param [in] mem The device memory descriptor queried + * @retval unsigned int The device memory's size in bytes + */ +DECL_EXPORT unsigned int bm_mem_get_device_size(struct bm_mem_desc mem); + +/** + * @name sg_mem_get_device_size + * @brief To get a device memory descriptor's size + * @ingroup bmlib_runtime + * + * @param [in] mem The device memory descriptor queried + * @retval unsigned int The device memory's size in bytes + */ +DECL_EXPORT unsigned long long sg_mem_get_device_size(struct sg_mem_desc mem); + +/** + * @name bm_mem_set_device_size + * @brief To set a device memory descriptor's size + * @ingroup bmlib_runtime + * + * @param [out] pmem The device memory descriptor pointer + * @param [in] size The new device memory size (in bytes) of the device memory + */ +DECL_EXPORT void bm_mem_set_device_size(struct bm_mem_desc* pmem, unsigned int size); + +/** + * @name sg_mem_set_device_size + * @brief To set a device memory descriptor's size + * @ingroup bmlib_runtime + * + * @param [out] pmem The device memory descriptor pointer + * @param [in] size The new device memory size (in bytes) of the device memory + */ +DECL_EXPORT void sg_mem_set_device_size(struct sg_mem_desc* pmem, unsigned long long size); + +/** + * @name bm_set_device_mem + * @brief To fill in a device memory descriptor with size and address + * @ingroup bmlib_runtime + * + * @param [in] pmem The device memory descriptor pointer + * @param [in] size The device memory descriptor's size + * @param [in] addr The device memory descriptor's address + */ +DECL_EXPORT void bm_set_device_mem(bm_device_mem_t* pmem, unsigned int size, + unsigned long long addr); + +/** + * @name sg_set_device_mem + * @brief To fill in a device memory descriptor with size and address + * @ingroup bmlib_runtime + * + * @param [in] pmem The device memory descriptor pointer + * @param [in] size The device memory descriptor's size + * @param [in] addr The device memory descriptor's address + */ +DECL_EXPORT void sg_set_device_mem(sg_device_mem_t* pmem, unsigned long long size, + unsigned long long addr); + +/** + * @name bm_mem_from_device + * @brief To create a device memory descriptor from address and size + * @ingroup bmlib_runtime + * + * @param [in] device_addr The device memory address + * @param [in] len The device memory size + * @retval bm_device_mem_t The device memory descriptor created + */ +DECL_EXPORT bm_device_mem_t bm_mem_from_device(unsigned long long device_addr, + unsigned int len); + +/** + * @name sg_mem_from_device + * @brief To create a device memory descriptor from address and size + * @ingroup bmlib_runtime + * + * @param [in] device_addr The device memory address + * @param [in] len The device memory size + * @retval bm_device_mem_t The device memory descriptor created + */ +DECL_EXPORT sg_device_mem_t sg_mem_from_device(unsigned long long device_addr, + unsigned long long len); + +/** + * @name bm_mem_get_system_addr + * @brief To get a system memory descriptor's address + * @ingroup bmlib_runtime + * + * @param [in] mem The system memory descriptor + * @retval void * The system memory descriptor's address + */ +DECL_EXPORT void *bm_mem_get_system_addr(struct bm_mem_desc mem); + +/** + * @name sg_mem_get_system_addr + * @brief To get a system memory descriptor's address + * @ingroup bmlib_runtime + * + * @param [in] mem The system memory descriptor + * @retval void * The system memory descriptor's address + */ +DECL_EXPORT void *sg_mem_get_system_addr(struct sg_mem_desc mem); + +/** + * @name bm_mem_set_system_addr + * @brief To set a system memory descriptor's address + * @ingroup bmlib_runtime + * + * @param [in] pmem The system memory descriptor pointer + * @param [in] addr The system memory address + */ +DECL_EXPORT void bm_mem_set_system_addr(struct bm_mem_desc* pmem, void *addr); + +/** + * @name sg_mem_set_system_addr + * @brief To set a system memory descriptor's address + * @ingroup bmlib_runtime + * + * @param [in] pmem The system memory descriptor pointer + * @param [in] addr The system memory address + */ +DECL_EXPORT void sg_mem_set_system_addr(struct sg_mem_desc* pmem, void *addr); + +/** + * @name bm_mem_from_system + * @brief To create a system memory descriptor with the given system address + * @ingroup bmlib_runtime + * + * @param [in] system_addr The system address in the descriptor + * @retval bm_system_mem_t The system memory descriptor created + */ +DECL_EXPORT bm_system_mem_t bm_mem_from_system(void *system_addr); + +/*******************memory alloc and free functions ***************************/ +/** + * @name bm_mem_null + * @brief Return an illegal device memory descriptor + * @ingroup bmlib_runtime + * + * @retval bm_device_mem_t An invalid device memory descriptor + */ +DECL_EXPORT bm_device_mem_t bm_mem_null(void); +#define BM_MEM_NULL (bm_mem_null()) + +/** + * @name bm_malloc_neuron_device + * @brief To malloc device memory according to a tensor shape + * (each neuron is 32 bits) + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] pmem The result devcie memory descriptor + * @param [in] n, c, h, w The shape of the input tensor + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_malloc_neuron_device(bm_handle_t handle, bm_device_mem_t *pmem, + int n, int c, int h, int w); + +/** + * @name sg_malloc_neuron_device + * @brief To malloc device memory according to a tensor shape + * (each neuron is 32 bits) + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] pmem The result devcie memory descriptor + * @param [in] n, c, h, w The shape of the input tensor + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t sg_malloc_neuron_device(bm_handle_t handle, sg_device_mem_t *pmem, + unsigned long long n, unsigned long long c, + unsigned long long h, unsigned long long w); + +/** + * @name bm_malloc_device_dword + * @brief To malloc device memory in size of dword (32 bits) + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] pmem The result device memory descriptor + * @param [in] count The number of dwords(32bits) to allocate + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_malloc_device_dword(bm_handle_t handle, bm_device_mem_t *pmem, + int count); + +/** + * @name sg_malloc_device_dword + * @brief To malloc device memory in size of dword (32 bits) + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] pmem The result device memory descriptor + * @param [in] count The number of dwords(32bits) to allocate + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t sg_malloc_device_dword(bm_handle_t handle, sg_device_mem_t *pmem, + unsigned long long count); + +/** + * @name bm_malloc_device_byte + * @brief To malloc device memory in size of byte + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] pmem The result device memory descriptor + * @param [in] size The number of bytes to allocate + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_malloc_device_byte(bm_handle_t handle, bm_device_mem_t *pmem, + unsigned int size); + +/** + * @name sg_malloc_device_byte + * @brief To malloc device memory in size of byte + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] pmem The result device memory descriptor + * @param [in] size The number of bytes to allocate + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t sg_malloc_device_byte(bm_handle_t handle, sg_device_mem_t *pmem, + unsigned long long size); + +/** + * @name bm_malloc_device_byte_heap + * @brief To malloc device memory in size of byte within the specified heap + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] pmem The result device memory descriptor + * @param [in] heap_id The heap where to allocate 0/1/2 + * @param [in] size The number of bytes to allocate + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_malloc_device_byte_heap(bm_handle_t handle, bm_device_mem_t *pmem, + int heap_id, unsigned int size); + +/** + * @name sg_malloc_device_byte_heap + * @brief To malloc device memory in size of byte within the specified heap + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] pmem The result device memory descriptor + * @param [in] heap_id The heap where to allocate 0/1/2 + * @param [in] size The number of bytes to allocate + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t sg_malloc_device_byte_heap(bm_handle_t handle, sg_device_mem_t *pmem, + int heap_id, unsigned long long size); + +/** + * @name bm_malloc_device_byte_heap_mask + * @brief To malloc device memory in size of byte within the specified heaps + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] pmem The result device memory descriptor + * @param [in] heap_id_mask The mask which heaps allocate from. each bit indicate one heap + * @param [in] size The number of bytes to allocate + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_malloc_device_byte_heap_mask(bm_handle_t handle, bm_device_mem_t *pmem, + int heap_id_mask, unsigned int size); + +/** + * @name sg_malloc_device_byte_heap_mask + * @brief To malloc device memory in size of byte within the specified heaps + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] pmem The result device memory descriptor + * @param [in] heap_id_mask The mask which heaps allocate from. each bit indicate one heap + * @param [in] size The number of bytes to allocate + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t sg_malloc_device_byte_heap_mask(bm_handle_t handle, sg_device_mem_t *pmem, + int heap_id_mask, unsigned long long size); + +/** + * @name bm_free_device + * @brief To free device memory + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] mem The device memory descriptor to free + */ +DECL_EXPORT void bm_free_device(bm_handle_t handle, bm_device_mem_t mem); + +/** + * @name sg_free_device + * @brief To free device memory + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] mem The device memory descriptor to free + */ +DECL_EXPORT void sg_free_device(bm_handle_t handle, sg_device_mem_t mem); + +/** + * @name bm_gmem_arm_reserved_request + * @brief To obtain the address of global memory reserved for arm926 + * @param [in] handle The device handle + * + * @retval unsigned long long The absolute address of gmem reserved for arm926 + */ +DECL_EXPORT unsigned long long bm_gmem_arm_reserved_request(bm_handle_t handle); + +/** + * @name bm_gmem_arm_reserved_release + * @brief To release the global memory reserved for arm926 + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + */ +DECL_EXPORT void bm_gmem_arm_reserved_release(bm_handle_t handle); + +/*******************memory copy functions *************************************/ +/** + * @name bm_memcpy_s2d + * @brief To copy data from system memory to device memory + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination memory (device memory descriptor ) + * @param [in] src The source memory (system memory, a void* pointer) + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_memcpy_s2d(bm_handle_t handle, bm_device_mem_t dst, void *src); + +/** + * @name bm_memcpy_p2p + * @brief To copy data from one chip to another chip + * @ingroup bmlib_runtime + * + * @param [in] handle_src The source device handle + * @param [in] src The source memory (device memory descriptor ) + * @param [in] handle_dst The destination device handle + * @param [in] dst The destination memory (device memory descriptor ) + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_memcpy_p2p(bm_handle_t handle_src, bm_device_mem_t src, bm_handle_t handle_dst,bm_device_mem_t dst); + +/** + * @name sg_memcpy_s2d + * @brief To copy data from system memory to device memory + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination memory (device memory descriptor ) + * @param [in] src The source memory (system memory, a void* pointer) + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t sg_memcpy_s2d(bm_handle_t handle, sg_device_mem_t dst, void *src); + +/** + * @name bm_memcpy_s2d_partial_offset + * @brief To copy specified bytes of data from system memory to device memory + * with an offset in device memory address. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination memory (device memory descriptor) + * @param [in] src The source memory (system memory, a void* pointer) + * @param [in] size The size of data to copy (in bytes) + * @param [in] offset The offset of the device memory address + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_memcpy_s2d_partial_offset(bm_handle_t handle, + bm_device_mem_t dst, void *src, + unsigned int size, + unsigned int offset); + +/** + * @name sg_memcpy_s2d_partial_offset + * @brief To copy specified bytes of data from system memory to device memory + * with an offset in device memory address. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination memory (device memory descriptor) + * @param [in] src The source memory (system memory, a void* pointer) + * @param [in] size The size of data to copy (in bytes) + * @param [in] offset The offset of the device memory address + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t sg_memcpy_s2d_partial_offset(bm_handle_t handle, + sg_device_mem_t dst, void *src, + unsigned long long size, + unsigned long long offset); + +/** + * @name bm_memcpy_s2d_partial + * @brief To copy specified bytes of data from system memory to device memory + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination memory (device memory descriptor) + * @param [in] src The source memory (system memory, a void* pointer) + * @param [in] size The size of data to copy (in bytes) + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_memcpy_s2d_partial(bm_handle_t handle, bm_device_mem_t dst, + void *src, unsigned int size); + +/** + * @name sg_memcpy_s2d_partial + * @brief To copy specified bytes of data from system memory to device memory + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination memory (device memory descriptor) + * @param [in] src The source memory (system memory, a void* pointer) + * @param [in] size The size of data to copy (in bytes) + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t sg_memcpy_s2d_partial(bm_handle_t handle, sg_device_mem_t dst, + void *src, unsigned long long size); + +/** + * @name bm_memcpy_d2s + * @brief To copy data from device memory to system memory + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination memory (system memory, a void* pointer) + * @param [in] src The source memory (device memory descriptor) + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_memcpy_d2s(bm_handle_t handle, void *dst, bm_device_mem_t src); + +/** + * @name sg_memcpy_d2s + * @brief To copy data from device memory to system memory + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination memory (system memory, a void* pointer) + * @param [in] src The source memory (device memory descriptor) + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t sg_memcpy_d2s(bm_handle_t handle, void *dst, sg_device_mem_t src); + +/** + * @name bm_memcpy_d2s_partial_offset + * @brief To copy specified bytes of data from device memory to system memory + * with an offset in device memory address. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination memory (system memory, a void* pointer) + * @param [in] src The source memory (device memory descriptor) + * @param [in] size The size of data to copy (in bytes) + * @param [in] offset The offset of the device memory address + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_memcpy_d2s_partial_offset(bm_handle_t handle, void *dst, + bm_device_mem_t src, unsigned int size, + unsigned int offset); + +/** + * @name sg_memcpy_d2s_partial_offset + * @brief To copy specified bytes of data from device memory to system memory + * with an offset in device memory address. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination memory (system memory, a void* pointer) + * @param [in] src The source memory (device memory descriptor) + * @param [in] size The size of data to copy (in bytes) + * @param [in] offset The offset of the device memory address + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t sg_memcpy_d2s_partial_offset(bm_handle_t handle, void *dst, + sg_device_mem_t src, unsigned long long size, + unsigned long long offset); + +/** + * @name bm_memcpy_d2s_partial + * @brief To copy specified bytes of data from device memory to system memory + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination memory (system memory, a void* pointer) + * @param [in] src The source memory (device memory descriptor) + * @param [in] size The size of data to copy (in bytes) + * + * @retval BM_SUCCESS Data transfer succeeds. + * Other code Data transfer fails. + */ +DECL_EXPORT bm_status_t bm_memcpy_d2s_partial(bm_handle_t handle, void *dst, + bm_device_mem_t src, unsigned int size); + +/** + * @name sg_memcpy_d2s_partial + * @brief To copy specified bytes of data from device memory to system memory + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination memory (system memory, a void* pointer) + * @param [in] src The source memory (device memory descriptor) + * @param [in] size The size of data to copy (in bytes) + * + * @retval BM_SUCCESS Data transfer succeeds. + * Other code Data transfer fails. + */ +DECL_EXPORT bm_status_t sg_memcpy_d2s_partial(bm_handle_t handle, void *dst, + sg_device_mem_t src, unsigned long long size); + +/** + * @name bm_memcpy_d2d + * @brief To copy specified dwords of data from one piece of device memory + * to another piece of device memory within one device. Both source + * and destination offsets can be specified. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination device memory + * @param [in] dst_offset The offset of destination device memory address + * @param [in] src The source device memory + * @param [in] src_offset The offset of source device memory address + * @param [in] len Length of data to copy (in DWORD 4 bytes) + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_memcpy_d2d(bm_handle_t handle, bm_device_mem_t dst, + int dst_offset, bm_device_mem_t src, int src_offset, + int len); + +/** + * @name bm_memcpy_d2d_with_core + * @brief To copy specified dwords of data from one piece of device memory + * to another piece of device memory within one device. Both source + * and destination offsets can be specified. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination device memory + * @param [in] dst_offset The offset of destination device memory address + * @param [in] src The source device memory + * @param [in] src_offset The offset of source device memory address + * @param [in] len Length of data to copy (in DWORD 4 bytes) + * @param [in] core_id The core id to copy + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_memcpy_d2d_with_core(bm_handle_t handle, bm_device_mem_t dst, + int dst_offset, bm_device_mem_t src, int src_offset, + int len, int core_id); + +/** + * @name bm_memcpy_d2d_byte + * @brief To copy specified bytes of data from one piece of device memory + * to another piece of device memory within one device. Both source + * and destination offsets can be specified. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination device memory + * @param [in] dst_offset The offset of destination device memory address (in bytes) + * @param [in] src The source device memory + * @param [in] src_offset The offset of source device memory address (in bytes) + * @param [in] size Size of data to copy (in bytes) + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_memcpy_d2d_byte(bm_handle_t handle, bm_device_mem_t dst, + size_t dst_offset, bm_device_mem_t src, + size_t src_offset, size_t size); + +/** + * @name bm_memcpy_d2d_byte_with_core + * @brief To copy specified bytes of data from one piece of device memory + * to another piece of device memory within one device. Both source + * and destination offsets can be specified. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination device memory + * @param [in] dst_offset The offset of destination device memory address (in bytes) + * @param [in] src The source device memory + * @param [in] src_offset The offset of source device memory address (in bytes) + * @param [in] size Size of data to copy (in bytes) + * @param [in] core_id The core id to copy + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_memcpy_d2d_byte_with_core(bm_handle_t handle, bm_device_mem_t dst, + size_t dst_offset, bm_device_mem_t src, + size_t src_offset, size_t size, int core_id); + +/** + * @name bm_memcpy_d2d_stride + * @brief To copy specified data from one piece of device memory + * to another piece of device memory within one device. Both source + * and destination offsets can be specified. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination device memory + * @param [in] dst_stride The data stride of destination data + * @param [in] src The source device memory + * @param [in] src_stride The data stride of source data + * @param [in] count Count of data to copy + * @param [in] format_size Data format byte size, such as sizeof(uint8_t), sizeof(float), etc. + * format_size only support 1/2/4. + * + * dst_stride MUST be 1, EXCEPT: dst_stride == 4 && src_stride == 1 && format_size ==1 + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_memcpy_d2d_stride(bm_handle_t handle, + bm_device_mem_t dst, + int dst_stride, + bm_device_mem_t src, + int src_stride, + int count, + int format_size); + +/** + * @name bm_memcpy_d2d_stride + * @brief To copy specified data from one piece of device memory + * to another piece of device memory within one device. Both source + * and destination offsets can be specified. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dst The destination device memory + * @param [in] dst_stride The data stride of destination data + * @param [in] src The source device memory + * @param [in] src_stride The data stride of source data + * @param [in] count Count of data to copy + * @param [in] format_size Data format byte size, such as sizeof(uint8_t), sizeof(float), etc. + * format_size only support 1/2/4. + * @param [in] core_id The core id to copy. + * + * dst_stride MUST be 1, EXCEPT: dst_stride == 4 && src_stride == 1 && format_size ==1 + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_memcpy_d2d_stride_with_core(bm_handle_t handle, + bm_device_mem_t dst, + int dst_stride, + bm_device_mem_t src, + int src_stride, + int count, + int format_size, + int core_id); + +/** + * @name bm_memcpy_c2c + * @brief To copy data from one chip to another chip. + * (Used in multi-chip card scenario) + * @ingroup bmlib_runtime + * + * @param [in] src_handle The source device handle + * @param [in] dst_handle The destination device handle + * @param [in] src The source device memory descriptor + * @param [in] dst The destination device memory descriptor + * @param [in] force_dst_cdma If use the CDMA engine of the destination device + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_memcpy_c2c(bm_handle_t src_handle, bm_handle_t dst_handle, + bm_device_mem_t src, bm_device_mem_t dst, + bool force_dst_cdma); + +/** + * @name bm_memset_device + * @brief To fill in specified device memory with the given value + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] value The value used to fill. (int type) + * @param [in] mem The device memory which will be filled in + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_memset_device(bm_handle_t handle, const int value, + bm_device_mem_t mem); + +/** + * @name bm_memset_device_ext + * @brief To fill in specified device memory with the given value and mode + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] value The pointer of value used to fill + * @param [in] mode The valid bytes of *value + * @param [in] mem The device memory which will be filled in + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_memset_device_ext(bm_handle_t handle, void* value, int mode, + bm_device_mem_t mem); + +/** + * @name bm_mem_convert_system_to_device_neuron + * @brief To malloc a piece of device memory according to the shape of + * neuron(in DWORD 4 bytes); copy neuron from system memory to + * device memory if need_copy is true. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dev_mem The device memory descriptor + * @param [in] sys_mem The system memory descriptor + * @param [in] need_copy If copy from system to device is needed + * @param [in] n,c,h,w Neuron shape size + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_neuron(bm_handle_t handle, + struct bm_mem_desc *dev_mem, + struct bm_mem_desc sys_mem, + bool need_copy, int n, int c, + int h, int w); + +/** + * @name bm_mem_convert_system_to_device_neuron_byte + * @brief To malloc a piece of device memory according to the shape of + * neuron(in bytes); copy neuron from system memory to + * device memory if need_copy is true. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dev_mem The device memory descriptor + * @param [in] sys_mem The system memory descriptor + * @param [in] need_copy If copy from system to device is needed + * @param [in] n,c,h,w Neuron shape size + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_neuron_byte( + bm_handle_t handle, struct bm_mem_desc *dev_mem, struct bm_mem_desc sys_mem, + bool need_copy, int n, int c, int h, int w); + +/** + * @name bm_mem_convert_system_to_device_coeff + * @brief To malloc a piece of device memory according to the size of + * coefficient (in DWORD 4 bytes); copy coefficient from system + * memory to device memory if need_copy is true. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dev_mem The device memory descriptor + * @param [in] sys_mem The system memory descriptor + * @param [in] need_copy If copy from system to device is needed + * @param [in] coeff_count Coefficient size + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_coeff(bm_handle_t handle, + struct bm_mem_desc *dev_mem, + struct bm_mem_desc sys_mem, + bool need_copy, + int coeff_count); +/** + * @name bm_mem_convert_system_to_device_coeff_byte + * @brief To malloc a piece of device memory according to the size of + * coefficient (in bytes); copy coefficient from system + * memory to device memory if need_copy is true. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dev_mem The device memory descriptor + * @param [in] sys_mem The system memory descriptor + * @param [in] need_copy If copy from system to device is needed + * @param [in] coeff_count Coefficient size + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_mem_convert_system_to_device_coeff_byte( + bm_handle_t handle, struct bm_mem_desc *dev_mem, struct bm_mem_desc sys_mem, + bool need_copy, int coeff_count); + +/*******************memory map functions *************************************/ +/** + * @name bm_mem_mmap_device_mem + * @brief To map a piece of device memory to user space with cache enabled. + * (only valid in SoC mode; Not supported in PCIE mode). + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dev_mem The device memory to map + * @param [out] vmem The virtual address of the mapped device memory + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_mem_mmap_device_mem(bm_handle_t handle, bm_device_mem_t *dmem, + + unsigned long long *vmem); + +/** + * @name sg_mem_mmap_device_mem + * @brief To map a piece of device memory to user space with cache enabled. + * (only valid in SoC mode; Not supported in PCIE mode). + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dev_mem The device memory to map + * @param [out] vmem The virtual address of the mapped device memory + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t sg_mem_mmap_device_mem(bm_handle_t handle, sg_device_mem_t *dmem, + unsigned long long *vmem); + +/*******************memory map functions *************************************/ +/** + * @name bm_mem_mmap_device_mem_no_cache + * @brief To map a piece of device memory to user space with cache disabled. + * (only valid in SoC mode; Not supported in PCIE mode). + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dev_mem The device memory to map + * @param [out] vmem The virtual address of the mapped device memory + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_mem_mmap_device_mem_no_cache(bm_handle_t handle, bm_device_mem_t *dmem, + + unsigned long long *vmem); + +/** + * @name sg_mem_mmap_device_mem_no_cache + * @brief To map a piece of device memory to user space with cache disabled. + * (only valid in SoC mode; Not supported in PCIE mode). + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dev_mem The device memory to map + * @param [out] vmem The virtual address of the mapped device memory + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t sg_mem_mmap_device_mem_no_cache(bm_handle_t handle, sg_device_mem_t *dmem, + unsigned long long *vmem); + +/** + * @name bm_mem_vir_to_phy + * @brief To get device mem address through the mapped virtual address . + * (only valid in SoC mode; Not supported in PCIE mode). + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] vmem The virtual address of the mapped device memory + * @param [out] dev_mem The device memory address + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_mem_vir_to_phy(bm_handle_t handle, unsigned long long vmem, + unsigned long long *device_mem); +/** + * @name bm_mem_invalidate_device_mem + * @brief To invalidate a piece of mapped device memory to maintain + * cache coherence + * (only valid in SoC mode; Not supported in PCIE mode). + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dmem The device memory to invalidate + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ + +DECL_EXPORT bm_status_t bm_mem_invalidate_device_mem(bm_handle_t handle, + bm_device_mem_t *dmem); + +/** + * @name sg_mem_invalidate_device_mem + * @brief To invalidate a piece of mapped device memory to maintain + * cache coherence + * (only valid in SoC mode; Not supported in PCIE mode). + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dmem The device memory to invalidate + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ + +DECL_EXPORT bm_status_t sg_mem_invalidate_device_mem(bm_handle_t handle, + sg_device_mem_t *dmem); + +/** + * @name bm_mem_invalidate_partial_device_mem + * @brief To invalidate part of mapped device memory to maintain + * cache coherence + * (only valid in SoC mode; Not supported in PCIE mode). + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dmem The device memory to invalidate + * @param [in] offset The offset of device memory address + * @param [in] len The length of memory to invalidate in bytes + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_mem_invalidate_partial_device_mem(bm_handle_t handle, + bm_device_mem_t *dmem, + unsigned int offset, + unsigned int len); + +/** + * @name sg_mem_invalidate_partial_device_mem + * @brief To invalidate part of mapped device memory to maintain + * cache coherence + * (only valid in SoC mode; Not supported in PCIE mode). + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dmem The device memory to invalidate + * @param [in] offset The offset of device memory address + * @param [in] len The length of memory to invalidate in bytes + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t sg_mem_invalidate_partial_device_mem(bm_handle_t handle, + sg_device_mem_t *dmem, + unsigned long long offset, + unsigned long long len); + +/** + * @name bm_mem_flush_device_mem + * @brief To flush a piece of mapped device memory to maintain + * cache coherence + * (only valid in SoC mode; Not supported in PCIE mode). + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dmem The device memory to flush + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_mem_flush_device_mem(bm_handle_t handle, bm_device_mem_t *dmem); + +/** + * @name sg_mem_flush_device_mem + * @brief To flush a piece of mapped device memory to maintain + * cache coherence + * (only valid in SoC mode; Not supported in PCIE mode). + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dmem The device memory to flush + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t sg_mem_flush_device_mem(bm_handle_t handle, sg_device_mem_t *dmem); + +/** + * @name bm_mem_flush_partial_device_mem + * @brief To flush part of mapped device memory to maintain + * cache coherence + * (only valid in SoC mode; Not supported in PCIE mode). + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dmem The device memory to flush + * @param [in] offset The offset of device memory address + * @param [in] len The length of memory to flush in bytes + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_mem_flush_partial_device_mem(bm_handle_t handle, + bm_device_mem_t *dmem, + unsigned int offset, + unsigned int len); + +/** + * @name sg_mem_flush_partial_device_mem + * @brief To flush part of mapped device memory to maintain + * cache coherence + * (only valid in SoC mode; Not supported in PCIE mode). + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] dmem The device memory to flush + * @param [in] offset The offset of device memory address + * @param [in] len The length of memory to flush in bytes + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t sg_mem_flush_partial_device_mem(bm_handle_t handle, + sg_device_mem_t *dmem, + unsigned long long offset, + unsigned long long len); + +/** + * @name bm_mem_unmap_device_mem + * @brief To unmap a piece of mapped device memory + * (only valid in SoC mode; Not supported in PCIE mode). + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] vmem The virtual address of the mapped device memory + * @param [in] size The size of unmapped memory + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_mem_unmap_device_mem(bm_handle_t handle, void *vmem, int size); + +/** + * @name sg_mem_unmap_device_mem + * @brief To unmap a piece of mapped device memory + * (only valid in SoC mode; Not supported in PCIE mode). + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] vmem The virtual address of the mapped device memory + * @param [in] size The size of unmapped memory + * + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t sg_mem_unmap_device_mem(bm_handle_t handle, void *vmem, unsigned long long size); + +/*******************api(kernel) functions *************************************/ +/** + * @name bm_flush + * @brief To synchronize APIs of the current thread. The thread will block + * until all the outstanding APIs of the current thread are finished. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + */ +DECL_EXPORT void bm_flush(bm_handle_t handle); + +/** + * @name bm_device_sync + * @brief To synchronize APIs of the device. The thread will block + * until all the outstanding APIs of the device are finished. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_device_sync(bm_handle_t handle); + +/** + * @name bm_handle_sync + * @brief To synchronize APIs of the handle. The thread will block + * until all the outstanding APIs of the handle are finished. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_handle_sync(bm_handle_t handle); + +/** + * @name bm_handle_sync_from_core + * @brief To synchronize APIs of the handle. The thread will block + * until all the outstanding APIs of the handle are finished. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] core_id The core id + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_handle_sync_from_core(bm_handle_t handle, int core_id); + +/** + * @name bm_thread_sync + * @brief To synchronize APIs of the current thread. The thread will block + * until all the outstanding APIs of the current thread are finished. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_thread_sync(bm_handle_t handle); + +/** + * @name bm_thread_sync_from_core + * @brief To synchronize APIs of the current thread. The thread will block + * until all the outstanding APIs of the current thread are finished. + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] core_id The core id + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_thread_sync_from_core(bm_handle_t handle, int core_id); + +/*******************trace and profile releated functions **********************/ +typedef struct bm_profile { +#ifdef __linux__ + unsigned long cdma_in_time; + unsigned long cdma_in_counter; + unsigned long cdma_out_time; + unsigned long cdma_out_counter; + unsigned long tpu_process_time; + unsigned long sent_api_counter; + unsigned long completed_api_counter; +#else + unsigned long long cdma_in_time; + unsigned long long cdma_in_counter; + unsigned long long cdma_out_time; + unsigned long long cdma_out_counter; + unsigned long long tpu_process_time; + unsigned long long sent_api_counter; + unsigned long long completed_api_counter; +#endif +} bm_profile_t; +/** + * @name bm_get_profile + * @brief To get the profile data at the moment + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] profile The result profile data + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_profile(bm_handle_t handle, bm_profile_t *profile); + +typedef struct bootloader_version{ + char *bl1_version; + char *bl2_version; + char *bl31_version; + char *uboot_version; +} boot_loader_version; + +/** + * @name bm_get_boot_loader_version + * @brief To get the boot_loader_version + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] version The result version data + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_boot_loader_version(bm_handle_t handle, boot_loader_version *version); + +/** + * @name bm_get_vpu_instant_usage + * @brief To get vpu usage + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] smi_attr The result vpu usage + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_vpu_instant_usage(bm_handle_t handle, int *vpu_usage); + +/** + * @name bm_get_jpu_core_usage + * @brief To get the jpu usage + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] smi_attr The result jpu usage + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_jpu_core_usage(bm_handle_t handle, int *jpu_usage); + +/** + * @name bm_get_vpp_instant_usage + * @brief To get the vpp usage + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] smi_attr The result vpp usage + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_vpp_instant_usage(bm_handle_t handle, int *vpp_usage); +/** + * @name bm_get_last_api_process_time_us + * @brief This function is abandoned. + */ +#ifdef __linux__ +DECL_EXPORT bm_status_t bm_get_last_api_process_time_us(bm_handle_t handle, + unsigned long *time_us); +#else +DECL_EXPORT bm_status_t bm_get_last_api_process_time_us(bm_handle_t handle, + unsigned long long *time_us); +#endif +/*******************tpu clock and module reset releated functions *************/ + +/** + * @name bm_set_clk_tpu_freq + * @brief To set the clock frequency of TPU (only valid in PCIE mode). + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] freq The TPU target frequency + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_set_clk_tpu_freq(bm_handle_t handle, int freq); + +/** + * @name bm_get_clk_tpu_freq + * @brief To get the clock frequency of TPU + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] freq The current TPU frequency + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_clk_tpu_freq(bm_handle_t handle, int *freq); + +/*******************misc functions ********************************************/ +struct bm_misc_info { + int pcie_soc_mode; /*0---pcie; 1---soc*/ + int ddr_ecc_enable; /*0---disable; 1---enable*/ + long long ddr0a_size; + long long ddr0b_size; + long long ddr1_size; + long long ddr2_size; + unsigned int chipid; +#define BM1682_CHIPID_BIT_MASK (0X1 << 0) +#define BM1684_CHIPID_BIT_MASK (0X1 << 1) +#define BM1686_CHIPID_BIT_MASK (0X1 << 2) +#ifdef __linux__ + unsigned long chipid_bit_mask; +#else + unsigned long long chipid_bit_mask; +#endif + unsigned int driver_version; + int domain_bdf; + int board_version; /*hardware board version [23:16]-mcu sw version, [15:8]-board type, [7:0]-hw version*/ + int a53_enable; + int dyn_enable; +}; + +/** + * @name bm_get_misc_info + * @brief To get miscellaneous information of the device + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] pmisc_info The fetched misc info + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_misc_info(bm_handle_t handle, struct bm_misc_info *pmisc_info); + +/** + * @name bm_get_chipid + * @brief To get the chipid of the device. (0x1682 / 0x1684 / 0x168?) + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] p_chipid The chip id of the device + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_chipid(bm_handle_t handle, unsigned int *p_chipid); + +#define BMLIB_LOG_QUIET -8 +#define BMLIB_LOG_PANIC 0 +#define BMLIB_LOG_FATAL 8 +#define BMLIB_LOG_ERROR 16 +#define BMLIB_LOG_WARNING 24 +#define BMLIB_LOG_INFO 32 +#define BMLIB_LOG_VERBOSE 40 +#define BMLIB_LOG_DEBUG 48 +#define BMLIB_LOG_TRACE 56 + +/** + * @name bmlib_log_get_level + * @brief To get the bmlib log level + * @ingroup bmlib_log + * + * @param void + * @retval The level of bmlib log level + */ +DECL_EXPORT int bmlib_log_get_level(void); + +/** + * @name bmlib_log_set_level + * @brief To set the bmlib log level + * @ingroup bmlib_log + * + * @param [in] level The level of bmlib log level + * @retval void + */ +DECL_EXPORT void bmlib_log_set_level(int level); + +/** + * @name bmlib_log_set_callback + * @brief To set callback to get bmlib log + * @ingroup bmlib_log + * + * @param [in] callback The callback function to get bmlib log + * @retval void + */ +DECL_EXPORT void bmlib_log_set_callback(void (*callback)(const char*, int, const char*, va_list args)); + +/** + * @name bm_set_debug_mode + * @brief To set the debug mode for firmware log for tpu + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @param [in] mode The debug mode of fw log, 0/1 for disable/enable log + * @retval void + */ +DECL_EXPORT void bm_set_debug_mode(bm_handle_t handle, int mode); + +/** + * @name bmlib_api_dbg_callback + * @brief To set debug callback to get firmware log + * @ingroup bmlib_log + * + * @param [in] bmlib_api_dbg_callback callback to get firmware log + * @retval void + */ +typedef void (*bmlib_api_dbg_callback)(int, int, int, const char*); +// api, result, duratioin, log, third int for api duration for future +DECL_EXPORT void bmlib_set_api_dbg_callback(bmlib_api_dbg_callback callback); + +/** + * @name bmcpu_get_cpu_status + * @brief Get bmcpu status + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @retval BMCPU_RUNNING bmcpu is running. + * Other code Fails. + */ +DECL_EXPORT bm_cpu_status_t bmcpu_get_cpu_status(bm_handle_t handle); + +/** + * @name bmcpu_start_cpu + * @brief Start cpu in pcie mode + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @param [in] boot_file Fip file + * @param [in] core_file Itb file + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bmcpu_start_cpu(bm_handle_t handle, char *boot_file, char *core_file); + +/** + * @name bmcpu_open_process + * @brief Open a process to do some work + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @param [in] flags Process flags + * @param [in] timeout Timeout value in millisecond, -1 means default value of this device + * @retval >= 0 process handle + * < 0 Other code Fails. + */ +DECL_EXPORT int bmcpu_open_process(bm_handle_t handle, unsigned int flags, int timeout); + +/** + * @name bmcpu_load_library + * @brief Load a share library(so) to specific process + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @param [in] process_handle Process handle + * @param [in] library_file Library file path + * @param [in] timeout Timeout value in millisecond, -1 means default value of this device + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bmcpu_load_library(bm_handle_t handle, int process_handle, char *library_file, int timeout); + +/** + * @name bmcpu_unload_library + * @brief Load a share library(so) to specific process + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @param [in] process_handle Process handle + * @param [in] library_file Library file path + * @param [in] timeout Timeout value in millisecond, -1 means default value of this device + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bmcpu_unload_library(bm_handle_t handle, int process_handle, char *library_file, int timeout); + +/** + * @name bmcpu_exec_function + * @brief Execute specific function in specific process + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @param [in] process_handle Process handle + * @param [in] function_name Function name + * @param [in] function_param Function parameters + * @param [in] param_size Parameters size in bytes + * @param [in] timeout Timeout value in millisecond, -1 means default value of this device + * @retval 0 success. + * >0 code fails from bmlib + * <0 code fails from function + */ +DECL_EXPORT int bmcpu_exec_function(bm_handle_t handle, + int process_handle, + char *function_name, + void *function_param, + unsigned int param_size, + int timeout); + +#define BMCPU_EXEC_OPT_NO_FLUSH_CACHE 1 +/** + * @name bmcpu_exec_function_ext + * @brief Execute specific function in specific process + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @param [in] process_handle Process handle + * @param [in] function_name Function name + * @param [in] function_param Function parameters + * @param [in] param_size Parameters size in bytes + * @param [in] opt exec options + * @param [in] timeout Timeout value in millisecond, -1 means default value of this device + * @retval 0 success. + * >0 code fails from bmlib + * <0 code fails from function + */ +DECL_EXPORT int bmcpu_exec_function_ext(bm_handle_t handle, + int process_handle, + char *function_name, + void *function_param, + unsigned int param_size, + unsigned int opt, + int timeout); + +/** + * @name bmcpu_exec_function_async + * @brief Execute specific function in specific process asynchronous + * user should use bm_query_exec_function_result to query result + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @param [in] process_handle Process handle + * @param [in] function_name Function name + * @param [in] function_param Function param + * @param [in] param_size Param size in bytes + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bmcpu_exec_function_async(bm_handle_t handle, + int process_handle, + char *function_name, + void *function_param, + unsigned int param_size, + unsigned long long *api_handle); + +/** + * @name bmcpu_exec_function_async_ext + * @brief Execute specific function in specific process asynchronous + * user should use bm_query_exec_function_result to query result + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @param [in] process_handle Process handle + * @param [in] function_name Function name + * @param [in] function_param Function param + * @param [in] param_size Param size in bytes + * @param [in] opt exec options + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bmcpu_exec_function_async_ext(bm_handle_t handle, + int process_handle, + char *function_name, + void *function_param, + unsigned int param_size, + unsigned int opt, + unsigned long long *api_handle); + +/** + * @name bmcpu_query_exec_function_result + * @brief Query result from function called by bm_exec_function + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @param [in] api_handle Api handle return by bm_exec_function_async + * @param [in] timeout Timeout value in millisecond, -1 means default value of this device + * @retval 0 success. + * >0 code fails from bmlib + * <0 code fails from function + */ +DECL_EXPORT int bmcpu_query_exec_function_result(bm_handle_t handle, unsigned long long api_handle, int timeout); + +/** + * @name bmcpu_map_phys_addr + * @brief Map physical address in specific process + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @param [in] process_handle Process handle + * @param [in] phys_addr Physical address + * @param [in] size Map size in bytes + * @param [in] timeout Timeout value in millisecond, -1 means default value of this device + * @retval >0 virtual address + * 0 fails + */ +DECL_EXPORT void *bmcpu_map_phys_addr(bm_handle_t handle, int process_handle, void *phys_addr, unsigned int size, int timeout); + +/** + * @name bmcpu_unmap_phys_addr + * @brief Unmap physical address in specific process + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @param [in] process_handle Process handle + * @param [in] phys_addr Physical address + * @param [in] timeout Timeout value in millisecond, -1 means default value of this device + * @retval <0 fail + * 0 success + */ +DECL_EXPORT bm_status_t bmcpu_unmap_phys_addr(bm_handle_t handle, int process_handle, void *phys_addr, int timeout); + +/** + * @name bmcpu_close_process + * @brief Close process + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @param [in] process_handle Process handle + * @param [in] timeout Timeout value in millisecond, -1 means default value of this device + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bmcpu_close_process(bm_handle_t handle, int process_handle, int timeout); + +/** + * @name bmcpu_reset_cpu + * @brief Reset cpu in pcie mode + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bmcpu_reset_cpu(bm_handle_t handle); + +/** + * @name bm_enable_perf_monitor + * @brief enable perf monitor to get gdma and tpu performance data + * @ingroup bmlib_perf + * + * @param [in] handle The device handle + * @param [in] perf_monitor The monitor to perf + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_enable_perf_monitor(bm_handle_t handle, bm_perf_monitor_t *perf_monitor); + +/** + * @name bm_disable_perf_monitor + * @brief disable perf monitor to get gdma and tpu performance data + * @ingroup bmlib_perf + * + * @param [in] handle The device handle + * @param [in] perf_monitor The monitor to perf + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_disable_perf_monitor(bm_handle_t handle, bm_perf_monitor_t *perf_monitor); + +/** + * @name bmcpu_set_log + * @brief Set cpu log options + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @param [in] log_level 0: DEBUG 1:INFO 2:WARN 3:ERROR 4:FATAL + * @param [in] log_to_console 1: YES 0: No + * @param [in] timeout Timeout value in millisecond, -1 means default value of this device + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bmcpu_set_log(bm_handle_t handle, unsigned int log_level, unsigned int log_to_console, int timeout); + +/** + * @name bmcpu_get_log + * @brief Get cpu log file + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @param [in] process_handle Process handle + * @param [in] log_file save log as file + * @param [in] timeout Timeout value in millisecond, -1 means default value of this device + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bmcpu_get_log(bm_handle_t handle, int process_handle, char *log_file, int timeout); + +/** + * @name bmcpu_sync_time + * @brief Sync device cpu time with host + * @ingroup bmlib_log + * + * @param [in] handle The device handle + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bmcpu_sync_time(bm_handle_t handle); + +/*******************trace and profile releated functions **********************/ +struct bm_heap_stat { + unsigned int mem_total; + unsigned int mem_avail; + unsigned int mem_used; +}; + +typedef struct bm_heap_stat_byte { + unsigned int heap_id; + unsigned long long mem_total; + unsigned long long mem_avail; + unsigned long long mem_used; + unsigned long long mem_start_addr; +} bm_heap_stat_byte_t; + +typedef struct bm_dev_stat { + int mem_total; + int mem_used; + int tpu_util; + int heap_num; + struct bm_heap_stat heap_stat[4]; +} bm_dev_stat_t; + +/** + * @name bm_get_stat + * @brief To get the stat data at the moment + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] profile The result stat data + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_stat(bm_handle_t handle, bm_dev_stat_t *stat); + +/** + * @name bm_get_gmem_heap_id + * @brief To get the heap id of allocated global memory + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] pmem The allocted global memory + * @param [out] heapid The result of get heap id + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ + +DECL_EXPORT bm_status_t bm_get_gmem_heap_id(bm_handle_t handle, bm_device_mem_t *pmem, unsigned int *heapid); + +/** + * @name sg_get_gmem_heap_id + * @brief To get the heap id of allocated global memory + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] pmem The allocted global memory + * @param [out] heapid The result of get heap id + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ + +DECL_EXPORT bm_status_t sg_get_gmem_heap_id(bm_handle_t handle, sg_device_mem_t *pmem, unsigned int *heapid); + +/** + * @name bm_get_gmem_total_heap_num + * @brief To get the total heap num of global memory + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] heap_num The result of get total num + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_gmem_total_heap_num(bm_handle_t handle, unsigned int *heap_num); + +/** + * @name bm_get_gmem_heap_stat_byte_by_id + * @brief To get the heap stat by heap id + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] heap_id The heap index to get heap status + * @param [out] pheap_byte The result of get heap status + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_gmem_heap_stat_byte_by_id(bm_handle_t handle, bm_heap_stat_byte_t *pheap_byte, unsigned int heap_id); + +DECL_EXPORT bm_status_t bm_load_firmware( + bm_handle_t handle, + const char *firmware_tcm, + const char *firmware_ddr); + +#define bmkernel_load_firmware okkernel_load_firmware +DECL_EXPORT bm_status_t okkernel_load_firmware( + bm_handle_t handle, + const char *firmware_tcm, + const char *firmware_ddr); + +DECL_EXPORT bm_status_t okkernel_launch_async( + bm_handle_t handle, + const char *func_name, + const void *args, + unsigned int size); + +DECL_EXPORT bm_status_t okkernel_launch_sync( + bm_handle_t handle, + const char *func_name, + const void *args, + unsigned int size); + +DECL_EXPORT bm_status_t tpu_kernel_launch_sync( + bm_handle_t handle, + const char *func_name, + const void *args, + unsigned int size); + +DECL_EXPORT bm_status_t okkernel_sync(bm_handle_t handle); + +/** + * @name bmkernel_launch + * @brief send api to device and launch function + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] api cmd struct pointer + * @param [in] api cmd length + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bmkernel_launch(bm_handle_t handle, const void *args, + unsigned int size); + +/** + * @name bmkernel_load_lookup_table + * @brief load lookup table to l2-sram + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [in] table which loaded to l2-sram + * @param [in] table size + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bmkernel_load_lookup_table(bm_handle_t handle, const void* table, unsigned int size); + +/*******************device management api functions ********************************************/ +/** + * @name bm_get_tpu_current + * @brief get tpu current + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] tpuc(mA) The pointer for tpu current + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_tpu_current(bm_handle_t handle, unsigned int *tpuc); + +/** + * @name bm_get_board_max_power + * @brief get board support max power + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] maxp The pointer for maxp + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_board_max_power(bm_handle_t handle, unsigned int *maxp); + +/** + * @name bm_get_board_power + * @brief get board power + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] boardp The pointer for boardp + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_board_power(bm_handle_t handle, unsigned int *boardp); + +/** + * @name bm_get_fan_speed + * @brief get board fan speed + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] fan The pointer for fan speed + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_fan_speed(bm_handle_t handle, unsigned int *fan); + +/** + * @name bm_get_ecc_correct_num + * @brief get ecc_correct_num + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [out] ecc_correct_num + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +#ifdef __linux__ +DECL_EXPORT bm_status_t bm_get_ecc_correct_num(bm_handle_t handle, unsigned long *ecc_correct_num); +#else +DECL_EXPORT bm_status_t bm_get_ecc_correct_num(bm_handle_t handle, unsigned long long *ecc_correct_num); +#endif +/** + * @name bm_get_12v_atx + * @brief get atx_12v + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [out] atx_12v + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_12v_atx(bm_handle_t handle, int *atx_12v); + +/** + * @name bm_get_product_sn + * @brief get SE5 sn + * @ingroup device management api + * + * @param [out] product_sn + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_product_sn(char *product_sn); + +/** + * @name bm_get_sn + * @brief get sn + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [out] sn + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_sn(bm_handle_t handle, char *sn); + +/** + * @name bm_get_status + * @brief get chip status + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [out] status The board error status, each bit represents an error state + * status == 0x0, borad is nornal, staus > 0, borad is abnormal; + * bit0 == 1, tpu is hang + * bit1 == 1, pcie link abnormal + * bit2 == 1, board temperature is too high + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_status(bm_handle_t handle, int *status); + +/** + * @name bm_get_tpu_maxclk + * @brief get tpu_maxclk + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [out] tpu_maxclk + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_tpu_maxclk(bm_handle_t handle, unsigned int *tpu_maxclk); + +/** + * @name bm_get_tpu_minclk + * @brief get tpu_minclk + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [out] tpu_minclk + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_tpu_minclk(bm_handle_t handle, unsigned int *tpu_minclk); + +/** + * @name bm_get_driver_version + * @brief get driver version + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [out] driver_version + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_driver_version(bm_handle_t handle, int *driver_version); + +/** + * @name bm_get_board_name + * @brief get device board name + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [out] board_name + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_board_name(bm_handle_t handle, char *name); + +/** + * @name bm_get_board_temp + * @brief get board temperature + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [out] board_temp + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_board_temp(bm_handle_t handle, unsigned int *board_temp); + +/** + * @name bm_get_chip_temp + * @brief get chip temperature + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [out] chip_temp + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_chip_temp(bm_handle_t handle, unsigned int *chip_temp); + +/** + * @name bm_get_tpu_power + * @brief get TPU power + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [out] tpu_power + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_tpu_power(bm_handle_t handle, float *tpu_power); + +/** + * @name bm_get_tpu_volt + * @brief get TPU voltage + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [out] tpu_volt + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_tpu_volt(bm_handle_t handle, unsigned int *tpu_volt); + +/** + * @name bm_get_card_id + * @brief get card id + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [out] card_id + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_card_id(bm_handle_t handle, unsigned int *card_id); + +/** + * @name bm_get_card_num + * @brief get card number + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [out] card_id + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_card_num(unsigned int *card_num); + +/** + * @name bm_get_chip_num_from_card + * @brief get chip number and start chip id from card + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [out] chip_num + * @param [out] dev_start_index + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_chip_num_from_card(unsigned int card_id, unsigned int *chip_num, unsigned int *dev_start_index); + +/** + * @name bm_get_dynfreq_status + * @brief get chip dynamic freq status + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [out] dynfreq_status + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_dynfreq_status(bm_handle_t handle, int *dynfreq_status); + +/** + * @name bm_change_dynfreq_status + * @brief change(enable/disable) chip dynamic freq status + * @ingroup device management api + * + * @param [in] handle The device handle + * @param [in] new_status + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_change_dynfreq_status(bm_handle_t handle, int new_status); + +/** + * @name bm_get_tpu_scalar_num + * @brief To get the core number of TPU scalar + * @ingroup bmlib_runtime + * + * @param [in] handle The device handle + * @param [out] core_num The core number of TPU scalar + * @retval BM_SUCCESS Succeeds. + * Other code Fails. + */ +DECL_EXPORT bm_status_t bm_get_tpu_scalar_num(bm_handle_t handle, unsigned int *core_num); + +#define bm_get_tpu_core_num bm_get_tpu_scalar_num + +#if defined(__cplusplus) +} +#endif + +#endif /* BM_RUNTIME_H_ */ diff --git a/models/Qwen/support/include/bmruntime_interface.h b/models/Qwen/support/include/bmruntime_interface.h new file mode 100755 index 0000000..54fd90d --- /dev/null +++ b/models/Qwen/support/include/bmruntime_interface.h @@ -0,0 +1,352 @@ +/***************************************************************************** + * + * Copyright (c) 2016-2026 by Sophgo Technologies Inc. All rights reserved. + * + * The material in this file is confidential and contains trade secrets + * of Sophgo Technologies Inc. This is proprietary information owned by + * Sophgo Technologies Inc. No part of this work may be disclosed, + * reproduced, copied, transmitted, or used in any way for any purpose, + * without the express written permission of Sophgo Technologies Inc. + * + *****************************************************************************/ + +/***************************************************************************** + * BMRuntime Interface is mainly for inference. + * Also we can use it for device computation from BMLang programming. + * Note: please use interface from bmlib_runtime.h for device memory operation. + ****************************************************************************/ + +#ifndef BMRUNTIME_INTERFACE_H_ +#define BMRUNTIME_INTERFACE_H_ + +#include "bmdef.h" + +#ifdef _WIN32 +#define DECL_EXPORT _declspec(dllexport) +#define DECL_IMPORT _declspec(dllimport) +#else +#define DECL_EXPORT +#define DECL_IMPORT +#endif + +#if defined(__cplusplus) +extern "C" { +#endif + +/* --------------------------------------------------------------------------*/ +/* interface for basic data type */ + +/* get data type byte size */ +DECL_EXPORT size_t bmrt_data_type_size(bm_data_type_t dtype); + +/* +dims array to bm_shape_t, +shape and dims should not be NULL, num_dims should not be larger than BM_MAX_DIMS_NUM */ +DECL_EXPORT void bmrt_shape(bm_shape_t* shape, const int* dims, int num_dims); + +/* +number of shape elements, shape should not be NULL and num_dims should not large than +BM_MAX_DIMS_NUM */ +DECL_EXPORT uint64_t bmrt_shape_count(const bm_shape_t* shape); + +/* compare whether two shape is same */ +DECL_EXPORT bool bmrt_shape_is_same(const bm_shape_t* left, const bm_shape_t* right); + +/* +fill a tensor with data type and shape, and st_mode = 0 as default. +tensor and p_bmrt should not be NULL, shape count should not be 0. +it will alloc device mem to tensor->device_mem, so user should bmrt_free_device(p_bmrt, +tensor->device_mem) to free it.*/ +DECL_EXPORT bool bmrt_tensor(bm_tensor_t* tensor, void* p_bmrt, bm_data_type_t dtype, bm_shape_t shape); + +/* fill a tensor with device mem existed, tensor byte size should not large than device mem size */ +DECL_EXPORT void bmrt_tensor_with_device(bm_tensor_t* tensor, bm_device_mem_t device_mem, + bm_data_type_t dtype, bm_shape_t shape); + +/* get tensor bytes size, tensor should not be NULL */ +DECL_EXPORT size_t bmrt_tensor_bytesize(const bm_tensor_t* tensor); + +/* get tensor mem size allocated in device mem, tensor should not be NULL */ +DECL_EXPORT size_t bmrt_tensor_device_size(const bm_tensor_t* tensor); + +/* print net info for debug */ +DECL_EXPORT void bmrt_print_network_info(const bm_net_info_t* net_info); + +/* --------------------------------------------------------------------------*/ +/** + * @name bmrt_create + * @brief To create the bmruntime with bm_handle. + * @ingroup bmruntime + * + * This API creates the bmruntime. It returns a void* pointer which is the pointer + * of bmruntime. Device id is set when get bm_handle; + * + * @param [in] bm_handle bm handle. It must be initialized by using bmlib. + * + * @retval void* the pointer of bmruntime + */ +DECL_EXPORT void* bmrt_create(bm_handle_t bm_handle); + +/* --------------------------------------------------------------------------*/ +/** + * @name bmrt_create_ex + * @brief To create the bmruntime with one or more bm_handle. + * @ingroup bmruntime + * + * This API creates the bmruntime. It returns a void* pointer which is the pointer + * of bmruntime. + * + * @param [in] bm_handles bm handles. They must be initialized by using bmlib. + * @param [in] num_handles number of bm_handles. + * + * @retval void* the pointer of bmruntime + */ +DECL_EXPORT void *bmrt_create_ex(bm_handle_t *bm_handles, int num_handles); + +/** + * @name bmrt_destroy + * @brief To destroy the bmruntime pointer + * @ingroup bmruntime + * + * This API destroy the bmruntime. + * + * @param [in] p_bmrt Bmruntime that had been created + */ +DECL_EXPORT void bmrt_destroy(void* p_bmrt); + +/** + * @name bmrt_get_bm_handle + * @brief To get the BM runtime context. + * @ingroup bmruntime + * + * This API get the BM runtime context for using BMDNN, BMCV or BMLIB + * + * @param [in] p_bmrt Bmruntime that had been created + */ +DECL_EXPORT void * bmrt_get_bm_handle(void* p_bmrt); + +/** + * @name bmrt_load_bmodel + * @brief To load the bmodel which is created by BM compiler + * @ingroup bmruntime + * + * This API is to load bmodel created by BM compiler. + * After loading bmodel, we can run the inference of neuron network. + * + * @param [in] p_bmrt Bmruntime that had been created + * @param [in] bmodel_path Bmodel file directory. + * + * @retval true Load context sucess. + * @retval false Load context failed. + */ +DECL_EXPORT bool bmrt_load_bmodel(void* p_bmrt, const char *bmodel_path); + +/** + * @name bmrt_load_bmodel_data + * @brief To load the bmodel which is created by BM compiler from buffer + * @ingroup bmruntime + * + * This API is to load bmodel created by BM compiler. + * After loading bmodel, we can run the inference of neuron network. + * Different with bmrt_load_bmodel, bmodel is the data in host memory. + * + * @param [in] p_bmrt Bmruntime that had been created + * @param [in] bmodel_data Bmodel data pointer to buffer + * @param [in] size Bmodel data size + * + * @retval true Load context sucess. + * @retval false Load context failed. + */ +DECL_EXPORT bool bmrt_load_bmodel_data(void* p_bmrt, const void * bmodel_data, size_t size); + +/** + * @name bmrt_show_neuron_network + * @brief To print the name of all neuron network + * @ingroup bmruntime + * + * @param [in] p_bmrt Bmruntime that had been created + */ +DECL_EXPORT void bmrt_show_neuron_network(void* p_bmrt); + +/** + * @name bmrt_get_network_number + * @brief To get the number of neuron network in the bmruntime + * @ingroup bmruntime + * + * @param [in] p_bmrt Bmruntime that had been created + * + * @retval int value The number of neuron networks. + */ +DECL_EXPORT int bmrt_get_network_number(void* p_bmrt); + +/** + * @name bmrt_get_network_names + * @brief To get the names of all neuron network in the bmruntime + * @ingroup bmruntime + * + * @param [in] p_bmrt Bmruntime that had been created + * @param [out] network_names The names of all neuron networks. It should be declare as (const char** networks_ = NULL), + * and use as the param &networks_. After this API, user need to free(networks_) if user + * do not need it. + */ +DECL_EXPORT void bmrt_get_network_names(void* p_bmrt, const char*** network_names); + +/** + * @name bmrt_get_network_info + * @brief To get network info by net name + * @ingroup bmruntime + * + * @param [in] p_bmrt Bmruntime that had been created + * @param [in] net_name Network name + * + * @retval bm_net_info_t* Pointer to net info, needn't free by user; if net name not found, will return NULL. + */ +DECL_EXPORT const bm_net_info_t* bmrt_get_network_info(void* p_bmrt, const char* net_name); + +/** + * @name bmrt_launch_tensor + * @brief To launch the inference of the neuron network with setting input tensors + * @ingroup bmruntime + * + * This API supports the neuron nework that is static-compiled or dynamic-compiled + * After calling this API, inference on TPU is launched. And the CPU program will not + * be blocked. bm_thread_sync should be called to make sure inference finished. + * This API support multiple inputs, and multi thread safety + * + * @param [in] p_bmrt Bmruntime that had been created + * @param [in] net_name The name of the neuron network + * @param [in] input_tensors Array of input tensor, defined like bm_tensor_t input_tensors[input_num]. + * User should initialize each input tensor. + * @param [in] input_num Input number + * @param [out] output_tensors Array of output tensor, defined like bm_tensor_t output_tensors[output_num]. + * This interface will alloc devcie mem to store output data. User should free each + * device mem by bm_free_device after the result data not used. + * @param [in] output_num Output number + * + * @retval true Launch success. + * @retval false Launch failed. + */ +DECL_EXPORT bool bmrt_launch_tensor(void* p_bmrt, const char * net_name, const bm_tensor_t input_tensors[], int input_num, + bm_tensor_t output_tensors[], int output_num); + +/** + * @name bmrt_launch_tensor_ex + * @brief To launch the inference of the neuron network with setting input tensors + * @ingroup bmruntime + * + * This API supports the neuron nework that is static-compiled or dynamic-compiled + * After calling this API, inference on TPU is launched. And the CPU program will not + * be blocked. bm_thread_sync should be called to make sure inference finished. + * This API support multiple inputs, and multi thread safety + * + * @param [in] p_bmrt Bmruntime that had been created + * @param [in] net_name The name of the neuron network + * @param [in] input_tensors Array of input tensor, defined like bm_tensor_t input_tensors[input_num], + * User should initialize each input tensor. + * @param [in] input_num Input number + * @param [out] output_tensors Array of output tensor, defined like bm_tensor_t output_tensors[output_num]. + * User can set device_mem or stmode of output tensors. If user_mem is true, this interface + * will use device mem of output_tensors to store output data, and not alloc device mem; + * Or it will alloc device mem to store output. If user_stmode is true, it will use stmode in + * each output tensor; Or stmode will be BM_STORE_1N as default. + * @param [in] output_num Output number + * @param [in] user_mem whether device_mem of output tensors are set + * @param [in] user_stmode whether stmode of output tensors are set + * + * @retval true Launch success. + * @retval false Launch failed. + */ +DECL_EXPORT bool bmrt_launch_tensor_ex(void* p_bmrt, const char * net_name, const bm_tensor_t input_tensors[], int input_num, + bm_tensor_t output_tensors[], int output_num, bool user_mem, bool user_stmode); + +/** + * @name bmrt_launch_data + * @brief To launch the inference of the neuron network with setting input datas in system memory + * @ingroup bmruntime + * + * This API supports the neuron nework that is static-compiled or dynamic-compiled + * After calling this API, inference on TPU is launched. And the CPU + * program will be blocked. + * This API support multiple inputs, and multi thread safety + * + * @param [in] p_bmrt Bmruntime that had been created + * @param [in] net_name The name of the neuron network + * @param [in] input_datas Array of input data, defined like void * input_datas[input_num]. User should + * initialize each data pointer as input. + * @param [in] input_shapes Array of input shape, defined like bm_shape_t input_shapes[input_num]. + * User should set each input shape + * @param [in] input_num Input number + * @param [out] output_datas Array of output data, defined like void * output_datas[output_num]. + * If user don't alloc each output data, set user_mem to false, and this api will alloc + * output mem, user should free each output mem when output data not used. Also + * user can alloc system memory for each output data by self and set user_mem = true. + * @param [out] output_shapes Array of output shape, defined like bm_shape_t output_shapes[output_num]. + * It will store each output shape. + * @param [in] output_num Output number + * @param [in] user_mem whether output_datas[i] have allocated memory + * + * @retval true Launch success. + * @retval false Launch failed. + */ +DECL_EXPORT bool bmrt_launch_data(void* p_bmrt, const char* net_name, void* const input_datas[], + const bm_shape_t input_shapes[], int input_num, void * output_datas[], + bm_shape_t output_shapes[], int output_num, bool user_mem); + +/** + * @name bmrt_trace + * @brief To check runtime environment, and collect info for DEBUG + * @ingroup bmruntime + * + * This API is to collect runtime info for DEBUG. Expecially when launch result sudden mistake, call bmrt_trace + * will show whether device mems are broken, and other check info. + * + * @param [in] p_bmrt Bmruntime that had been created + */ +DECL_EXPORT void bmrt_trace(void* p_bmrt); + +/** + * @name bmrt_launch_tensor_multi_cores + * @brief To launch the inference of the neuron network with setting input tensors, and support multi core inference. + * @ingroup bmruntime + * + * This API supports the neuron nework that is static-compiled or dynamic-compiled + * After calling this API, inference on TPU is launched. And the CPU program will not + * be blocked. bm_thread_sync_from_core should be called to make sure inference is finished. + * This API support multiple inputs, and multi thread safety + * + * @param [in] p_bmrt Bmruntime that had been created + * @param [in] net_name The name of the neuron network + * @param [in] input_tensors Array of input tensor, defined like bm_tensor_t input_tensors[input_num], + * User should initialize each input tensor. + * @param [in] input_num Input number + * @param [out] output_tensors Array of output tensor, defined like bm_tensor_t output_tensors[output_num]. + * User can set device_mem or stmode of output tensors. If user_mem is true, this interface + * will use device mem of output_tensors to store output data, and not alloc device mem; + * Or it will alloc device mem to store output. If user_stmode is true, it will use stmode in + * each output tensor; Or stmode will be BM_STORE_1N as default. + * @param [in] output_num Output number + * @param [in] user_mem whether device_mem of output tensors are set + * @param [in] user_stmode whether stmode of output tensors are set + * @param [in] core_list core id list those will be used to inference + * @param [in] core_num number of the core list + * + * @retval true Launch success. + * @retval false Launch failed. + */ +DECL_EXPORT bool bmrt_launch_tensor_multi_cores( + void *p_bmrt, + const char *net_name, + const bm_tensor_t input_tensors[], + int input_num, + bm_tensor_t output_tensors[], + int output_num, + bool user_mem, + bool user_stmode, + const int *core_list, + int core_num); + +#if defined (__cplusplus) +} +#endif + +#endif diff --git a/models/Qwen/support/include/tiktoken.h b/models/Qwen/support/include/tiktoken.h new file mode 100755 index 0000000..2481cfa --- /dev/null +++ b/models/Qwen/support/include/tiktoken.h @@ -0,0 +1,271 @@ +#pragma once + +#include +#include "unordered_dense.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tiktoken { + +static auto +_byte_pair_merge(const std::string &piece, + const ankerl::unordered_dense::map &ranks, + std::function func) -> std::vector { + std::vector> parts; + parts.reserve(piece.size() + 1); + for (auto idx = 0U; idx < piece.size() + 1; ++idx) { + parts.emplace_back(idx, std::numeric_limits::max()); + } + + auto get_rank = [&piece, + &ranks](const std::vector> &parts, + int start_idx, int skip) -> std::optional { + if (start_idx + skip + 2 < (int)parts.size()) { + auto s = parts[start_idx].first; + auto e = parts[start_idx + skip + 2].first; + auto key = piece.substr(s, e - s); + auto iter = ranks.find(key); + if (iter != ranks.end()) { + return iter->second; + } + } + return std::nullopt; + }; + + for (auto i = 0U; i < parts.size() - 2; ++i) { + auto rank = get_rank(parts, i, 0); + if (rank) { + assert(*rank != std::numeric_limits::max()); + parts[i].second = *rank; + } + } + + while (true) { + if (parts.size() == 1) + break; + + auto min_rank = + std::make_pair(std::numeric_limits::max(), 0); + for (auto i = 0U; i < parts.size() - 1; ++i) { + auto rank = parts[i].second; + if (rank < min_rank.first) { + min_rank = {rank, i}; + } + } + + if (min_rank.first != std::numeric_limits::max()) { + auto i = min_rank.second; + auto rank = get_rank(parts, i, 1); + if (rank) { + parts[i].second = *rank; + } else { + parts[i].second = std::numeric_limits::max(); + } + if (i > 0) { + auto rank = get_rank(parts, i - 1, 1); + if (rank) { + parts[i - 1].second = *rank; + } else { + parts[i - 1].second = std::numeric_limits::max(); + } + } + + parts.erase(parts.begin() + (i + 1)); + } else { + break; + } + } + std::vector out; + out.reserve(parts.size() - 1); + for (auto i = 0U; i < parts.size() - 1; ++i) { + out.push_back(func(parts[i].first, parts[i + 1].first)); + } + return out; +} + +static auto +byte_pair_encode(const std::string &piece, + const ankerl::unordered_dense::map &ranks) + -> std::vector { + if (piece.size() == 1) { + return {ranks.at(piece)}; + } + + auto func = [&piece, &ranks](int start, int stop) -> int { + std::string key = piece.substr(start, stop - start); + return ranks.at(key); + }; + + return _byte_pair_merge(piece, ranks, func); +} + +class tiktoken { +public: + tiktoken() = default; + tiktoken(ankerl::unordered_dense::map encoder, + ankerl::unordered_dense::map special_encoder, + const std::string &pattern) { + regex_ = std::make_unique("(" + pattern + ")"); + + std::string special_pattern; + for (const auto &item : special_encoder) { + if (!special_pattern.empty()) { + special_pattern += "|"; + } + special_pattern += re2::RE2::QuoteMeta(item.first); + } + if (special_pattern.empty()) { + special_regex_ = nullptr; + } else { + special_regex_ = std::make_unique("(" + special_pattern + ")"); + } + + encoder_ = std::move(encoder); + special_tokens_encoder = std::move(special_encoder); + + for (const auto &[k, v] : encoder_) { + decoder_.emplace(v, k); + } + assert(encoder_.size() == decoder_.size()); + + for (const auto &[k, v] : special_tokens_encoder) { + special_tokens_decoder.emplace(v, k); + } + } + + auto encode_ordinary(const std::string &text) const -> std::vector { + return _encode_ordinary_native(text); + } + + auto encode(const std::string &text) const -> std::vector { + return _encode_native(text, special_tokens_encoder).first; + } + + auto encode_single_piece(const std::string &text) const -> std::vector { + auto iter = encoder_.find(text); + if (iter != encoder_.end()) { + return {iter->second}; + } + return byte_pair_encode(text, encoder_); + } + + auto decode(const std::vector &tokens) const -> std::string { + return _decode_native(tokens); + } + +private: + auto split_with_allowed_special_token( + re2::StringPiece &input, + const ankerl::unordered_dense::map &allowed_special) + const -> std::pair, re2::StringPiece> { + if (special_regex_ == nullptr) + return {std::nullopt, input}; + + auto start = input.begin(); + std::string special; + while (true) { + if (!re2::RE2::FindAndConsume(&input, *special_regex_, &special)) { + break; + } + + if (allowed_special.count(special) == 1) { + return { + std::move(special), + re2::StringPiece(start, input.begin() - start - special.size())}; + } + } + + return {std::nullopt, input}; + } + + auto _encode_ordinary_native(const std::string &text) const + -> std::vector { + std::vector ret; + re2::StringPiece input(text); + + std::string piece; + while (re2::RE2::FindAndConsume(&input, *regex_, &piece)) { + auto iter = encoder_.find(piece); + if (iter != encoder_.end()) { + ret.push_back(iter->second); + continue; + } + auto tokens = byte_pair_encode(piece, encoder_); + ret.insert(ret.end(), tokens.begin(), tokens.end()); + } + return ret; + } + + auto _encode_native(const std::string &text, + const ankerl::unordered_dense::map + &allowed_special) const + -> std::pair, int> { + std::vector ret; + int last_piece_token_len = 0; + re2::StringPiece input(text); + + while (true) { + auto [special, sub_input] = + split_with_allowed_special_token(input, allowed_special); + std::string piece; + while (re2::RE2::FindAndConsume(&sub_input, *regex_, &piece)) { + auto iter = encoder_.find(piece); + if (iter != encoder_.end()) { + last_piece_token_len = 1; + ret.push_back(iter->second); + continue; + } + auto tokens = byte_pair_encode(piece, encoder_); + last_piece_token_len = tokens.size(); + ret.insert(ret.end(), tokens.begin(), tokens.end()); + } + + if (special) { + int token = special_tokens_encoder.at(*special); + ret.push_back(token); + last_piece_token_len = 0; + } else { + break; + } + } + + return {ret, last_piece_token_len}; + } + + auto _decode_native(const std::vector &tokens) const -> std::string { + std::string ret; + ret.reserve(tokens.size() * 2); + for (auto token : tokens) { + std::string token_bytes; + auto iter = decoder_.find(token); + if (iter != decoder_.end()) { + token_bytes = iter->second; + } else { + iter = special_tokens_decoder.find(token); + if (iter != special_tokens_decoder.end()) { + token_bytes = iter->second; + } else { + throw std::runtime_error("unknown token: " + std::to_string(token)); + } + } + ret += token_bytes; + } + return ret; + } + + ankerl::unordered_dense::map encoder_; + ankerl::unordered_dense::map special_tokens_encoder; + ankerl::unordered_dense::map decoder_; + ankerl::unordered_dense::map special_tokens_decoder; + std::unique_ptr regex_; + std::unique_ptr special_regex_; +}; + +} // namespace tiktoken diff --git a/models/Qwen/support/include/tokenizer.h b/models/Qwen/support/include/tokenizer.h new file mode 100755 index 0000000..f2723d8 --- /dev/null +++ b/models/Qwen/support/include/tokenizer.h @@ -0,0 +1,33 @@ +#pragma once + +#include "tiktoken.h" +#include "base64.h" +#include +#include +#include +#include +#include +#include + +class QwenTokenizer { +public: + QwenTokenizer(const std::string &tiktoken_path); + + auto encode(const std::string &text, int max_length) const + -> std::vector; + + auto decode(const std::vector &ids) const -> std::string; + + auto encode_history(const std::vector &history, + int max_length) const -> std::vector; + + auto build_prompt(const std::vector &history) const + -> std::string; + + auto is_special_id(int id) const -> bool; + + tiktoken::tiktoken tokenizer; + const int eod_id = 151643; + const int im_start_id = 151644; + const int im_end_id = 151645; +}; \ No newline at end of file diff --git a/models/Qwen/support/include/unordered_dense.h b/models/Qwen/support/include/unordered_dense.h new file mode 100755 index 0000000..1e635b0 --- /dev/null +++ b/models/Qwen/support/include/unordered_dense.h @@ -0,0 +1,1936 @@ +///////////////////////// ankerl::unordered_dense::{map, set} ///////////////////////// + +// A fast & densely stored hashmap and hashset based on robin-hood backward shift deletion. +// Version 4.1.2 +// https://github.com/martinus/unordered_dense +// +// Licensed under the MIT License . +// SPDX-License-Identifier: MIT +// Copyright (c) 2022-2023 Martin Leitner-Ankerl +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef ANKERL_UNORDERED_DENSE_H +#define ANKERL_UNORDERED_DENSE_H + +// see https://semver.org/spec/v2.0.0.html +#define ANKERL_UNORDERED_DENSE_VERSION_MAJOR 4 // NOLINT(cppcoreguidelines-macro-usage) incompatible API changes +#define ANKERL_UNORDERED_DENSE_VERSION_MINOR 1 // NOLINT(cppcoreguidelines-macro-usage) backwards compatible functionality +#define ANKERL_UNORDERED_DENSE_VERSION_PATCH 2 // NOLINT(cppcoreguidelines-macro-usage) backwards compatible bug fixes + +// API versioning with inline namespace, see https://www.foonathan.net/2018/11/inline-namespaces/ + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define ANKERL_UNORDERED_DENSE_VERSION_CONCAT1(major, minor, patch) v##major##_##minor##_##patch +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +#define ANKERL_UNORDERED_DENSE_VERSION_CONCAT(major, minor, patch) ANKERL_UNORDERED_DENSE_VERSION_CONCAT1(major, minor, patch) +#define ANKERL_UNORDERED_DENSE_NAMESPACE \ + ANKERL_UNORDERED_DENSE_VERSION_CONCAT( \ + ANKERL_UNORDERED_DENSE_VERSION_MAJOR, ANKERL_UNORDERED_DENSE_VERSION_MINOR, ANKERL_UNORDERED_DENSE_VERSION_PATCH) + +#if defined(_MSVC_LANG) +# define ANKERL_UNORDERED_DENSE_CPP_VERSION _MSVC_LANG +#else +# define ANKERL_UNORDERED_DENSE_CPP_VERSION __cplusplus +#endif + +#if defined(__GNUC__) +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +# define ANKERL_UNORDERED_DENSE_PACK(decl) decl __attribute__((__packed__)) +#elif defined(_MSC_VER) +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +# define ANKERL_UNORDERED_DENSE_PACK(decl) __pragma(pack(push, 1)) decl __pragma(pack(pop)) +#endif + +// exceptions +#if defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND) +# define ANKERL_UNORDERED_DENSE_HAS_EXCEPTIONS() 1 // NOLINT(cppcoreguidelines-macro-usage) +#else +# define ANKERL_UNORDERED_DENSE_HAS_EXCEPTIONS() 0 // NOLINT(cppcoreguidelines-macro-usage) +#endif +#ifdef _MSC_VER +# define ANKERL_UNORDERED_DENSE_NOINLINE __declspec(noinline) +#else +# define ANKERL_UNORDERED_DENSE_NOINLINE __attribute__((noinline)) +#endif + +// defined in unordered_dense.cpp +#if !defined(ANKERL_UNORDERED_DENSE_EXPORT) +# define ANKERL_UNORDERED_DENSE_EXPORT +#endif + +#if ANKERL_UNORDERED_DENSE_CPP_VERSION < 201703L +# error ankerl::unordered_dense requires C++17 or higher +#else +# include // for array +# include // for uint64_t, uint32_t, uint8_t, UINT64_C +# include // for size_t, memcpy, memset +# include // for equal_to, hash +# include // for initializer_list +# include // for pair, distance +# include // for numeric_limits +# include // for allocator, allocator_traits, shared_ptr +# include // for out_of_range +# include // for basic_string +# include // for basic_string_view, hash +# include // for forward_as_tuple +# include // for enable_if_t, declval, conditional_t, ena... +# include // for forward, exchange, pair, as_const, piece... +# include // for vector +# if ANKERL_UNORDERED_DENSE_HAS_EXCEPTIONS() == 0 +# include // for abort +# endif + +# if defined(__has_include) +# if __has_include() +# define ANKERL_UNORDERED_DENSE_PMR std::pmr // NOLINT(cppcoreguidelines-macro-usage) +# include // for polymorphic_allocator +# elif __has_include() +# define ANKERL_UNORDERED_DENSE_PMR std::experimental::pmr // NOLINT(cppcoreguidelines-macro-usage) +# include // for polymorphic_allocator +# endif +# endif + +# if defined(_MSC_VER) && defined(_M_X64) +# include +# pragma intrinsic(_umul128) +# endif + +# if defined(__GNUC__) || defined(__INTEL_COMPILER) || defined(__clang__) +# define ANKERL_UNORDERED_DENSE_LIKELY(x) __builtin_expect(x, 1) // NOLINT(cppcoreguidelines-macro-usage) +# define ANKERL_UNORDERED_DENSE_UNLIKELY(x) __builtin_expect(x, 0) // NOLINT(cppcoreguidelines-macro-usage) +# else +# define ANKERL_UNORDERED_DENSE_LIKELY(x) (x) // NOLINT(cppcoreguidelines-macro-usage) +# define ANKERL_UNORDERED_DENSE_UNLIKELY(x) (x) // NOLINT(cppcoreguidelines-macro-usage) +# endif + +namespace ankerl::unordered_dense { +inline namespace ANKERL_UNORDERED_DENSE_NAMESPACE { + +namespace detail { + +# if ANKERL_UNORDERED_DENSE_HAS_EXCEPTIONS() + +// make sure this is not inlined as it is slow and dramatically enlarges code, thus making other +// inlinings more difficult. Throws are also generally the slow path. +[[noreturn]] inline ANKERL_UNORDERED_DENSE_NOINLINE void on_error_key_not_found() { + throw std::out_of_range("ankerl::unordered_dense::map::at(): key not found"); +} +[[noreturn]] inline ANKERL_UNORDERED_DENSE_NOINLINE void on_error_bucket_overflow() { + throw std::overflow_error("ankerl::unordered_dense: reached max bucket size, cannot increase size"); +} +[[noreturn]] inline ANKERL_UNORDERED_DENSE_NOINLINE void on_error_too_many_elements() { + throw std::out_of_range("ankerl::unordered_dense::map::replace(): too many elements"); +} + +# else + +[[noreturn]] inline void on_error_key_not_found() { + abort(); +} +[[noreturn]] inline void on_error_bucket_overflow() { + abort(); +} +[[noreturn]] inline void on_error_too_many_elements() { + abort(); +} + +# endif + +} // namespace detail + +// hash /////////////////////////////////////////////////////////////////////// + +// This is a stripped-down implementation of wyhash: https://github.com/wangyi-fudan/wyhash +// No big-endian support (because different values on different machines don't matter), +// hardcodes seed and the secret, reformats the code, and clang-tidy fixes. +namespace detail::wyhash { + +inline void mum(uint64_t* a, uint64_t* b) { +# if defined(__SIZEOF_INT128__) + __uint128_t r = *a; + r *= *b; + *a = static_cast(r); + *b = static_cast(r >> 64U); +# elif defined(_MSC_VER) && defined(_M_X64) + *a = _umul128(*a, *b, b); +# else + uint64_t ha = *a >> 32U; + uint64_t hb = *b >> 32U; + uint64_t la = static_cast(*a); + uint64_t lb = static_cast(*b); + uint64_t hi{}; + uint64_t lo{}; + uint64_t rh = ha * hb; + uint64_t rm0 = ha * lb; + uint64_t rm1 = hb * la; + uint64_t rl = la * lb; + uint64_t t = rl + (rm0 << 32U); + auto c = static_cast(t < rl); + lo = t + (rm1 << 32U); + c += static_cast(lo < t); + hi = rh + (rm0 >> 32U) + (rm1 >> 32U) + c; + *a = lo; + *b = hi; +# endif +} + +// multiply and xor mix function, aka MUM +[[nodiscard]] inline auto mix(uint64_t a, uint64_t b) -> uint64_t { + mum(&a, &b); + return a ^ b; +} + +// read functions. WARNING: we don't care about endianness, so results are different on big endian! +[[nodiscard]] inline auto r8(const uint8_t* p) -> uint64_t { + uint64_t v{}; + std::memcpy(&v, p, 8U); + return v; +} + +[[nodiscard]] inline auto r4(const uint8_t* p) -> uint64_t { + uint32_t v{}; + std::memcpy(&v, p, 4); + return v; +} + +// reads 1, 2, or 3 bytes +[[nodiscard]] inline auto r3(const uint8_t* p, size_t k) -> uint64_t { + return (static_cast(p[0]) << 16U) | (static_cast(p[k >> 1U]) << 8U) | p[k - 1]; +} + +[[maybe_unused]] [[nodiscard]] inline auto hash(void const* key, size_t len) -> uint64_t { + static constexpr auto secret = std::array{UINT64_C(0xa0761d6478bd642f), + UINT64_C(0xe7037ed1a0b428db), + UINT64_C(0x8ebc6af09c88c6e3), + UINT64_C(0x589965cc75374cc3)}; + + auto const* p = static_cast(key); + uint64_t seed = secret[0]; + uint64_t a{}; + uint64_t b{}; + if (ANKERL_UNORDERED_DENSE_LIKELY(len <= 16)) { + if (ANKERL_UNORDERED_DENSE_LIKELY(len >= 4)) { + a = (r4(p) << 32U) | r4(p + ((len >> 3U) << 2U)); + b = (r4(p + len - 4) << 32U) | r4(p + len - 4 - ((len >> 3U) << 2U)); + } else if (ANKERL_UNORDERED_DENSE_LIKELY(len > 0)) { + a = r3(p, len); + b = 0; + } else { + a = 0; + b = 0; + } + } else { + size_t i = len; + if (ANKERL_UNORDERED_DENSE_UNLIKELY(i > 48)) { + uint64_t see1 = seed; + uint64_t see2 = seed; + do { + seed = mix(r8(p) ^ secret[1], r8(p + 8) ^ seed); + see1 = mix(r8(p + 16) ^ secret[2], r8(p + 24) ^ see1); + see2 = mix(r8(p + 32) ^ secret[3], r8(p + 40) ^ see2); + p += 48; + i -= 48; + } while (ANKERL_UNORDERED_DENSE_LIKELY(i > 48)); + seed ^= see1 ^ see2; + } + while (ANKERL_UNORDERED_DENSE_UNLIKELY(i > 16)) { + seed = mix(r8(p) ^ secret[1], r8(p + 8) ^ seed); + i -= 16; + p += 16; + } + a = r8(p + i - 16); + b = r8(p + i - 8); + } + + return mix(secret[1] ^ len, mix(a ^ secret[1], b ^ seed)); +} + +[[nodiscard]] inline auto hash(uint64_t x) -> uint64_t { + return detail::wyhash::mix(x, UINT64_C(0x9E3779B97F4A7C15)); +} + +} // namespace detail::wyhash + +ANKERL_UNORDERED_DENSE_EXPORT template +struct hash { + auto operator()(T const& obj) const noexcept(noexcept(std::declval>().operator()(std::declval()))) + -> uint64_t { + return std::hash{}(obj); + } +}; + +template +struct hash> { + using is_avalanching = void; + auto operator()(std::basic_string const& str) const noexcept -> uint64_t { + return detail::wyhash::hash(str.data(), sizeof(CharT) * str.size()); + } +}; + +template +struct hash> { + using is_avalanching = void; + auto operator()(std::basic_string_view const& sv) const noexcept -> uint64_t { + return detail::wyhash::hash(sv.data(), sizeof(CharT) * sv.size()); + } +}; + +template +struct hash { + using is_avalanching = void; + auto operator()(T* ptr) const noexcept -> uint64_t { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + return detail::wyhash::hash(reinterpret_cast(ptr)); + } +}; + +template +struct hash> { + using is_avalanching = void; + auto operator()(std::unique_ptr const& ptr) const noexcept -> uint64_t { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + return detail::wyhash::hash(reinterpret_cast(ptr.get())); + } +}; + +template +struct hash> { + using is_avalanching = void; + auto operator()(std::shared_ptr const& ptr) const noexcept -> uint64_t { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + return detail::wyhash::hash(reinterpret_cast(ptr.get())); + } +}; + +template +struct hash::value>::type> { + using is_avalanching = void; + auto operator()(Enum e) const noexcept -> uint64_t { + using underlying = typename std::underlying_type_t; + return detail::wyhash::hash(static_cast(e)); + } +}; + +// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) +# define ANKERL_UNORDERED_DENSE_HASH_STATICCAST(T) \ + template <> \ + struct hash { \ + using is_avalanching = void; \ + auto operator()(T const& obj) const noexcept -> uint64_t { \ + return detail::wyhash::hash(static_cast(obj)); \ + } \ + } + +# if defined(__GNUC__) && !defined(__clang__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wuseless-cast" +# endif +// see https://en.cppreference.com/w/cpp/utility/hash +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(bool); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(char); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(signed char); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned char); +# if ANKERL_UNORDERED_DENSE_CPP_VERSION >= 202002L && defined(__cpp_char8_t) +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(char8_t); +# endif +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(char16_t); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(char32_t); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(wchar_t); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(short); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned short); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(int); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned int); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(long); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(long long); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned long); +ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned long long); + +# if defined(__GNUC__) && !defined(__clang__) +# pragma GCC diagnostic pop +# endif + +// bucket_type ////////////////////////////////////////////////////////// + +namespace bucket_type { + +struct standard { + static constexpr uint32_t dist_inc = 1U << 8U; // skip 1 byte fingerprint + static constexpr uint32_t fingerprint_mask = dist_inc - 1; // mask for 1 byte of fingerprint + + uint32_t m_dist_and_fingerprint; // upper 3 byte: distance to original bucket. lower byte: fingerprint from hash + uint32_t m_value_idx; // index into the m_values vector. +}; + +ANKERL_UNORDERED_DENSE_PACK(struct big { + static constexpr uint32_t dist_inc = 1U << 8U; // skip 1 byte fingerprint + static constexpr uint32_t fingerprint_mask = dist_inc - 1; // mask for 1 byte of fingerprint + + uint32_t m_dist_and_fingerprint; // upper 3 byte: distance to original bucket. lower byte: fingerprint from hash + size_t m_value_idx; // index into the m_values vector. +}); + +} // namespace bucket_type + +namespace detail { + +struct nonesuch {}; + +template class Op, class... Args> +struct detector { + using value_t = std::false_type; + using type = Default; +}; + +template class Op, class... Args> +struct detector>, Op, Args...> { + using value_t = std::true_type; + using type = Op; +}; + +template