diff --git a/docs/continuous_batching.md b/docs/continuous_batching.md new file mode 100644 index 000000000..1e59e6549 --- /dev/null +++ b/docs/continuous_batching.md @@ -0,0 +1,164 @@ +Continuous Batching +======= + +Continuous batching is a more efficient batching mechanism in LLM server system when compared with static batching input and output. It has two main characteristics: +- concat input sequences in `seq_len` dimension (omit padding token) for `linear` operation and split it in `multi-head attention` and other seq-level operators (`RoPE`, etc.). +- Split sequences out immediately if they finish generation and send new sequences into inference engine. +For more technical details, please refer to [ORCA paper](https://www.usenix.org/system/files/osdi22-yu.pdf). + +There is a illustration below of continuous batching from ORCA paper. $x_{ij}$ means it is a j-th token which belong to i-th request (sequence). And this figure only depicts the QKV Linear, Attention, and +Attention Out Linear operations for simplicity. + +![ORCA continuous batching inference](./imgs/ORCA_batching.png) + +## Offline +We only support multi-batch inference in concatenating & splitting input sequences way. Because it can avoid padding mask effect for some operators (`RoPE`, etc.) and save `linear` inference time. You can use `transformers` liked code (padding prompts->giving `Torch.Tensor`->generation) to try it. We will remove those padding tokens inside and return the whole generation results with a python list. So all you need to do additionally is provide the right `pad_token_id`. + +The code example is like: +```python +from transformers import AutoTokenizer, AutoModelForCausalLM +from neural_speed import Model + +model_name = "meta-llama/Llama-2-7b-hf" +prompts = [ + "Tell me an interesting fact about llamas.", + "What is the best way to cook a steak?", + "Are you familiar with the Special Theory of Relativity and can you explain it to me?", + "Recommend some interesting books to read.", + "What is the best way to learn a new language?", + ] + +tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left") +# if the tokenizer has no pad_token, you can specify it. +tokenizer.pad_token = tokenizer.eos_token +pad_token_id = tokenizer.pad_token_id +inputs = tokenizer(ps, padding=True, return_tensors='pt').input_ids + +model = Model() +model.init(model_name, use_quant=True, weight_dtype="int4", compute_dtype="int8") +# greedy search example, top_k_top_p sampling and beam_search also supported +# do not forget to pass pad_token_id +outputs = model.generate(inputs, max_new_tokens=128, do_sample=False, pad_token=pad_token_id) +ans = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False) +for a in ans: + print(a) + print("===========================") +``` +> Note: Not every model supports multi-batching inference and most of them are under construction, please refer to [Supported Models](#supported-models). + +## Server +We supply a corresponding [script](../scripts/python_api_example_for_model_server.py) for server usage. +You can modify the `max_request_num` for setting the maximum bearable requests. + +>Note: 1.The server system we developed is just a prototype now. It may change interface and usage later. 2. Not every model supports server mode and most of them are under construction, please refer to [Supported Models](#supported-models). + +## Supported Models +You can refer to [developer_document](../developer_document.md#22-inference-process) for adding continuous batching inference feature in your own customized model. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModelsContinuous Batching Support
+ +[LLaMA-7B](https://huggingface.co/decapoda-research/llama-7b-hf), [LLaMA-13B](https://huggingface.co/decapoda-research/llama-13b-hf), [LLaMA2-7B](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf), [LLaMA2-13B](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf), [LLaMA2-70B](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf)
+[CodeLlama-7b](https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf)
+[Solar-10.7B](https://huggingface.co/upstage/SOLAR-10.7B-Instruct-v1.0)
+[Neural-Chat-7B-v3-1](https://huggingface.co/Intel/neural-chat-7b-v3-1), [Neural-Chat-7B-v3-2](https://huggingface.co/Intel/neural-chat-7b-v3-2)
+[Magicoder-6.7B](https://huggingface.co/ise-uiuc/Magicoder-S-DS-6.7B)
+[Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.1), [Mixtral-8x7B](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
+[GPT-J-6B](https://huggingface.co/EleutherAI/gpt-j-6b)
+[GPT-NeoX-20B](https://huggingface.co/EleutherAI/gpt-neox-20b), [Dolly-v2-3B](https://huggingface.co/databricks/dolly-v2-3b)🚧
+[Qwen-7B](https://huggingface.co/Qwen/Qwen-7B-Chat), [Qwen-14B](https://huggingface.co/Qwen/Qwen-14B-Chat), [Qwen1.5-7B](https://huggingface.co/Qwen/Qwen1.5-7B-Chat"), [Qwen1.5-0.5B](https://huggingface.co/Qwen/Qwen1.5-0.5B)🚧
+[MPT-7B](https://huggingface.co/mosaicml/mpt-7b), [MPT-30B](https://huggingface.co/mosaicml/mpt-30b)🚧
+[Falcon-7B](https://huggingface.co/tiiuae/falcon-7b), [Falcon-40B](https://huggingface.co/tiiuae/falcon-40b)🚧
+[BLOOM-7B](https://huggingface.co/bigscience/bloomz-7b1)🚧
+[OPT-125m](https://huggingface.co/facebook/opt-125m), [OPT-350m](https://huggingface.co/facebook/opt-350m), [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b), [OPT-13B](https://huggingface.co/facebook/opt-13b)🚧
+[ChatGLM-6B](https://huggingface.co/THUDM/chatglm-6b), [ChatGLM2-6B](https://huggingface.co/THUDM/chatglm2-6b), [ChatGLM3-6B](https://huggingface.co/THUDM/chatglm3-6b)🚧
+[StarCoder-1B](https://huggingface.co/bigcode/starcoderbase-1b), [StarCoder-3B](https://huggingface.co/bigcode/starcoderbase-3b), [StarCoder-15.5B](https://huggingface.co/bigcode/starcoder)🚧
+[Baichuan-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan-13B-Chat), [Baichuan2-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat)🚧
+[phi-2](https://huggingface.co/microsoft/phi-2), [phi-1_5](https://huggingface.co/microsoft/phi-1_5), [phi-1](https://huggingface.co/microsoft/phi-1)🚧
+[StableLM-3B](https://huggingface.co/stabilityai/stablelm-3b-4e1t), [StableLM2-1_6B](https://huggingface.co/stabilityai/stablelm-2-1_6b), [StableLM2-Zephyr-1_6B](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b)🚧
+[gemma-2b-it](https://huggingface.co/google/gemma-2b-it), [gemma-7b](https://huggingface.co/google/gemma-7b)🚧
+ +> ✅: Supported; 🚧: WIP diff --git a/docs/imgs/ORCA_batching.png b/docs/imgs/ORCA_batching.png new file mode 100644 index 000000000..4f9a06031 Binary files /dev/null and b/docs/imgs/ORCA_batching.png differ diff --git a/neural_speed/__init__.py b/neural_speed/__init__.py index c6b706d4b..1bd974309 100644 --- a/neural_speed/__init__.py +++ b/neural_speed/__init__.py @@ -22,20 +22,7 @@ model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder"} max_request_num_default = 1 - -class Model: - - def __init__(self): - self.module = None - self.model = None - self.model_type = None - self.bin_file = None - self.generate_round = 0 - self.max_request_num = -1 - - def __import_package(self, model_type): - if self.module: - return +def _import_package(model_type): if model_type == "gptj": import neural_speed.gptj_cpp as cpp_model elif model_type == "falcon": @@ -80,28 +67,62 @@ def __import_package(self, model_type): import neural_speed.mixtral_cpp as cpp_model else: raise TypeError("Unsupported model type {}!".format(model_type)) - self.module = cpp_model + return cpp_model + +def _get_model_config(model_name, model_hub="huggingface"): + if model_hub == "modelscope": + from modelscope import AutoConfig + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + else: + from transformers import AutoConfig + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + return config + +def _get_model_type(model_config): + model_type = model_maps.get(model_config.model_type, model_config.model_type) + if model_type == "chatglm" and "chatglm2" in model_config._name_or_path: + model_type = "chatglm2" + + # For ChatGLM3 + if model_type == "chatglm" and "chatglm3" in model_config._name_or_path: + # due to the same model architecture. + model_type = "chatglm2" + + # for TheBloke/falcon-40b-instruct-GPTQ & TheBloke/Falcon-7B-Instruct-GPTQ + if model_type == "RefinedWebModel" or model_type == "RefinedWeb": + model_type = "falcon" + + # for TheBloke/phi-2-GPTQ + if model_type == "phi-msft": + model_type = "phi" + + return model_type + +def _filter_model_args(valid_args, **input_kwargs): + invalid_args = [] + for k in input_kwargs.keys(): + if k not in valid_args: + invalid_args.append(k) + for k in invalid_args: + input_kwargs.pop(k) + return input_kwargs + +def get_cpp_module(model_name, model_hub="huggingface"): + model_config = _get_model_config(model_name, model_hub=model_hub) + model_type = _get_model_type(model_config) + cpp_module = _import_package(model_type) + return cpp_module - @staticmethod - def get_model_type(model_config): - model_type = model_maps.get(model_config.model_type, model_config.model_type) - if model_type == "chatglm" and "chatglm2" in model_config._name_or_path: - model_type = "chatglm2" - - # For ChatGLM3 - if model_type == "chatglm" and "chatglm3" in model_config._name_or_path: - # due to the same model architecture. - model_type = "chatglm2" - - # for TheBloke/falcon-40b-instruct-GPTQ & TheBloke/Falcon-7B-Instruct-GPTQ - if model_type == "RefinedWebModel" or model_type == "RefinedWeb": - model_type = "falcon" - - # for TheBloke/phi-2-GPTQ - if model_type == "phi-msft": - model_type = "phi" +class Model: - return model_type + def __init__(self): + self.module = None + self.model = None + self.model_type = None + self.bin_file = None + self.generate_round = 0 + self.max_request_num = -1 + self.reinit_from_bin = False def init(self, model_name, @@ -116,15 +137,11 @@ def init(self, compute_dtype="int8", use_ggml=False, model_hub="huggingface"): - if model_hub == "modelscope": - from modelscope import AutoConfig - self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - else: - from transformers import AutoConfig - self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - model_type = Model.get_model_type(self.config) + self.config = _get_model_config(model_name, model_hub=model_hub) + model_type = _get_model_type(self.config) self.model_type = model_type - self.__import_package(model_type) + if self.module is None: + self.module = _import_package(model_type) # check cache and quantization output_path = "runtime_outs" @@ -181,7 +198,8 @@ def init(self, os.remove(fp32_bin) def init_from_bin(self, model_type, model_path, **generate_kwargs): - self.__import_package(model_type) + if self.module is None: + self.module = _import_package(model_type) self.model = self.module.Model() if self.max_request_num == -1: @@ -271,10 +289,16 @@ def get_scratch_size_ratio(size): else: generate_kwargs["scratch_size_ratio"] = 35 - self.model.init_model(model_path, **generate_kwargs) + valid_args = {"max_new_tokens", "n_batch", "ctx_size", "seed", "threads", "repetition_penalty", + "num_beams", "do_sample", "top_k", "top_p", "temperature", "min_new_tokens", + "length_penalty", "early_stopping", "n_keep", "n_discard", "shift_roped_k", + "batch_size","pad_token", "memory_dtype", "continuous_batching", "max_request_num", + "scratch_size_ratio"} + self.model.init_model(model_path, **_filter_model_args(valid_args, **generate_kwargs)) def quant_model(self, model_type, model_path, out_path, **quant_kwargs): - self.__import_package(model_type) + if self.module is None: + self.module = _import_package(model_type) self.module.Model.quant_model(model_path=model_path, out_path=out_path, **quant_kwargs) def generate(self, @@ -287,15 +311,11 @@ def generate(self, batch_size = input_ids.shape[0] max_new_tokens = generate_kwargs.get("max_new_tokens", -1) - max_request_num = generate_kwargs.pop("max_request_num", max_request_num_default) - reinit_from_bin = False - if max_request_num > self.max_request_num or batch_size > self.max_request_num: - reinit_from_bin = True - if self.max_request_num > 0: - print("Will start to reinit model from bin due to different max request num.") - self.max_request_num = max(batch_size, max_request_num) + self.reinit_from_bin = False + self._check_max_request_num(batch_size, **generate_kwargs) + generate_kwargs.pop("max_request_num", max_request_num_default) - if self.model is None or reinit_from_bin: + if self.model is None or self.reinit_from_bin: self.init_from_bin(self.model_type, self.bin_file, batch_size=batch_size, @@ -313,9 +333,6 @@ def generate(self, beam_search = False if (generate_kwargs.get("num_beams", 1) > 1) and not generate_kwargs.get("do_sample", False): beam_search = True - if not beam_search: - # TODO support multi batch - assert input_ids.shape[0] == 1, "Unsupported multi-batch input ids." if streamer: assert input_ids.shape[0] == 1, "Streamer only supports batch size 1." @@ -328,12 +345,7 @@ def generate(self, if interactive: self.model.reset_token_end() out_count = 0 - input_list = None - pad_token_id = generate_kwargs.get("pad_token", None) - if input_ids.shape[0] > 1 and generate_kwargs.get("continuous_batching", True): - input_list = self._cont_batching_input(input_ids, pad_token_id) - else: - input_list = input_ids.tolist() + input_list = self._get_model_input_list(input_ids, **generate_kwargs) while True: response = self.model.generate(input_ids=input_list) input_list = [] # next-token stage will use previous output @@ -345,13 +357,13 @@ def generate(self, ret[i].extend(response[i]) if beam_search: break + out_count += 1 if stopping_criteria is not None: if stopping_criteria(torch.tensor(ret), None): break elif ret[0][-1] == self.__get_eos_id() or \ - (max_new_tokens != -1 and out_count > max_new_tokens): + (max_new_tokens != -1 and out_count >= max_new_tokens): break - out_count += 1 if streamer: streamer.end() @@ -377,27 +389,94 @@ def __call__(self, model_input, reinit=False, logits_all=False, **kwargs): print("Please input an audio file") return if isinstance(model_input, torch.Tensor): - if self.model is None: - self.init_from_bin(self.model_type, self.bin_file, **kwargs) + import numpy as np + batch_size = model_input.shape[0] + logits_seq_len_dim = model_input.shape[1] if logits_all else 1 + self.reinit_from_bin = False + self._check_max_request_num(batch_size, **kwargs) + kwargs.pop("max_request_num", max_request_num_default) + if self.model is None or self.reinit_from_bin: + self.init_from_bin(self.model_type, + self.bin_file, + batch_size=batch_size, + max_request_num=self.max_request_num, + **kwargs) self.generate_round = 0 elif reinit: self.model.reinit() self.generate_round = 0 - return self.model.evaluate(model_input.tolist(), logits_all) + model_input_list = self._get_model_input_list(model_input, **kwargs) + raw_logits = self.model.evaluate(model_input_list, logits_all) + ignore_padding = kwargs.get("ignore_padding", False) + if not ignore_padding and batch_size > 1: + padding_side = kwargs.get("padding_side", "left") + for i in range(len(raw_logits)): + padding_row = np.ones((logits_seq_len_dim - raw_logits[i].shape[0], raw_logits[i].shape[1])) + if padding_side == "left": + raw_logits[i] = np.vstack((padding_row * (-np.inf), raw_logits[i])) + else: + raw_logits[i] = np.vstack((raw_logits[i], padding_row * (-np.inf))) + return np.array(raw_logits) else: print("Please input torch.Tensor") return - def _cont_batching_input(self, input_ids, pad_token_id=None): + def _cont_batching_input(self, input_ids, pad_token_id=None, padding_side="left"): assert isinstance(input_ids, torch.Tensor), "Input must be torch.Tensor." input_list = input_ids.tolist() pti = pad_token_id if pti == None: pti = self.tokenizer.pad_token_id - assert pti != None, "Please supply pad token id." + assert pti != None, "Please supply pad token id with `pad_token=token_id`." for il in range(len(input_list)): count = input_list[il].count(pti) - # padding left - del input_list[il][0:count] + if padding_side == "left": + del input_list[il][0:count] + elif padding_side == "right": + del input_list[il][len(input_list[il]) - count :] + else: + raise ValueError("padding_side must be 'left' or 'right'.") assert input_list[il] != [], "there are all pad tokens in batch {}.".format(il) return input_list + + def _check_max_request_num(self, input_batch_size, **kwargs): + max_request_num = kwargs.get("max_request_num", max_request_num_default) + if max_request_num > self.max_request_num or input_batch_size > self.max_request_num: + self.reinit_from_bin = True + if self.max_request_num > 0: + print("Will start to reinit model from bin due to different max request num.") + self.max_request_num = max(input_batch_size, max_request_num) + + def _get_model_input_list(self, input_ids, **kwargs): + input_list = None + ignore_padding = kwargs.get("ignore_padding", False) + if input_ids.shape[0] > 1 and kwargs.get("continuous_batching", True) and not ignore_padding: + pad_token_id = kwargs.get("pad_token", None) + padding_side = kwargs.get("padding_side", "left") + input_list = self._cont_batching_input(input_ids, pad_token_id, padding_side) + else: + input_list = input_ids.tolist() + return input_list + + +class ModelServer: + def __init__(self, model_name, reponse_function, model_path, **server_kwargs): + if not os.path.exists(model_path): + raise ValueError("model file {} does not exist.".format(model_path)) + self.module = get_cpp_module(model_name) + valid_args = {"max_new_tokens", "n_batch", "ctx_size", "seed", "threads", "repetition_penalty", + "num_beams", "do_sample", "top_k", "top_p", "temperature", "min_new_tokens", + "length_penalty", "early_stopping", "n_keep", "n_discard", "shift_roped_k", + "batch_size","pad_token", "memory_dtype", "continuous_batching", "max_request_num", + "scratch_size_ratio", "return_prompt", "print_log", "init_cb"} + self.cpp_server = self.module.ModelServer(reponse_function, + model_path, + **_filter_model_args(valid_args, **server_kwargs)) + + def issueQuery(self, index, token_ids): + self.cpp_server.issueQuery([self.module.Query(index, token_ids)]) + + def Empty(self): + return self.cpp_server.Empty() + +__all__ = ["get_cpp_module", "Model", "ModelServer"] diff --git a/neural_speed/application/main_pybind.cpp b/neural_speed/application/main_pybind.cpp index d95934401..e3f2c62e6 100644 --- a/neural_speed/application/main_pybind.cpp +++ b/neural_speed/application/main_pybind.cpp @@ -91,7 +91,12 @@ void init_gpt_params(gpt_params* params, const std::string& model_path, int max_ #endif params->model_arch = model_name_to_arch::init().find(params->model_name); params->model = model_path; - params->n_predict = max_new_tokens; + if (max_new_tokens < 0) { + fprintf(stderr, "warning: max_new_tokens must be not less than 0, resetting it to 0.\n"); + params->n_predict = 0; + } else { + params->n_predict = max_new_tokens; + } params->n_batch = n_batch; params->n_ctx = ctx_size; params->seed = seed; @@ -338,13 +343,23 @@ class Model { // deprecated API std::vector> generate_tokens(const std::vector>& input_ids); const std::vector& evaluate_(const std::vector>& input_ids); - py::array_t evaluate(const std::vector>& input_ids, bool logits_all = false) { + std::vector> evaluate(const std::vector>& input_ids, + bool logits_all = false) { if (logits_all) ctx->logits_all = true; - if (!check_input_and_count_padding(input_ids)) return py::array_t(); + if (!check_input_and_count_padding(input_ids)) { + return std::vector>(ctx->batch_size, py::array_t()); + } const auto& logits = evaluate_(input_ids); + std::vector> rets(ctx->batch_size); + size_t off = 0; + for (int i = 0; i < ctx->batch_size; ++i) { + auto cur_seq_len = logits_all ? curr_input_ids[i].size() : 1; + rets[i] = py::array_t(cur_seq_len * n_vocab, logits.data() + off) + .reshape({py::ssize_t(-1), static_cast(n_vocab)}); + off += cur_seq_len * n_vocab; + } for (auto& input_id : curr_input_ids) input_id.clear(); // clear curr_input_ids after eval - return py::array_t(logits.size(), logits.data()) - .reshape({py::ssize_t(-1), static_cast(ctx->model.hparams.n_vocab)}); + return rets; } bool is_token_end() { return token_eos; } model_token get_eos_id() { return ctx->vocab.eos_token_id; } @@ -401,15 +416,15 @@ class Model { model_context* ctx = nullptr; gpt_params params; std::vector> curr_input_ids; - int n_past = 0; - int n_total = 0; + std::vector n_past; + std::vector n_total; int n_vocab = 0; int n_ctx = 0; std::vector> last_n_tokens; bool token_eos = false; int64_t generate_count = 0; std::vector padding_count; - uint32_t n_prompt_tokens = 0; + std::vector n_prompt_tokens; std::vector times; std::vector> beam_generate(const std::vector>& input_ids); @@ -432,8 +447,9 @@ void Model::init_model(const std::string& model_path, int max_new_tokens, int n_ n_discard, shift_roped_k, batch_size, pad_token, memory_dtype, continuous_batching, max_request_num, scratch_size_ratio); - n_past = 0; - n_total = 0; + n_past.assign(params.max_request_num, 0); + n_total.assign(params.max_request_num, 0); + n_prompt_tokens.assign(params.max_request_num, 0); token_eos = false; curr_input_ids.clear(); curr_input_ids.resize(params.max_request_num); @@ -448,8 +464,9 @@ void Model::init_model(const std::string& model_path, int max_new_tokens, int n_ } void Model::reinit() { - n_past = 0; - n_total = 0; + n_past.assign(params.max_request_num, 0); + n_total.assign(params.max_request_num, 0); + n_prompt_tokens.assign(params.max_request_num, 0); last_n_tokens.clear(); last_n_tokens.resize(params.max_request_num); for (int i = 0; i < params.max_request_num; ++i) { @@ -462,7 +479,6 @@ void Model::reinit() { ctx->t_sample_us = 0; generate_count = 0; padding_count.clear(); - n_prompt_tokens = 0; } bool Model::check_input_and_count_padding(const std::vector>& input_ids) { @@ -475,7 +491,7 @@ bool Model::check_input_and_count_padding(const std::vectorbatch_size = 1; - n_prompt_tokens = input_ids[STATIC_INPUT_HEAD_IDX].size(); + n_prompt_tokens[STATIC_INPUT_HEAD_IDX] = input_ids[STATIC_INPUT_HEAD_IDX].size(); return true; } else { // multi-batch inputs (first token) ctx->batch_size = input_ids.size(); @@ -485,24 +501,22 @@ bool Model::check_input_and_count_padding(const std::vectorvocab.pad_token_id == -1) { - fprintf(stderr, "\nERROR: please set pad_token for static multi-batch generation (tokenizer.pad_token_id)!\n"); + if (ctx->vocab.pad_token_id == -1 && !ctx->cont_batching) { + fprintf(stderr, "\nERROR: please set pad_token for static multi-batch inference (tokenizer.pad_token_id)!\n"); return false; } if (!padding_count.empty()) padding_count.clear(); - if (ctx->cont_batching) { - padding_count.assign(input_ids.size(), 0); - return true; - } + padding_count.assign(input_ids.size(), 0); for (int bs = 0; bs < input_ids.size(); ++bs) { - model_vocab::id pad_token_id = ctx->vocab.pad_token_id; - auto iter = std::find_if(input_ids[bs].begin(), input_ids[bs].end(), - [&pad_token_id](model_token t) { return (t != pad_token_id); }); - if (iter == input_ids[bs].end()) fprintf(stderr, "\nERROR: there are all pad tokens in batch %d!\n", bs); - padding_count.push_back(std::distance(input_ids[bs].begin(), iter)); + n_prompt_tokens[bs] = input_ids[bs].size(); + if (!ctx->cont_batching) { + model_vocab::id pad_token_id = ctx->vocab.pad_token_id; + auto iter = std::find_if(input_ids[bs].begin(), input_ids[bs].end(), + [&pad_token_id](model_token t) { return (t != pad_token_id); }); + if (iter == input_ids[bs].end()) fprintf(stderr, "\nERROR: there are all pad tokens in batch %d!\n", bs); + padding_count[bs] = std::distance(input_ids[bs].begin(), iter); + } } - // should be same in static batching inference - n_prompt_tokens = input_ids[STATIC_INPUT_HEAD_IDX].size(); return true; } } @@ -513,7 +527,7 @@ std::vector> Model::beam_generate(const std::vector& Model::evaluate_(const std::vector n_ctx) { + if (n_past[bs] + curr_input_ids[bs].size() > n_ctx) { // always keep the first token - n_past = std::max(1, params.n_keep); + n_past[bs] = std::max(1, params.n_keep); int n_discard = params.n_discard; if (!params.shift_roped_k) { // shift_roped_k can use ring-buffer and thus does not need re-computing @@ -574,19 +588,19 @@ const std::vector& Model::evaluate_(const std::vectorlogits; } @@ -595,6 +609,11 @@ std::vector> Model::generate(const std::vectorbeam_search) return beam_generate(input_ids); + if (ctx->vocab.pad_token_id == -1 && input_ids.size() > 1) { + fprintf(stderr, "\nERROR: please set pad_token for multi-batch generation (tokenizer.pad_token_id)!\n"); + return {}; + } + const auto& logits = evaluate_(input_ids); if (logits.empty()) return {}; @@ -675,9 +694,9 @@ std::vector> Model::generate_tokens(const std::vector n_ctx) { + if (n_past[STATIC_INPUT_HEAD_IDX] + curr_input_ids[STATIC_INPUT_HEAD_IDX].size() > n_ctx) { // always keep the first token - n_past = std::max(1, params.n_keep); + n_past[STATIC_INPUT_HEAD_IDX] = std::max(1, params.n_keep); int n_discard = params.n_discard; if (!params.shift_roped_k) { // shift_roped_k can use ring-buffer and thus does not need re-computing @@ -695,16 +714,16 @@ std::vector> Model::generate_tokens(const std::vector next_token_id = post_process(logits); @@ -725,28 +744,7 @@ std::vector> Model::generate_tokens(const std::vector Model::post_greedy_search(const float* logits) { - std::vector ids(ctx->batch_size); - static int n_vocab_segment = 1024; - int num_segments = (n_vocab + n_vocab_segment - 1) / n_vocab_segment; - std::vector candidate_tokens(ctx->batch_size * num_segments); - std::vector candidate_logits(ctx->batch_size * num_segments); -#pragma omp parallel for collapse(2) - for (int bs = 0; bs < ctx->batch_size; ++bs) { - for (int vocab = 0; vocab < n_vocab; vocab += n_vocab_segment) { - auto max_e = - std::max_element(logits + bs * n_vocab + vocab, vocab + n_vocab_segment > n_vocab - ? logits + bs * n_vocab + n_vocab - : logits + bs * n_vocab + vocab + n_vocab_segment); - candidate_tokens[bs * num_segments + vocab / n_vocab_segment] = max_e - (logits + bs * n_vocab); - candidate_logits[bs * num_segments + vocab / n_vocab_segment] = *max_e; - } - } - for (int bs = 0; bs < ctx->batch_size; ++bs) { - ids[bs] = candidate_tokens[std::distance(candidate_logits.begin(), - std::max_element(candidate_logits.begin() + bs * num_segments, - candidate_logits.begin() + (bs + 1) * num_segments))]; - } - return ids; + return model_post_greedy_search(logits, ctx); } std::vector> Model::post_beam_search(model_context* lctx, const int& n_predict, @@ -763,44 +761,7 @@ std::vector> Model::post_beam_search(model_context* lct } std::vector Model::post_sample_top_k_top_p_repeat(const float* logits) { - int alpha_frequency = 0; - int alpha_presence = 0; - int repeat_last_n = 64; - int top_k = params.top_k; - float tfs_z = 1.00f; - float typical_p = 1.00f; - float top_p = params.top_p; - float temp = params.temp; - std::vector ids(ctx->batch_size); - // #pragma omp parallel for // omp will affect sampling positions in batch infer - // TODO(Zhentao): (make sample functions support batch processing) - for (int bs = 0; bs < ctx->batch_size; ++bs) { - std::vector candidates; - candidates.reserve(n_vocab); - for (model_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(model_token_data{token_id, logits[bs * n_vocab + token_id], 0.0f}); - } - model_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; - - // Apply penalties - float nl_logit = logits[bs * n_vocab + model_token_nl()]; - auto last_n_repeat = std::min(std::min(static_cast(last_n_tokens[bs].size()), repeat_last_n), n_ctx); - model_sample_repetition_penalty(ctx, &candidates_p, - last_n_tokens[bs].data() + last_n_tokens[bs].size() - last_n_repeat, last_n_repeat, - params.repeat_penalty); - model_sample_frequency_and_presence_penalties(ctx, &candidates_p, - last_n_tokens[bs].data() + last_n_tokens[bs].size() - last_n_repeat, - last_n_repeat, alpha_frequency, alpha_presence); - // int id = model_sample_token_greedy(ctx, &candidates_p); - // Temperature sampling - model_sample_top_k(ctx, &candidates_p, top_k, 1); - model_sample_tail_free(ctx, &candidates_p, tfs_z, 1); - model_sample_typical(ctx, &candidates_p, typical_p, 1); - model_sample_top_p(ctx, &candidates_p, top_p, 1); - model_sample_temperature(ctx, &candidates_p, temp); - ids[bs] = model_sample_token(ctx, &candidates_p); - } - return ids; + return model_post_sample_top_k_top_p_repeat(logits, ctx, last_n_tokens); } std::vector Model::post_process(const float* logits) { diff --git a/neural_speed/models/model_utils/model_types.h b/neural_speed/models/model_utils/model_types.h index 454fbbb34..c0ccefdac 100644 --- a/neural_speed/models/model_utils/model_types.h +++ b/neural_speed/models/model_utils/model_types.h @@ -278,6 +278,12 @@ struct generation_config { // `length_penalty` < 0.0 encourages shorter sequences. (default = 1.0) float length_penalty = 1.0f; bool do_early_stopping = false; + // sampling parameters + bool do_sample = false; + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float temp = 0.80f; // 1.0 = disabled + float repeat_penalty = 1.10f; // 1.0 = disabled }; class beam_search_kv_cache_reorder; // forward declaration diff --git a/neural_speed/models/model_utils/model_utils.cpp b/neural_speed/models/model_utils/model_utils.cpp index 266f9486b..824abd1c6 100644 --- a/neural_speed/models/model_utils/model_utils.cpp +++ b/neural_speed/models/model_utils/model_utils.cpp @@ -1291,10 +1291,18 @@ struct model_context* model_init_from_gpt_params(const gpt_params& params) { lparams.shift_roped_k = params.shift_roped_k; lparams.cont_batching = params.cont_batching; lparams.max_request_num = params.max_request_num; - lparams.gen_conf.max_new_tokens = params.n_predict; - lparams.gen_conf.min_new_tokens = params.min_new_tokens; - lparams.gen_conf.length_penalty = params.length_penalty; - lparams.gen_conf.do_early_stopping = params.do_early_stopping; + generation_config gen_conf = { + /*.max_new_tokens =*/(uint32_t)params.n_predict, + /*.min_new_tokens =*/params.min_new_tokens, + /*.length_penalty =*/params.length_penalty, + /*.do_early_stopping =*/params.do_early_stopping, + /*.do_sample =*/params.do_sample, + /*.top_k =*/params.top_k, + /*.top_p =*/params.top_p, + /*.temp =*/params.temp, + /*.repeat_penalty =*/params.repeat_penalty, + }; + lparams.gen_conf = gen_conf; lparams.scratch_size_ratio = params.scratch_size_ratio; NE_ASSERT(("Start size cannot be greater than the maximum context size!", lparams.n_keep < lparams.n_ctx)); @@ -2504,8 +2512,7 @@ const std::vector>& beam_search_flow::loop(const std::v for (int ni = 0; ni < next_inputs.size(); ++ni) { n_tokens[ni] = next_inputs[ni].n_tokens; if (n_tokens[ni] > model_n_ctx(ctx)) { - fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, n_tokens[ni], - model_n_ctx(ctx) - 4); + fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, n_tokens[ni], model_n_ctx(ctx)); return response; } n_prompt_tokens[ni] = next_inputs[ni].n_tokens; @@ -2784,3 +2791,76 @@ std::vector> split_inputs_into_groups(const model_input* inputs } return groups; } + +std::vector model_post_greedy_search(const float* logits, model_context* ctx) { + std::vector ids(ctx->batch_size); + const int n_vocab = model_n_vocab(ctx); + static int n_vocab_segment = 1024; + int num_segments = (n_vocab + n_vocab_segment - 1) / n_vocab_segment; + std::vector candidate_tokens(ctx->batch_size * num_segments); + std::vector candidate_logits(ctx->batch_size * num_segments); +#pragma omp parallel for collapse(2) + for (int bs = 0; bs < ctx->batch_size; ++bs) { + for (int vocab = 0; vocab < n_vocab; vocab += n_vocab_segment) { + auto max_e = + std::max_element(logits + bs * n_vocab + vocab, vocab + n_vocab_segment > n_vocab + ? logits + bs * n_vocab + n_vocab + : logits + bs * n_vocab + vocab + n_vocab_segment); + candidate_tokens[bs * num_segments + vocab / n_vocab_segment] = max_e - (logits + bs * n_vocab); + candidate_logits[bs * num_segments + vocab / n_vocab_segment] = *max_e; + } + } + for (int bs = 0; bs < ctx->batch_size; ++bs) { + ids[bs] = candidate_tokens[std::distance(candidate_logits.begin(), + std::max_element(candidate_logits.begin() + bs * num_segments, + candidate_logits.begin() + (bs + 1) * num_segments))]; + } + return ids; +} + +std::vector model_post_sample_top_k_top_p_repeat( + const float* logits, model_context* ctx, const std::vector>& last_n_tokens, + const std::vector& last_n_tokens_indices) { + int alpha_frequency = 0; + int alpha_presence = 0; + int repeat_last_n = 64; + int top_k = ctx->generation_conf.top_k; + float tfs_z = 1.00f; + float typical_p = 1.00f; + float top_p = ctx->generation_conf.top_p; + float temp = ctx->generation_conf.temp; + std::vector ids(ctx->batch_size); + const int n_vocab = model_n_vocab(ctx); + if (!last_n_tokens_indices.empty()) MODEL_ASSERT(last_n_tokens_indices.size() == ctx->batch_size); + // #pragma omp parallel for // omp will affect sampling positions in batch dimension + for (int bs = 0; bs < ctx->batch_size; ++bs) { + std::vector candidates; + candidates.reserve(n_vocab); + for (model_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(model_token_data{token_id, logits[bs * n_vocab + token_id], 0.0f}); + } + model_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; + + // Apply penalties + float nl_logit = logits[bs * n_vocab + model_token_nl()]; + // continuous batching will update last_n_tokens in request_idx dimension + int idx = last_n_tokens_indices.empty() ? bs : last_n_tokens_indices.at(bs); + auto last_n_repeat = + std::min(std::min(static_cast(last_n_tokens[idx].size()), repeat_last_n), model_n_ctx(ctx)); + model_sample_repetition_penalty(ctx, &candidates_p, + last_n_tokens[idx].data() + last_n_tokens[idx].size() - last_n_repeat, + last_n_repeat, ctx->generation_conf.repeat_penalty); + model_sample_frequency_and_presence_penalties(ctx, &candidates_p, + last_n_tokens[idx].data() + last_n_tokens[idx].size() - last_n_repeat, + last_n_repeat, alpha_frequency, alpha_presence); + // int id = model_sample_token_greedy(ctx, &candidates_p); + // Temperature sampling + model_sample_top_k(ctx, &candidates_p, top_k, 1); + model_sample_tail_free(ctx, &candidates_p, tfs_z, 1); + model_sample_typical(ctx, &candidates_p, typical_p, 1); + model_sample_top_p(ctx, &candidates_p, top_p, 1); + model_sample_temperature(ctx, &candidates_p, temp); + ids[bs] = model_sample_token(ctx, &candidates_p); + } + return ids; +} diff --git a/neural_speed/models/model_utils/model_utils.h b/neural_speed/models/model_utils/model_utils.h index 9fba1ac6b..e893553dc 100644 --- a/neural_speed/models/model_utils/model_utils.h +++ b/neural_speed/models/model_utils/model_utils.h @@ -341,14 +341,22 @@ struct beam_hypotheses { int len() { return beams.size(); } - void add(beam b, const uint32_t& n_prompt_tokens) { + void add(beam b, const uint32_t& n_prompt_tokens, const bool penalize_prompt = false) { auto comp = [](const beam& a, const beam& b) { return a.score > b.score; }; uint32_t cur_len = b.eos() ? b.token_ids.size() - 1 : b.token_ids.size(); - float score = b.score / std::pow(cur_len + n_prompt_tokens, length_penalty); + float score; + // reference: + // https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/generation/beam_search.py#L954-L960 + if (penalize_prompt) { + score = b.score / std::pow(cur_len + n_prompt_tokens, length_penalty); + } else { + score = b.score / std::pow(cur_len, length_penalty); + } #ifdef NS_BEAM_SEARCH_VERBOSE_ON printf("add beam hypos: \n"); b.print(); - printf("origin_score: %12.6f, new_score: %12.6f, sentence_len: %d \n", b.score, score, cur_len + n_prompt_tokens); + printf("origin_score: %12.6f, new_score: %12.6f, generated_len: %d, sentence_len: %d \n", b.score, score, cur_len, + cur_len + n_prompt_tokens); printf("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ \n"); #endif b.score = score; @@ -517,6 +525,11 @@ MODEL_API std::vector> beam_search(model_context* lctx, const int& n_threads); // split model inputs into continuous inference groups which have num_requests length MODEL_API std::vector> split_inputs_into_groups(const model_input* inputs, const int n_input); +// token sampling function +MODEL_API std::vector model_post_greedy_search(const float* logits, model_context* ctx); +MODEL_API std::vector model_post_sample_top_k_top_p_repeat( + const float* logits, model_context* ctx, const std::vector>& last_n_tokens, + const std::vector& last_n_tokens_indices = {}); // Internal API to be implemented by model.cpp and used by tests/benchmarks only #ifdef MODEL_API_INTERNAL diff --git a/neural_speed/models/model_utils/scheduler.cpp b/neural_speed/models/model_utils/scheduler.cpp index cf466939b..bd0ffd2e2 100644 --- a/neural_speed/models/model_utils/scheduler.cpp +++ b/neural_speed/models/model_utils/scheduler.cpp @@ -22,10 +22,18 @@ Iter_level_worker::Iter_level_worker(const gpt_params& params) : m_ctx(model_ini } if (m_ctx->beam_search && bsf == nullptr) { bsf = new beam_search_flow(m_ctx, m_ctx->max_request_num, params.n_threads); + fprintf(stdout, "%s: use beam search generation in model server.\n", __func__); + } else if (m_ctx->generation_conf.do_sample == false) { + fprintf(stdout, "%s: use greedy search generation in model server.\n", __func__); + } else { + fprintf(stdout, "%s: use top_k_top_p sampling generation in model server.\n", __func__); } + // for repetition penalizing sampling and long context if (!m_ctx->beam_search) { - fprintf(stderr, "%s: error: only supports beam search.\n", __func__); - exit(0); + last_n_tokens.resize(m_ctx->max_request_num); + for (int i = 0; i < m_ctx->max_request_num; ++i) { + last_n_tokens[i].resize(m_ctx->n_ctx, 0); + } } threads = params.n_threads; } @@ -46,31 +54,44 @@ Cont_batch_gen_worker::Cont_batch_gen_worker(const gpt_params& params) : Iter_le bool Cont_batch_gen_worker::prepare_inputs(std::vector* seqs, const int& n_input, model_input* inputs) { for (int i = 0; i < n_input; ++i) { - if ((seqs->at(i)).status != seq_status::PREFILL && (seqs->at(i)).status != seq_status::DECODING) { + if (seqs->at(i).status != seq_status::PREFILL && seqs->at(i).status != seq_status::DECODING) { fprintf(stderr, "%s: error: request %d status is unright (%d).\n", __func__, seqs->at(i).request_idx, - static_cast((seqs->at(i)).status)); + static_cast(seqs->at(i).status)); return false; - } else if ((seqs->at(i)).status == seq_status::PREFILL) { - inputs[i].tokens = (seqs->at(i)).prompt_ids.data(); - inputs[i].n_tokens = (seqs->at(i)).n_prompt_tokens; - inputs[i].n_prompt_tokens = (seqs->at(i)).n_prompt_tokens; + } else if (seqs->at(i).status == seq_status::PREFILL) { + if (seqs->at(i).n_prompt_tokens + seqs->at(i).gen_conf.max_new_tokens > m_ctx->n_ctx) { + fprintf(stderr, "%s: error: prompt + max_new_tokens is too long (%d tokens, max %d) for model server.\n", + __func__, seqs->at(i).n_prompt_tokens + seqs->at(i).gen_conf.max_new_tokens, m_ctx->n_ctx); + return false; + } + inputs[i].tokens = seqs->at(i).prompt_ids.data(); + inputs[i].n_tokens = seqs->at(i).n_prompt_tokens; + inputs[i].n_prompt_tokens = seqs->at(i).n_prompt_tokens; inputs[i].n_past = 0; inputs[i].n_total = 0; - inputs[i].request_idx = (seqs->at(i)).request_idx; + inputs[i].request_idx = seqs->at(i).request_idx; // do not support padding for now inputs[i].n_padding = 0; - inputs[i].gen_conf = (seqs->at(i)).gen_conf; - } else if ((seqs->at(i)).status == seq_status::DECODING) { - inputs[i].tokens = &(seqs->at(i)).generated_ids.back(); + inputs[i].gen_conf = seqs->at(i).gen_conf; + } else if (seqs->at(i).status == seq_status::DECODING) { + inputs[i].tokens = (bsf != nullptr) ? nullptr : &(seqs->at(i).generated_ids.back()); inputs[i].n_tokens = 1; - inputs[i].n_past = (seqs->at(i)).n_past; - inputs[i].n_total = (seqs->at(i)).n_total; - inputs[i].request_idx = (seqs->at(i)).request_idx; + inputs[i].n_past = seqs->at(i).n_past; + inputs[i].n_total = seqs->at(i).n_total; + inputs[i].request_idx = seqs->at(i).request_idx; // do not support padding for now inputs[i].n_padding = 0; } else { continue; } + // update last_n_tokens + if (!m_ctx->beam_search && + (seqs->at(i).status == seq_status::PREFILL || seqs->at(i).status == seq_status::DECODING)) { + int req_idx = inputs[i].request_idx; + last_n_tokens[req_idx].erase(last_n_tokens[req_idx].begin(), last_n_tokens[req_idx].begin() + inputs[i].n_tokens); + last_n_tokens[req_idx].insert(last_n_tokens[req_idx].end(), inputs[i].tokens, + inputs[i].tokens + inputs[i].n_tokens); + } } return true; } @@ -87,17 +108,66 @@ bool Cont_batch_gen_worker::beam_search_step(std::vector* seqs, const return true; } +bool Cont_batch_gen_worker::greedy_search_step(std::vector* seqs, const int& n_input) { + // prepare inputs + std::vector step_inputs(n_input); + if (!prepare_inputs(seqs, n_input, step_inputs.data())) { + return false; + } + m_ctx->batch_size = n_input; + m_ctx->request_running_bs = n_input; + // model eval + if (model_eval(m_ctx, step_inputs.data(), step_inputs.size(), threads) > 0) { + return false; + } + // greedy search + next_tokens = model_post_greedy_search(m_ctx->logits.data(), m_ctx); + return true; +} + +bool Cont_batch_gen_worker::top_k_top_p_sample_step(std::vector* seqs, const int& n_input) { + // prepare inputs + std::vector step_inputs(n_input); + if (!prepare_inputs(seqs, n_input, step_inputs.data())) { + return false; + } + m_ctx->batch_size = n_input; + m_ctx->request_running_bs = n_input; + // model eval + if (model_eval(m_ctx, step_inputs.data(), step_inputs.size(), threads) > 0) { + return false; + } + // top_k_top_p sampling + std::vector last_n_tokens_indices(n_input, 0); + for (int ni = 0; ni < n_input; ++ni) { + last_n_tokens_indices[ni] = seqs->at(ni).request_idx; + } + next_tokens = model_post_sample_top_k_top_p_repeat(m_ctx->logits.data(), m_ctx, last_n_tokens, last_n_tokens_indices); + return true; +} + bool Cont_batch_gen_worker::step(std::vector* seqs, const int& n_input) { reqidx_to_vecid.clear(); for (int ni = 0; ni < n_input; ++ni) { reqidx_to_vecid.emplace(seqs->at(ni).request_idx, ni); } - if (m_ctx->beam_search && bsf != nullptr) { - if (!beam_search_step(seqs, n_input)) { + // beam search + if (m_ctx->beam_search) { + if (bsf == nullptr || !beam_search_step(seqs, n_input)) { + return false; + } + // greedy search + } else if (m_ctx->generation_conf.do_sample == false) { + if (!greedy_search_step(seqs, n_input)) { + return false; + } + // top_k_top_p sampling + } else { + if (!top_k_top_p_sample_step(seqs, n_input)) { return false; } } - // TODO (YZT) greedy search and top_p-top_k sampling + return update_seqs(seqs, n_input); } @@ -117,8 +187,17 @@ bool Cont_batch_gen_worker::update_seqs(std::vector* seqs, const int& fprintf(stderr, "%s: error: wrong sequence status %d.\n", __func__, static_cast(seqs->at(ni).status)); return false; } + if (!m_ctx->beam_search) { + if (next_tokens.size() != n_input) { + fprintf(stderr, "%s: error: wrong next_tokens size %ld, which should be %d.\n", __func__, next_tokens.size(), + n_input); + return false; + } + seqs->at(ni).generated_ids.emplace_back(next_tokens[ni]); + } } - if (m_ctx->beam_search && bsf != nullptr) { + if (m_ctx->beam_search) { + if (bsf == nullptr) return false; request_done_ids = bsf->request_done_ids(); std::vector> req_done_res = bsf->request_done_reponse(); if (request_done_ids.size() != req_done_res.size()) { @@ -140,8 +219,19 @@ bool Cont_batch_gen_worker::update_seqs(std::vector* seqs, const int& seqs->at(vecid).end_time = model_time_us(); } return true; + } else { + for (int ni = 0; ni < n_input; ++ni) { + if (seqs->at(ni).status == seq_status::DECODING && !seqs->at(ni).generated_ids.empty() && + (seqs->at(ni).generated_ids.back() == m_ctx->vocab.eos_token_id || + seqs->at(ni).generated_ids.size() >= seqs->at(ni).gen_conf.max_new_tokens)) { + seqs->at(ni).status = seq_status::FINISHED; + seqs->at(ni).end_time = model_time_us(); + request_done_ids.emplace_back(seqs->at(ni).request_idx); + last_n_tokens[seqs->at(ni).request_idx].resize(m_ctx->n_ctx, 0); + } + } + return true; } - return false; // TODO (YZT) greedy search and top_p-top_k sampling } // Iter_level_scheduler @@ -240,6 +330,8 @@ bool Cont_batch_gen_scheduler::prepare_seqs() { for (int np = 0; np < n_perfill_seqs; ++np) { if (waiting_pool.pop(&executed_seqs[cur_running_num + np])) { executed_seqs[cur_running_num + np].status = seq_status::PREFILL; + executed_seqs[cur_running_num + np].generated_ids.reserve( + executed_seqs[cur_running_num + np].gen_conf.max_new_tokens); if (executed_seqs[cur_running_num + np].request_idx == -1) { const int fidx = query_free_req_idx(); if (fidx == -1) { diff --git a/neural_speed/models/model_utils/scheduler.h b/neural_speed/models/model_utils/scheduler.h index 7b34044c5..6f1aabe67 100644 --- a/neural_speed/models/model_utils/scheduler.h +++ b/neural_speed/models/model_utils/scheduler.h @@ -24,8 +24,9 @@ class Iter_level_worker { explicit Iter_level_worker(const gpt_params& params); virtual ~Iter_level_worker(); virtual bool step(std::vector* seqs, const int& n_input) = 0; - // virtual bool greedy_search_step(sequence seqs, const int& n_input) = 0; virtual bool beam_search_step(std::vector* seqs, const int& n_input) = 0; + virtual bool greedy_search_step(std::vector* seqs, const int& n_input) = 0; + virtual bool top_k_top_p_sample_step(std::vector* seqs, const int& n_input) = 0; inline void set_threads(const int& n_threads) { threads = n_threads; } inline std::vector get_request_done_ids() const { return request_done_ids; } @@ -38,6 +39,8 @@ class Iter_level_worker { model_context* m_ctx = nullptr; int threads; beam_search_flow* bsf = nullptr; + std::vector next_tokens; + std::vector> last_n_tokens; std::vector request_done_ids; std::unordered_map reqidx_to_vecid; }; @@ -50,8 +53,9 @@ class Cont_batch_gen_worker : public Iter_level_worker { ~Cont_batch_gen_worker() = default; bool step(std::vector* seqs, const int& n_input) override; - // bool greedy_search_step(sequence seqs, const int& n_input) override; bool beam_search_step(std::vector*, const int& n_input) override; + bool greedy_search_step(std::vector* seqs, const int& n_input) override; + bool top_k_top_p_sample_step(std::vector* seqs, const int& n_input) override; protected: bool prepare_inputs(std::vector*, const int& n_input, model_input* inputs) override; diff --git a/scripts/python_api_example_for_model_server.py b/scripts/python_api_example_for_model_server.py index e0168ca62..583038887 100644 --- a/scripts/python_api_example_for_model_server.py +++ b/scripts/python_api_example_for_model_server.py @@ -2,7 +2,7 @@ import argparse from pathlib import Path from typing import List, Optional -import neural_speed.llama_cpp as cpp +from neural_speed import ModelServer from transformers import AutoTokenizer @@ -38,7 +38,25 @@ def main(args_in: Optional[List[str]] = None) -> None: help="scale for enlarge memory for model inference: Float", required=False, default=1.0) parser.add_argument("--memory_dtype", type=str, help="KV cache memory dtype: String", - required=False, default="auto") + required=False, default="auto", choices=["f32", "f16", "auto"]) + parser.add_argument("--ctx_size", type=int, help="Size of the prompt context: "\ + "Int (default: 512, can not be larger than specific model's context window"\ + " length)", required=False, default=512) + parser.add_argument("--seed", type=int, + help="NG seed: Int (default: -1, use random seed for < 0)", + required=False, default=-1) + parser.add_argument("--repeat_penalty", type=float, + help="Penalize repeat sequence of tokens: Float (default: 1.1, 1.0 = disabled)", + required=False, default=1.1) + parser.add_argument("--top_k", type=int, + help="top_k in generated token sampling: Int (default: 40, <= 0 to use vocab size)", + required=False, default=40) + parser.add_argument("--top_p", type=float, + help="top_p in generated token sampling: Float (default: 0.95, 1.0 = disabled)", + required=False, default=0.95) + parser.add_argument("--temperature", type=float, + help="temperature in generated token sampling: Float (default: 0.8, 1.0 = disabled)", + required=False, default=0.8) args = parser.parse_args(args_in) print(args) @@ -74,24 +92,31 @@ def f_response(res, working): print("=====================================") added_count = 0 - s = cpp.ModelServer(f_response, - str(args.model_path), - max_new_tokens=args.max_new_tokens, - num_beams=args.num_beams, - min_new_tokens=args.min_new_tokens, - early_stopping=args.early_stopping, - do_sample=args.do_sample, - continuous_batching=True, - return_prompt=args.return_prompt, - threads=args.threads, - max_request_num=args.max_request_num, - print_log=args.print_log, - scratch_size_ratio = args.scratch_size_ratio, - memory_dtype= args.memory_dtype, + s = ModelServer(args.model_name, + f_response, + str(args.model_path), + max_new_tokens=args.max_new_tokens, + num_beams=args.num_beams, + min_new_tokens=args.min_new_tokens, + early_stopping=args.early_stopping, + do_sample=args.do_sample, + continuous_batching=True, + return_prompt=args.return_prompt, + threads=args.threads, + max_request_num=args.max_request_num, + print_log=args.print_log, + scratch_size_ratio = args.scratch_size_ratio, + memory_dtype= args.memory_dtype, + ctx_size=args.ctx_size, + seed=args.seed, + top_k=args.top_k, + top_p=args.top_p, + repetition_penalty=args.repeat_penalty, + temperature=args.temperature, ) for i in range(len(prompts)): p_token_ids = tokenizer(prompts[i], return_tensors='pt').input_ids.tolist() - s.issueQuery([cpp.Query(i, p_token_ids)]) + s.issueQuery(i, p_token_ids) added_count += 1 time.sleep(2) # adjust query sending time interval diff --git a/tests/test_model_server.py b/tests/test_model_server.py index 3ec66e8ef..8c4946926 100644 --- a/tests/test_model_server.py +++ b/tests/test_model_server.py @@ -15,8 +15,7 @@ import time import unittest import shutil -from neural_speed import Model -import neural_speed.llama_cpp as cpp +from neural_speed import Model, ModelServer from transformers import AutoTokenizer class TestModelServer(unittest.TestCase): @@ -52,6 +51,23 @@ def test_model_server(self): model = Model() # get quantized model model.init(model_name, use_quant=True, weight_dtype="int4", compute_dtype="int8") + print("=======REFERENCE RESULTS FOR COMPARISON=========", flush=True) + print("=======FOR LOOP GREEDY SEARCH GENERATION RESULTS WITH MHA==========", flush=True) + for i in range(len(prompts)): + p_token_ids = tokenizer(prompts[i], return_tensors='pt').input_ids + output = model.generate(p_token_ids, num_beams=1, max_new_tokens=128, do_sample=False) + ans = tokenizer.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + print(ans, flush=True) + print("================================", flush=True) + print("=======FOR LOOP BEAM SEARCH GENERATION RESULTS WITH MHA==========", flush=True) + for i in range(len(prompts)): + p_token_ids = tokenizer(prompts[i], return_tensors='pt').input_ids + output = model.generate(p_token_ids, num_beams=4, max_new_tokens=128, min_new_tokens=30, + early_stopping=True, do_sample=False, continuous_batching=True, + max_request_num=4) + ans = tokenizer.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + print(ans, flush=True) + print("================================", flush=True) del model model_path = "./runtime_outs/ne_llama_q_int4_bestla_cint8_g32.bin" @@ -64,19 +80,21 @@ def f_response(res, working): clean_up_tokenization_spaces=False) print(f"working_size: {working}, ans:", flush=True) for a in ans: - print(a) - print("=====================================") + print(a, flush=True) + print("=====================================", flush=True) + log_map = {"auto": "MHA", "f16": "NON-MHA", + "greedy": "GREEDY SEARCH", "beam": "BEAM SEARCH"} for md in ["auto", "f16"]: - if md == "auto": - print("=======MHA MODEL SERVER TESTING=========") - else: - print("=======NON-MHA MODEL SERVER TESTING=========") - added_count = 0 - s = cpp.ModelServer(f_response, + for policy in ["greedy", "beam"]: + print("============={} {} MODEL SERVER TESTING========".format(log_map[md], + log_map[policy]), flush=True) + added_count = 0 + s = ModelServer(model_name, + f_response, model_path, max_new_tokens=128, - num_beams=4, + num_beams=4 if policy == "beam" else 1, min_new_tokens=30, early_stopping=True, do_sample=False, @@ -87,19 +105,19 @@ def f_response(res, working): print_log=False, scratch_size_ratio = 1.0, memory_dtype= md, - ) - for i in range(len(prompts)): - p_token_ids = tokenizer(prompts[i], return_tensors='pt').input_ids.tolist() - s.issueQuery([cpp.Query(i, p_token_ids)]) - added_count += 1 - time.sleep(2) # adjust query sending time interval + ) + for i in range(len(prompts)): + p_token_ids = tokenizer(prompts[i], return_tensors='pt').input_ids.tolist() + s.issueQuery(i, p_token_ids) + added_count += 1 + time.sleep(2) # adjust query sending time interval - # recommend to use time.sleep in while loop to exit program - # let cpp server owns more resources - while (added_count != len(prompts) or not s.Empty()): - time.sleep(1) - del s - print("should finished") + # recommend to use time.sleep in while loop to exit program + # let cpp server owns more resources + while (added_count != len(prompts) or not s.Empty()): + time.sleep(1) + del s + print("should finished", flush=True) if __name__ == "__main__": unittest.main() diff --git a/tests/test_python_api.py b/tests/test_python_api.py index 1cccea15b..360b30e83 100644 --- a/tests/test_python_api.py +++ b/tests/test_python_api.py @@ -78,7 +78,7 @@ def test_llm_runtime(self): print(config_type, diff_data) - def test_beam_search(self): + def test_multi_batch_inference(self): model_name = "/tf_dataset2/models/pytorch/gpt-j-6B" # or local path to model prompts = [ "she opened the door and see", @@ -93,13 +93,13 @@ def test_beam_search(self): tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left") tokenizer.pad_token = tokenizer.eos_token - pad_token = tokenizer(tokenizer.pad_token)['input_ids'][0] + pad_token = tokenizer.pad_token_id inputs = tokenizer(prompts, padding=True, return_tensors='pt') # pytorch fp32 pt_generate_ids = torch.load("/tf_dataset2/inc-ut/nlptoolkit_ut_model/beam_pt_generate_ids.pth").tolist() - # llm runtime fp32 + # llm runtime fp32 beam search itrex_model = Model() itrex_model.init(model_name, use_quant=False) itrex_generate_ids_padded = itrex_model.generate( @@ -114,6 +114,19 @@ def test_beam_search(self): for i in range(len(itrex_generate_ids_cont)): self.assertListEqual(itrex_generate_ids_cont[i], itrex_generate_ids_cont[i]) + # llm runtime int4 greedy search + itrex_model = Model() + itrex_model.init(model_name, use_quant=True, weight_dtype="int4", compute_dtype="int8") + outputs = itrex_model.generate(inputs.input_ids, num_beams=1, max_new_tokens=128, pad_token=pad_token, + continuous_batching=True, memory_dtype="f16", do_sample=False) + for i in range(len(prompts)): + input_ids = tokenizer(prompts[i], return_tensors='pt').input_ids + output = itrex_model.generate(input_ids, num_beams=1, max_new_tokens=128, pad_token=pad_token, + memory_dtype="f16", do_sample=False) + # ignore pad token + gen_len = len(output[0]) - input_ids.shape[-1] + self.assertListEqual(outputs[i][inputs.input_ids.shape[-1]: inputs.input_ids.shape[-1] + gen_len], + output[0][input_ids.shape[-1]:]) if __name__ == "__main__": unittest.main()