diff --git a/developer_document.md b/developer_document.md index 3b6e41cde..19ad23cf4 100644 --- a/developer_document.md +++ b/developer_document.md @@ -8,7 +8,7 @@ For simplicity, we take [polyglot](https://huggingface.co/EleutherAI/polyglot-ko Firstly, we need to add its temp buffer in its [related model-arch header file](neural_speed/models/gptneox/gptneox.h) and [re-compile](README.md#Install). ```diff -static const model_scratch gptneox_mem_req(int n_layers) { +static const model_scratch gptneox_mem_req(int n_layers, float scratch_size_ratio = 1.0f) { switch (n_layers) { case 44: return {2048ull * MB, 2048ull * MB, 4096ull * MB}; @@ -167,7 +167,7 @@ and update [model_name_to_arch()](neural_speed/models/model_utils/model_types.h# + NEW_MODEL_13B, +}; -+static const model_scratch new_model_mem_req(int n_layers) { ++static const model_scratch new_model_mem_req(int n_layers, float scratch_size_ratio = 1.0f) { + switch (n_layers) { + case N: + return {8192ull * MB, 8192ull * MB, 8192ull * MB}; @@ -390,7 +390,7 @@ We recommend to use continuous batching way since it has no padding effect and c + ne_view_2d(ctx0, KQV_merged_contiguous, head_size * n_head, attn_sl * attn_bs, head_size * n_head * ne_element_size(KQV_merged_contiguous), ne_element_size(KQV_merged_contiguous) * off_sl))); + off_sl += head_size * n_head * attn_sl * attn_bs; ``` ->Note: You can set larger [`NE_MAX_NODES`](neural_speed/core/ne.h#43) and [`model_scratch_enlarge_scale`](neural_speed/models/llama/llama.h#29) values if out of memory when the inputs' batch size becomes larger. +>Note: You can set larger [`NE_MAX_NODES`](neural_speed/core/ne.h#43) and [`scratch_size_ratio`](neural_speed/models/llama/llama.h#29) values if out of memory when the inputs' batch size becomes larger. ## 2.3. Application - Q4_0 quant : We can quantize the model generated by convert by adding a quant layer class to quantize it into an int4 low-bit file, so as to obtain better inference performance. Register quant layer class in your new_model_utils.cpp, just like [gptneox_utils.cpp](neural_speed/models/gptneox/gptneox_utils.cpp#L163), replace `gptneox_quant_layer` to your `new_model_quant_layer`. diff --git a/docs/gptq_and_awq.md b/docs/gptq_and_awq.md index ec66f5e43..d8dfbc43f 100644 --- a/docs/gptq_and_awq.md +++ b/docs/gptq_and_awq.md @@ -13,7 +13,7 @@ Validated GPTQ & AWQ models directly from the HuggingFace: * [Qwen-7B-Chat-GPTQ](https://huggingface.co/TheBloke/Qwen-7B-Chat-GPTQ) & [Qwen-7B-Chat-AWQ](https://huggingface.co/TheBloke/Qwen-7B-Chat-AWQ) & * [Qwen1.5-7B-Chat-GPTQ-Int4](https://huggingface.co/Qwen/Qwen1.5-7B-Chat-GPTQ-Int4) * [SOLAR-10.7B-v1.0-GPTQ](https://huggingface.co/TheBloke/SOLAR-10.7B-v1.0-GPTQ) -Please check more validated GPTQ & AWQ models in the list of [supported_models](./docs/supported_models.md). +Please check more validated GPTQ & AWQ models in the list of [supported_models](./supported_models.md). ## Examples diff --git a/docs/supported_models.md b/docs/supported_models.md index ef3f7d362..4aad26d29 100644 --- a/docs/supported_models.md +++ b/docs/supported_models.md @@ -10,6 +10,7 @@ Neural Speed supports the following models: INT8 INT4 Transformer Version + Max tokens length RTN @@ -36,6 +37,7 @@ Neural Speed supports the following models: ✅ ✅ Latest + 4096 LLaMA-7B, @@ -49,6 +51,7 @@ Neural Speed supports the following models: ✅ ✅ Latest + 2048 CodeLlama-7b ✅ @@ -60,6 +63,7 @@ Neural Speed supports the following models: ✅ ✅ Latest + 16384 Solar-10.7B @@ -72,6 +76,7 @@ Neural Speed supports the following models: ✅ ✅ Latest + 4096 Neural-Chat-7B-v3-1, @@ -85,6 +90,7 @@ Neural Speed supports the following models: ✅ ✅ Latest + 32768 Mistral-7B, @@ -98,6 +104,7 @@ Neural Speed supports the following models: ✅ ✅ 4.36.0 or newer + 32768 Qwen-7B, @@ -113,6 +120,7 @@ Neural Speed supports the following models: ✅ ✅ Latest + 8192 / 32768 GPT-J-6B @@ -125,6 +133,7 @@ Neural Speed supports the following models: ✅ ✅ Latest + 2048 GPT-NeoX-20B @@ -137,6 +146,7 @@ Neural Speed supports the following models: Latest + 2048 Dolly-v2-3B @@ -149,6 +159,7 @@ Neural Speed supports the following models: 4.28.1 or newer + 2048 MPT-7B, @@ -162,6 +173,7 @@ Neural Speed supports the following models: Latest + 2048 Falcon-7B, @@ -175,6 +187,7 @@ Neural Speed supports the following models: Latest + 2048 BLOOM-7B @@ -187,6 +200,7 @@ Neural Speed supports the following models: Latest + 2048 OPT-125m, @@ -201,6 +215,7 @@ Neural Speed supports the following models: Latest + 2048 ChatGLM-6B, @@ -214,6 +229,7 @@ Neural Speed supports the following models: 4.33.1 + 2048 / 32768 Baichuan-13B-Chat, @@ -227,6 +243,7 @@ Neural Speed supports the following models: 4.33.1 + 4096 phi-2, @@ -241,6 +258,7 @@ Neural Speed supports the following models: Latest + 2048 Whisper-tiny, @@ -257,6 +275,7 @@ Neural Speed supports the following models: Latest + 448 diff --git a/neural_speed/__init__.py b/neural_speed/__init__.py index e9f4203fb..dda41c270 100644 --- a/neural_speed/__init__.py +++ b/neural_speed/__init__.py @@ -24,6 +24,7 @@ class Model: + def __init__(self): self.module = None self.model = None @@ -84,9 +85,19 @@ def get_model_type(model_config): model_type = "chatglm2" return model_type - def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, use_autoround=False, - weight_dtype="int4", alg="sym", group_size=32, - scale_dtype="fp32", compute_dtype="int8", use_ggml=False, model_hub="huggingface"): + def init(self, + model_name, + use_quant=True, + use_gptq=False, + use_awq=False, + use_autoround=False, + weight_dtype="int4", + alg="sym", + group_size=32, + scale_dtype="fp32", + compute_dtype="int8", + use_ggml=False, + model_hub="huggingface"): if model_hub == "modelscope": from modelscope import AutoConfig self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) @@ -124,8 +135,7 @@ def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, use_au self.bin_file = quant_bin if os.path.exists(self.bin_file): - print("{} existed, will use cache file. Otherwise please remove the file". - format(self.bin_file)) + print("{} existed, will use cache file. Otherwise please remove the file".format(self.bin_file)) return if use_gptq or use_awq or use_autoround: @@ -133,15 +143,20 @@ def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, use_au return if not os.path.exists(fp32_bin): - convert_model(model_name, fp32_bin, "f32", model_hub = model_hub) + convert_model(model_name, fp32_bin, "f32", model_hub=model_hub) assert os.path.exists(fp32_bin), "Fail to convert pytorch model" if not use_quant: print("FP32 model will be used.") return - self.module.Model.quant_model(model_path=fp32_bin, out_path=quant_bin, - weight_dtype=weight_dtype, alg=alg, group_size=group_size, - scale_dtype=scale_dtype, compute_dtype=compute_dtype, use_ggml=use_ggml) + self.module.Model.quant_model(model_path=fp32_bin, + out_path=quant_bin, + weight_dtype=weight_dtype, + alg=alg, + group_size=group_size, + scale_dtype=scale_dtype, + compute_dtype=compute_dtype, + use_ggml=use_ggml) assert os.path.exists(quant_bin), "Fail to quantize model" # clean @@ -150,9 +165,11 @@ def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, use_au def init_from_bin(self, model_type, model_path, **generate_kwargs): self.__import_package(model_type) self.model = self.module.Model() + if self.max_request_num == -1: - self.max_request_num = max(generate_kwargs.get("max_request_num", - max_request_num_default), generate_kwargs.get("batch_size", 1)) + self.max_request_num = max(generate_kwargs.get("max_request_num", max_request_num_default), + generate_kwargs.get("batch_size", 1)) + if "threads" not in generate_kwargs: threads = os.getenv("OMP_NUM_THREADS") import platform @@ -165,29 +182,107 @@ def init_from_bin(self, model_type, model_path, **generate_kwargs): generate_kwargs["threads"] = len(os.sched_getaffinity(0)) else: generate_kwargs["threads"] = int(threads) - self.model.init_model(model_path, **generate_kwargs) + # Setting scratch_size_ratio according to the ctx_size & tokens_length + # If scratch_size_ratio has been set, will not enter this branch. + if generate_kwargs.get("ctx_size") is not None and generate_kwargs.get( + "ctx_size") > 2048 and generate_kwargs.get("scratch_size_ratio") is None: + + def get_max_seq_length(): + config = self.config.to_dict() + # chatglm2, bloom + if 'seq_length' in config: + return config['seq_length'] + # qwen2, llama-2, llama, dolly, gptneox, qwen, qwen1.5, opt, phi + elif 'max_position_embeddings' in config: + return config['max_position_embeddings'] + # baichuan, baichuan2 + elif 'model_max_length' in config: + return config['model_max_length'] + # gptj + elif 'n_positions' in config: + return config['n_positions'] + # mpt + elif 'max_seq_len' in config: + return config['max_seq_len'] + # chatglm + elif 'max_sequence_length' in config: + return config['max_sequence_length'] + # whisper + elif 'max_length' in config: + return config['max_length'] + # Falcon does not have these parameters. + elif model_type == "falcon": + return 2048 + else: + print("Not found max seq length, setting to default 512") + return 512 + + # when tokens less than 10240 + def get_scratch_size_ratio(size): + if size > 2048 and size <= 4096: + return 2 + elif size > 4096 and size <= 8192: + return 4 + elif size > 8192 and size <= 10240: + return 8 + else: + # more than 10240 + return -1 + + max_seq_length = get_max_seq_length() + ctx_size = generate_kwargs.get("ctx_size") + + if ctx_size > max_seq_length: + print(f'max_seq_length is {max_seq_length}, but ctx_size is {ctx_size}. Please reduce ctx_size.') + exit(0) + + if max_seq_length > 2048 and max_seq_length <= 4096: + generate_kwargs["scratch_size_ratio"] = 2 + elif max_seq_length > 4096 and max_seq_length <= 8192: + generate_kwargs["scratch_size_ratio"] = 4 + elif max_seq_length > 8192: + if get_scratch_size_ratio(ctx_size) != -1: + generate_kwargs["scratch_size_ratio"] = get_scratch_size_ratio(ctx_size) + else: + if max_seq_length == 16384: + generate_kwargs["scratch_size_ratio"] = 12 + elif max_seq_length == 32768: + if ctx_size < 20480: + generate_kwargs["scratch_size_ratio"] = 20 + else: + generate_kwargs["scratch_size_ratio"] = 35 + + self.model.init_model(model_path, **generate_kwargs) def quant_model(self, model_type, model_path, out_path, **quant_kwargs): self.__import_package(model_type) self.module.Model.quant_model(model_path=model_path, out_path=out_path, **quant_kwargs) + def generate(self, + input_ids, + streamer=None, + interactive=False, + ignore_prompt=False, + stopping_criteria=None, + **generate_kwargs): + batch_size = input_ids.shape[0] - def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False, - stopping_criteria=None, **generate_kwargs): max_new_tokens = generate_kwargs.get("max_new_tokens", -1) - input_bs = input_ids.shape[0] max_request_num = generate_kwargs.pop("max_request_num", max_request_num_default) reinit_from_bin = False - if max_request_num > self.max_request_num or input_bs > self.max_request_num: + if max_request_num > self.max_request_num or batch_size > self.max_request_num: reinit_from_bin = True if self.max_request_num > 0: print("Will start to reinit model from bin due to different max request num.") - self.max_request_num = max(input_bs, max_request_num) + self.max_request_num = max(batch_size, max_request_num) if self.model is None or reinit_from_bin: - self.init_from_bin(self.model_type, self.bin_file, batch_size=input_bs, - max_request_num = self.max_request_num, **generate_kwargs) + self.init_from_bin(self.model_type, + self.bin_file, + batch_size=batch_size, + max_request_num=self.max_request_num, + **generate_kwargs) self.generate_round = 0 elif not interactive: self.model.reinit() @@ -208,6 +303,7 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa assert input_ids.shape[0] == 1, "Streamer only supports batch size 1." assert beam_search == False, "ERROR, can not use streamer when use beam search for generation! \ Make sure that `num_beams` is set to 1." + if self.generate_round == 0 and not ignore_prompt: streamer.put(input_ids) @@ -284,6 +380,6 @@ def _cont_batching_input(self, input_ids, pad_token_id=None): for il in range(len(input_list)): count = input_list[il].count(pti) # padding left - del input_list[il][0: count] + del input_list[il][0:count] assert input_list[il] != [], "there are all pad tokens in batch {}.".format(il) return input_list diff --git a/neural_speed/application/main_pybind.cpp b/neural_speed/application/main_pybind.cpp index f82fd803b..9149cd439 100644 --- a/neural_speed/application/main_pybind.cpp +++ b/neural_speed/application/main_pybind.cpp @@ -84,7 +84,7 @@ void init_gpt_params(gpt_params* params, const std::string& model_path, int max_ bool early_stopping = false, int n_keep = 0, int n_discard = -1, bool shift_roped_k = false, int batch_size = 1, model_vocab::id pad_token = -1, const std::string& memory_dtype = "auto", bool continuous_batching = true, const int& max_request_num = MODEL_MAX_REQUEST_NUM, - const float& model_scratch_enlarge_scale = 1.0f) { + const float& scratch_size_ratio = 1.0f) { MODEL_ASSERT(params != nullptr); #ifdef MODEL_NAME params->model_name = MODEL_NAME; @@ -115,6 +115,7 @@ void init_gpt_params(gpt_params* params, const std::string& model_path, int max_ params->memory_type = KV_MEM_TYPE_AUTO; else fprintf(stderr, "Unexpected memory dtype %s!", memory_dtype.c_str()); + // TODO(Yi & YZT): MHA IN MULTI-BATCH For More Model Archs params->cont_batching = continuous_batching; if (params->shift_roped_k) params->cont_batching = false; @@ -126,13 +127,19 @@ void init_gpt_params(gpt_params* params, const std::string& model_path, int max_ params->min_new_tokens = min_new_tokens; params->length_penalty = length_penalty; params->do_early_stopping = early_stopping; - params->model_scratch_enlarge_scale = model_scratch_enlarge_scale; + params->scratch_size_ratio = scratch_size_ratio; + + // TODO(Yi): MHA FOR LONG TOKENS + int32_t tokens_length = 6144; + if (params->n_ctx > tokens_length) { + params->memory_type = KV_MEM_TYPE_F16; + } printf( - "beam_size: %d, do_sample: %d, top_k: %d, top_p: %f, continuous_batching: %d, max_request_num: %d, " - "early_stopping: %d\n", + "beam_size: %d, do_sample: %d, top_k: %d, top_p: %.3f, continuous_batching: %d, max_request_num: %d, " + "early_stopping: %d, scratch_size_ratio: %.3f\n", params->beam_size, params->do_sample, params->top_k, params->top_p, params->cont_batching, - params->max_request_num, params->do_early_stopping); + params->max_request_num, params->do_early_stopping, params->scratch_size_ratio); } class ModelServer { @@ -142,7 +149,7 @@ class ModelServer { int top_k, float top_p, float temperature, int min_new_tokens, float length_penalty, bool early_stopping, int n_keep, int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token, const std::string& memory_dtype, bool continuous_batching, const int& max_request_num, - const float& model_scratch_enlarge_scale, const std::string& policy, bool print_log, + const float& scratch_size_ratio, const std::string& policy, bool print_log, const std::function& init_cb) : response(response), waiting(), @@ -161,7 +168,7 @@ class ModelServer { this->InitServerParams(model_path, max_new_tokens, n_batch, ctx_size, seed, threads, repetition_penalty, num_beams, do_sample, top_k, top_p, temperature, min_new_tokens, length_penalty, early_stopping, n_keep, n_discard, shift_roped_k, batch_size, pad_token, memory_dtype, - true, max_request_num, model_scratch_enlarge_scale); + true, max_request_num, scratch_size_ratio); Cont_batch_gen_scheduler scheduler(this->params, policy, print_log ? 0 : 1); std::vector added_seqs; while (running) { @@ -263,11 +270,11 @@ class ModelServer { float temperature, int min_new_tokens, float length_penalty, bool early_stopping, int n_keep, int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token, const std::string& memory_dtype, bool continuous_batching, const int& max_request_num, - const float& model_scratch_enlarge_scale) { + const float& scratch_size_ratio) { init_gpt_params(¶ms, model_path, max_new_tokens, n_batch, ctx_size, seed, threads, repetition_penalty, num_beams, do_sample, top_k, top_p, temperature, min_new_tokens, length_penalty, early_stopping, n_keep, n_discard, shift_roped_k, batch_size, pad_token, memory_dtype, continuous_batching, - max_request_num, model_scratch_enlarge_scale); + max_request_num, scratch_size_ratio); if (cont_batching_model_archs.count(params.model_arch) == 0) { fprintf(stderr, "\nERROR: ModelServer only supports gpt-j, llama!\n"); running = false; @@ -325,7 +332,7 @@ class Model { float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p, float temperature, int min_new_tokens, float length_penalty, bool early_stopping, int n_keep, int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token, const std::string& memory_dtype, - bool continuous_batching, const int& max_request_num, const float& model_scratch_enlarge_scale); + bool continuous_batching, const int& max_request_num, const float& scratch_size_ratio); void reinit(); std::vector> generate(const std::vector>& input_ids); // deprecated API @@ -419,11 +426,11 @@ void Model::init_model(const std::string& model_path, int max_new_tokens, int n_ float temperature, int min_new_tokens, float length_penalty, bool early_stopping, int n_keep, int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token, const std::string& memory_dtype, bool continuous_batching, const int& max_request_num, - const float& model_scratch_enlarge_scale) { + const float& scratch_size_ratio) { init_gpt_params(¶ms, model_path, max_new_tokens, n_batch, ctx_size, seed, threads, repetition_penalty, num_beams, do_sample, top_k, top_p, temperature, min_new_tokens, length_penalty, early_stopping, n_keep, n_discard, shift_roped_k, batch_size, pad_token, memory_dtype, continuous_batching, max_request_num, - model_scratch_enlarge_scale); + scratch_size_ratio); n_past = 0; n_total = 0; @@ -533,12 +540,13 @@ const std::vector& Model::evaluate_(const std::vector n_ctx - 4) { // long input_id_cb and empty curr_input_ids[bs] + } else if (input_id_cb.size() > n_ctx - params.n_keep) { // long input_id_cb and empty curr_input_ids[bs] fprintf(stderr, "\n%s: Warning: prompt is too long (%zu tokens, max %d), will be truncated\n", __func__, - input_id_cb.size(), n_ctx - 4); - curr_input_ids[bs].resize(n_ctx - 4); - std::copy(input_id_cb.end() - n_ctx - 8, input_id_cb.end(), curr_input_ids[bs].begin() + 4); - std::copy(input_id_cb.begin(), input_id_cb.begin() + 4, curr_input_ids[bs].begin()); + input_id_cb.size(), n_ctx - params.n_keep); + curr_input_ids[bs].resize(n_ctx - params.n_keep); + std::copy(input_id_cb.end() - n_ctx - params.n_keep * 2, input_id_cb.end(), + curr_input_ids[bs].begin() + params.n_keep); + std::copy(input_id_cb.begin(), input_id_cb.begin() + params.n_keep, curr_input_ids[bs].begin()); } else { // good input_id_cb and empty curr_input_ids[bs] curr_input_ids[bs] = input_id_cb; } @@ -648,13 +656,13 @@ std::vector> Model::generate_tokens(const std::vector n_ctx - 4) { + if (input_ids[STATIC_INPUT_HEAD_IDX].size() > n_ctx - params.n_keep) { fprintf(stderr, "\n%s: Warning: prompt is too long (%zu tokens, max %d), will be truncated\n", __func__, - input_ids[STATIC_INPUT_HEAD_IDX].size(), n_ctx - 4); - curr_input_ids[STATIC_INPUT_HEAD_IDX].resize(n_ctx - 4); - std::copy(input_ids[STATIC_INPUT_HEAD_IDX].end() - n_ctx - 8, input_ids[STATIC_INPUT_HEAD_IDX].end(), - curr_input_ids[STATIC_INPUT_HEAD_IDX].begin() + 4); - std::copy(input_ids[STATIC_INPUT_HEAD_IDX].begin(), input_ids[STATIC_INPUT_HEAD_IDX].begin() + 4, + input_ids[STATIC_INPUT_HEAD_IDX].size(), n_ctx - params.n_keep); + curr_input_ids[STATIC_INPUT_HEAD_IDX].resize(n_ctx - params.n_keep); + std::copy(input_ids[STATIC_INPUT_HEAD_IDX].end() - n_ctx - params.n_keep * 2, + input_ids[STATIC_INPUT_HEAD_IDX].end(), curr_input_ids[STATIC_INPUT_HEAD_IDX].begin() + params.n_keep); + std::copy(input_ids[STATIC_INPUT_HEAD_IDX].begin(), input_ids[STATIC_INPUT_HEAD_IDX].begin() + params.n_keep, curr_input_ids[STATIC_INPUT_HEAD_IDX].begin()); } else { curr_input_ids[STATIC_INPUT_HEAD_IDX] = input_ids[STATIC_INPUT_HEAD_IDX]; @@ -922,7 +930,7 @@ PYBIND11_MODULE(mixtral_cpp, m) py::arg("n_keep") = 0, py::arg("n_discard") = -1, py::arg("shift_roped_k") = false, py::arg("batch_size") = 1, py::arg("pad_token") = -1, py::arg("memory_dtype") = "auto", py::arg("continuous_batching") = true, py::arg("max_request_num") = MODEL_MAX_REQUEST_NUM, - py::arg("model_scratch_enlarge_scale") = 1.0f) + py::arg("scratch_size_ratio") = 1.0f) .def("generate", &Model::generate, "Generate token with input ids", py::arg("input_ids")) .def("evaluate", &Model::evaluate, "Evaluate token with input ids and output logits", py::arg("input_ids") = std::vector>{}, py::arg("logits_all") = false) @@ -962,7 +970,7 @@ PYBIND11_MODULE(mixtral_cpp, m) py::arg("length_penalty") = 1.0, py::arg("early_stopping") = false, py::arg("n_keep") = 0, py::arg("n_discard") = -1, py::arg("shift_roped_k") = false, py::arg("batch_size") = 1, py::arg("pad_token") = -1, py::arg("memory_dtype") = "auto", py::arg("continuous_batching") = true, - py::arg("max_request_num") = MODEL_MAX_REQUEST_NUM, py::arg("model_scratch_enlarge_scale") = 1.0f, + py::arg("max_request_num") = MODEL_MAX_REQUEST_NUM, py::arg("scratch_size_ratio") = 1.0f, py::arg("policy") = "fcfs", py::arg("print_log") = false, py::arg("init_cb") = std::function{[]() {}}) .def("issueQuery", &ModelServer::issueQuery, "desc placeholder", py::arg("qs")) diff --git a/neural_speed/application/main_run.cpp b/neural_speed/application/main_run.cpp index 7a74197e4..55cd4fec8 100644 --- a/neural_speed/application/main_run.cpp +++ b/neural_speed/application/main_run.cpp @@ -250,9 +250,9 @@ int main(int argc, char** argv) { // NOLINT const int n_ctx = model_n_ctx(ctx); - if (static_cast(embd_inp.size()) > n_ctx - 4) { + if (static_cast(embd_inp.size()) > n_ctx - params.n_keep) { fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, static_cast(embd_inp.size()), - n_ctx - 4); + n_ctx - params.n_keep); return 1; } @@ -352,8 +352,8 @@ int main(int argc, char** argv) { // NOLINT params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); - fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, - params.n_predict, params.n_keep); + fprintf(stderr, "generate: n_ctx = %d, tokens_length = %ld, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, + embd_inp.size(), params.n_batch, params.n_predict, params.n_keep); fprintf(stderr, "\n\n"); // TODO(Bo): replace with ring-buffer diff --git a/neural_speed/application/pybind_gptj.cpp b/neural_speed/application/pybind_gptj.cpp index 9779b6913..f98dc7a0a 100644 --- a/neural_speed/application/pybind_gptj.cpp +++ b/neural_speed/application/pybind_gptj.cpp @@ -35,9 +35,8 @@ static model_context** g_ctx; bool gptj_model_eval_ids(model_context* ctx, model_token* tokens, size_t n_eval, size_t n_past, size_t n_threads) { const int n_ctx = model_n_ctx(ctx); - if (static_cast(n_eval) > n_ctx - 4) { - fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, static_cast(n_eval), - n_ctx - 4); + if (static_cast(n_eval) > n_ctx) { + fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, static_cast(n_eval), n_ctx); return true; } diff --git a/neural_speed/core/ne_layers.c b/neural_speed/core/ne_layers.c index 9030e67cd..b8110b74c 100644 --- a/neural_speed/core/ne_layers.c +++ b/neural_speed/core/ne_layers.c @@ -911,8 +911,10 @@ struct ne_tensor* ne_new_tensor_impl(struct ne_context* ctx, enum ne_type type, size_needed += sizeof(struct ne_tensor); if (cur_end + size_needed + NE_OBJECT_SIZE > ctx->mem_size) { - NE_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", __func__, - cur_end + size_needed + NE_OBJECT_SIZE, ctx->mem_size); + NE_PRINT( + "%s: %d Context's memory pool is not enough(current %zu MB, ctx->mem_size available %zu MB), please increase " + "the scratch_size_ratio.\n", + __func__, __LINE__, (cur_end + size_needed + NE_OBJECT_SIZE) / 1024 / 1024, ctx->mem_size / 1024 / 1024); assert(false); return NULL; } @@ -924,14 +926,17 @@ struct ne_tensor* ne_new_tensor_impl(struct ne_context* ctx, enum ne_type type, }; } else { if (ctx->scratch.offs + size_needed > ctx->scratch.size) { - NE_PRINT("%s: not enough space in the scratch memory\n", __func__); + NE_PRINT( + "%s: %d scratch.size pool is not enough(current %zu MB, ctx->scratch.size available %zu MB), please increase " + "the scratch_size_ratio.\n", + __func__, __LINE__, (ctx->scratch.offs + size_needed) / 1024 / 1024, ctx->scratch.size / 1024 / 1024); assert(false); return NULL; } if (cur_end + sizeof(struct ne_tensor) + NE_OBJECT_SIZE > ctx->mem_size) { - NE_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", __func__, - cur_end + sizeof(struct ne_tensor) + NE_OBJECT_SIZE, ctx->mem_size); + NE_PRINT("%s: %d not enough space in the context's memory pool (needed %zu, ctx->mem_size available %zu)\n", + __func__, __LINE__, cur_end + sizeof(struct ne_tensor) + NE_OBJECT_SIZE, ctx->mem_size); assert(false); return NULL; } diff --git a/neural_speed/models/baichuan/baichuan.h b/neural_speed/models/baichuan/baichuan.h index d2cb1ad5d..2803d9617 100644 --- a/neural_speed/models/baichuan/baichuan.h +++ b/neural_speed/models/baichuan/baichuan.h @@ -23,10 +23,14 @@ enum baichuan_model { BAICHUAN_13B, }; -static const model_scratch baichuan_mem_req(int n_layers) { +static const model_scratch baichuan_mem_req(int n_layers, float scratch_size_ratio = 1.0f) { switch (n_layers) { case 40: - return {8192ull * MB, 8192ull * MB, 8192ull * MB}; + return { + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + }; default: MODEL_ASSERT(false); } diff --git a/neural_speed/models/baichuan/baichuan_utils.cpp b/neural_speed/models/baichuan/baichuan_utils.cpp index e755833cb..155bc1e61 100644 --- a/neural_speed/models/baichuan/baichuan_utils.cpp +++ b/neural_speed/models/baichuan/baichuan_utils.cpp @@ -75,7 +75,7 @@ void BAICHUAN::init(const char* path_model, model_context* ctx, int n_gpu_layer_ n_embd = hparams.n_embd; n_vocab = hparams.n_vocab; n_layer = hparams.n_layer; - scratch = baichuan_mem_req(n_layer); + scratch = baichuan_mem_req(n_layer, lctx.scratch_size_ratio); model.scratchs = scratch; } @@ -89,7 +89,7 @@ void BAICHUAN::load(model_context* ctx, model_progress_callback progress_callbac size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); ctx_size = ctx_size * 2; - fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); + fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); const auto& hparams = model.hparams; const int head_dim = n_embd / hparams.n_head; @@ -153,6 +153,9 @@ void BAICHUAN::load(model_context* ctx, model_progress_callback progress_callbac // this is the total memory required to run the inference const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory scratch.scratch0 + scratch.scratch1 + scratch.eval; + fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0); fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0); (void)n_gpu_layer; diff --git a/neural_speed/models/bloom/bloom.h b/neural_speed/models/bloom/bloom.h index e66ababe5..86b480bbd 100644 --- a/neural_speed/models/bloom/bloom.h +++ b/neural_speed/models/bloom/bloom.h @@ -23,10 +23,14 @@ enum bloom_model { BLOOM_7B, }; -static const model_scratch bloom_mem_req(int n_layers) { +static const model_scratch bloom_mem_req(int n_layers, float scratch_size_ratio = 1.0f) { switch (n_layers) { case 30: - return {4 * 2048ull * MB, 4 * 2048ull * MB, 4 * 4096ull * MB}; + return { + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + }; default: MODEL_ASSERT(false); } diff --git a/neural_speed/models/bloom/bloom_utils.cpp b/neural_speed/models/bloom/bloom_utils.cpp index 10d72bb9a..4f57767b0 100644 --- a/neural_speed/models/bloom/bloom_utils.cpp +++ b/neural_speed/models/bloom/bloom_utils.cpp @@ -73,7 +73,7 @@ void BLOOM::init(const char* path_model, model_context* ctx, int n_gpu_layer_, b n_embd = hparams.n_embd; n_vocab = hparams.n_vocab; n_layer = hparams.n_layer; - scratch = bloom_mem_req(n_layer); + scratch = bloom_mem_req(n_layer, lctx.scratch_size_ratio); model.scratchs = scratch; } @@ -86,7 +86,7 @@ void BLOOM::load(model_context* ctx, model_progress_callback progress_callback, size_t ctx_size; size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); - fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); + fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); // create the ne context lctx.model.buf.resize(ctx_size); @@ -195,6 +195,9 @@ void BLOOM::load(model_context* ctx, model_progress_callback progress_callback, // this is the total memory required to run the inference const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory scratch.scratch0 + scratch.scratch1 + scratch.eval; + fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0); fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0); (void)n_gpu_layer; diff --git a/neural_speed/models/chatglm/chatglm.h b/neural_speed/models/chatglm/chatglm.h index 853086bb3..a26194cde 100644 --- a/neural_speed/models/chatglm/chatglm.h +++ b/neural_speed/models/chatglm/chatglm.h @@ -23,11 +23,14 @@ enum chatglm_model { CHATGLM_6B, }; -static const model_scratch chatglm_mem_req(int n_layers) { +static const model_scratch chatglm_mem_req(int n_layers, float scratch_size_ratio = 1.0f) { switch (n_layers) { case 28: - return {4096ull * MB, 4096ull * MB, 8192ull * MB}; - // TODO(hengyu): add more variants besides 6B + return { + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + }; default: MODEL_ASSERT(false); } diff --git a/neural_speed/models/chatglm/chatglm2.h b/neural_speed/models/chatglm/chatglm2.h index 35db8175b..328ae7d7f 100644 --- a/neural_speed/models/chatglm/chatglm2.h +++ b/neural_speed/models/chatglm/chatglm2.h @@ -23,10 +23,14 @@ enum chatglm2_model { CHATGLM2_6B, }; -static const model_scratch chatglm_mem_req(int n_layers) { +static const model_scratch chatglm_mem_req(int n_layers, float scratch_size_ratio = 1.0f) { switch (n_layers) { case 28: - return {4096ull * MB, 4096ull * MB, 8192ull * MB}; + return { + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + }; default: MODEL_ASSERT(false); } diff --git a/neural_speed/models/chatglm/chatglm2_utils.cpp b/neural_speed/models/chatglm/chatglm2_utils.cpp index 3fd38d2f3..b2202b06a 100644 --- a/neural_speed/models/chatglm/chatglm2_utils.cpp +++ b/neural_speed/models/chatglm/chatglm2_utils.cpp @@ -78,7 +78,7 @@ void CHATGLM2::init(const char* path_model, model_context* ctx, int n_gpu_layer_ n_embd = hparams.n_embd; n_vocab = hparams.n_vocab; n_layer = hparams.n_layer; - scratch = chatglm_mem_req(n_layer); + scratch = chatglm_mem_req(n_layer, lctx.scratch_size_ratio); model.scratchs = scratch; } @@ -92,7 +92,7 @@ void CHATGLM2::load(model_context* ctx, model_progress_callback progress_callbac size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); ctx_size = ctx_size * 2; - fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); + fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); const auto& hparams = model.hparams; MODEL_ASSERT(("chatglm uses multi_query_group_num rather than n_head_kv", @@ -174,6 +174,9 @@ void CHATGLM2::load(model_context* ctx, model_progress_callback progress_callbac // this is the total memory required to run the inference const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory scratch.scratch0 + scratch.scratch1 + scratch.eval; + fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0); fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0); (void)n_gpu_layer; diff --git a/neural_speed/models/chatglm/chatglm_utils.cpp b/neural_speed/models/chatglm/chatglm_utils.cpp index e11579e90..665498cee 100644 --- a/neural_speed/models/chatglm/chatglm_utils.cpp +++ b/neural_speed/models/chatglm/chatglm_utils.cpp @@ -72,7 +72,7 @@ void CHATGLM::init(const char* path_model, model_context* ctx, int n_gpu_layer_, n_embd = hparams.n_embd; n_vocab = hparams.n_vocab; n_layer = hparams.n_layer; - scratch = chatglm_mem_req(n_layer); + scratch = chatglm_mem_req(n_layer, lctx.scratch_size_ratio); model.scratchs = scratch; } @@ -86,7 +86,7 @@ void CHATGLM::load(model_context* ctx, model_progress_callback progress_callback size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); ctx_size = ctx_size * 2; - fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); + fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); const auto& hparams = model.hparams; const int head_dim = n_embd / hparams.n_head; @@ -160,6 +160,9 @@ void CHATGLM::load(model_context* ctx, model_progress_callback progress_callback // this is the total memory required to run the inference const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory scratch.scratch0 + scratch.scratch1 + scratch.eval; + fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0); fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0); (void)n_gpu_layer; diff --git a/neural_speed/models/falcon/falcon.h b/neural_speed/models/falcon/falcon.h index b837c141a..3d6ea8b36 100644 --- a/neural_speed/models/falcon/falcon.h +++ b/neural_speed/models/falcon/falcon.h @@ -23,14 +23,26 @@ enum falcon_model { FALCON_7B, }; -static const model_scratch falcon_mem_req(int n_layers) { +static const model_scratch falcon_mem_req(int n_layers, float scratch_size_ratio = 1.0f) { switch (n_layers) { case 32: - return {2048ull * MB, 2048ull * MB, 4096ull * MB}; + return { + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + }; case 60: - return {2 * 2048ull * MB, 2 * 2048ull * MB, 2 * 4096ull * MB}; + return { + static_cast(scratch_size_ratio * 2 * 3072) * MB, + static_cast(scratch_size_ratio * 2 * 2048) * MB, + static_cast(scratch_size_ratio * 2 * 3072) * MB, + }; case 80: - return {3 * 2048ull * MB, 3 * 2048ull * MB, 3 * 4096ull * MB}; + return { + static_cast(scratch_size_ratio * 3 * 3072) * MB, + static_cast(scratch_size_ratio * 3 * 2048) * MB, + static_cast(scratch_size_ratio * 3 * 3072) * MB, + }; default: MODEL_ASSERT(false); } diff --git a/neural_speed/models/falcon/falcon_utils.cpp b/neural_speed/models/falcon/falcon_utils.cpp index b52919d7f..ea27d7946 100644 --- a/neural_speed/models/falcon/falcon_utils.cpp +++ b/neural_speed/models/falcon/falcon_utils.cpp @@ -77,7 +77,7 @@ void FALCON::init(const char* path_model, model_context* ctx, int n_gpu_layer_, n_vocab = hparams.n_vocab; n_layer = hparams.n_layer; n_head_kv = hparams.n_head_kv; - scratch = falcon_mem_req(n_layer); + scratch = falcon_mem_req(n_layer, lctx.scratch_size_ratio); model.scratchs = scratch; } @@ -90,7 +90,7 @@ void FALCON::load(model_context* ctx, model_progress_callback progress_callback, size_t ctx_size; size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); - fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); + fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); // create the ne context lctx.model.buf.resize(ctx_size); @@ -204,6 +204,9 @@ void FALCON::load(model_context* ctx, model_progress_callback progress_callback, // this is the total memory required to run the inference const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory scratch.scratch0 + scratch.scratch1 + scratch.eval; + fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0); fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0); (void)n_gpu_layer; diff --git a/neural_speed/models/gptj/gptj.h b/neural_speed/models/gptj/gptj.h index f1edcf97c..dacafc2b4 100644 --- a/neural_speed/models/gptj/gptj.h +++ b/neural_speed/models/gptj/gptj.h @@ -26,14 +26,14 @@ enum gptj_model { GPTJ_65B, }; -static const model_scratch gptj_mem_req(int n_layers, float enlarge_scale = 1.0f) { +static const model_scratch gptj_mem_req(int n_layers, float scratch_size_ratio = 1.0f) { switch (n_layers) { case 28: // should be enough for batch=8 * beam=4 return { - static_cast(enlarge_scale * 3072) * MB, - static_cast(enlarge_scale * 2048) * MB, - static_cast(enlarge_scale * 3072) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, }; default: MODEL_ASSERT(false); diff --git a/neural_speed/models/gptj/gptj_utils.cpp b/neural_speed/models/gptj/gptj_utils.cpp index 2e6702e48..004a5b423 100644 --- a/neural_speed/models/gptj/gptj_utils.cpp +++ b/neural_speed/models/gptj/gptj_utils.cpp @@ -75,7 +75,7 @@ void GPTJ::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bo n_embd = hparams.n_embd; n_vocab = hparams.n_vocab; n_layer = hparams.n_layer; - scratch = gptj_mem_req(n_layer, lctx.model_scratch_enlarge_scale); + scratch = gptj_mem_req(n_layer, lctx.scratch_size_ratio); model.scratchs = scratch; } @@ -88,7 +88,7 @@ void GPTJ::load(model_context* ctx, model_progress_callback progress_callback, v size_t ctx_size; size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); - fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); + fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); // create the ne context lctx.model.buf.resize(ctx_size); @@ -154,6 +154,9 @@ void GPTJ::load(model_context* ctx, model_progress_callback progress_callback, v // this is the total memory required to run the inference const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory scratch.scratch0 + scratch.scratch1 + scratch.eval; + fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0); fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0); (void)n_gpu_layer; diff --git a/neural_speed/models/gptneox/gptneox.h b/neural_speed/models/gptneox/gptneox.h index 1304b386f..617827618 100644 --- a/neural_speed/models/gptneox/gptneox.h +++ b/neural_speed/models/gptneox/gptneox.h @@ -23,14 +23,26 @@ enum gptneox_model { GPTNEOX_7B, }; -static const model_scratch gptneox_mem_req(int n_layers) { +static const model_scratch gptneox_mem_req(int n_layers, float scratch_size_ratio = 1.0f) { switch (n_layers) { case 44: - return {2048ull * MB, 2048ull * MB, 4096ull * MB}; + return { + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + }; case 32: - return {512ull * MB, 512ull * MB, 1026ull * MB}; + return { + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + }; case 28: // 5.8B - return {512ull * MB, 512ull * MB, 1024ull * MB}; + return { + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + }; default: MODEL_ASSERT(false); } diff --git a/neural_speed/models/gptneox/gptneox_utils.cpp b/neural_speed/models/gptneox/gptneox_utils.cpp index 401c42803..7a7f0c967 100644 --- a/neural_speed/models/gptneox/gptneox_utils.cpp +++ b/neural_speed/models/gptneox/gptneox_utils.cpp @@ -74,7 +74,7 @@ void GPTNEOX::init(const char* path_model, model_context* ctx, int n_gpu_layer_, n_embd = hparams.n_embd; n_vocab = hparams.n_vocab; n_layer = hparams.n_layer; - scratch = gptneox_mem_req(n_layer); + scratch = gptneox_mem_req(n_layer, lctx.scratch_size_ratio); model.scratchs = scratch; } @@ -87,7 +87,7 @@ void GPTNEOX::load(model_context* ctx, model_progress_callback progress_callback size_t ctx_size; size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); - fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); + fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); // create the ne context lctx.model.buf.resize(ctx_size); @@ -152,6 +152,9 @@ void GPTNEOX::load(model_context* ctx, model_progress_callback progress_callback // this is the total memory required to run the inference const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory scratch.scratch0 + scratch.scratch1 + scratch.eval; + fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0); fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0); (void)n_gpu_layer; diff --git a/neural_speed/models/llama/llama.h b/neural_speed/models/llama/llama.h index 5c9f07e58..99fb65a72 100644 --- a/neural_speed/models/llama/llama.h +++ b/neural_speed/models/llama/llama.h @@ -26,37 +26,37 @@ enum llama_model { LLAMA_65B, }; -static const model_scratch llama_mem_req(int n_layers, float enlarge_scale = 1.0f) { +static const model_scratch llama_mem_req(int n_layers, float scratch_size_ratio = 1.0f) { switch (n_layers) { case 32: return { - static_cast(enlarge_scale * 1024) * MB, - static_cast(enlarge_scale * 1024) * MB, - static_cast(enlarge_scale * 1608) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, }; case 40: return { - static_cast(enlarge_scale * 512) * MB, - static_cast(enlarge_scale * 512) * MB, - static_cast(enlarge_scale * 1608) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, }; case 48: return { - static_cast(enlarge_scale * 512) * MB, - static_cast(enlarge_scale * 512) * MB, - static_cast(enlarge_scale * 2366) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, }; case 60: return { - static_cast(enlarge_scale * 512) * MB, - static_cast(enlarge_scale * 512) * MB, - static_cast(enlarge_scale * 3124) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, }; case 80: return { - static_cast(enlarge_scale * 2048) * MB, - static_cast(enlarge_scale * 2048) * MB, - static_cast(enlarge_scale * 10240) * MB, + static_cast(scratch_size_ratio * 3072) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 3072 * 3) * MB, }; default: MODEL_ASSERT(false); diff --git a/neural_speed/models/llama/llama_utils.cpp b/neural_speed/models/llama/llama_utils.cpp index 9757156c0..6685dacdd 100644 --- a/neural_speed/models/llama/llama_utils.cpp +++ b/neural_speed/models/llama/llama_utils.cpp @@ -81,7 +81,7 @@ void Llama::init(const char* path_model, model_context* ctx, int n_gpu_layer_, b n_head = hparams.n_head; n_expert = hparams.n_experts; n_expert_used = hparams.n_experts_used; - scratch = llama_mem_req(n_layer, lctx.model_scratch_enlarge_scale); + scratch = llama_mem_req(n_layer, lctx.scratch_size_ratio); model.scratchs = scratch; } @@ -93,7 +93,7 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback, size_t ctx_size; size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); - fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); + fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); // create the ne context lctx.model.buf.resize(ctx_size); @@ -226,6 +226,9 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback, // this is the total memory required to run the inference const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory scratch.scratch0 + scratch.scratch1 + scratch.eval; + fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0); fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0); (void)n_gpu_layer; diff --git a/neural_speed/models/model_utils/model_config.h b/neural_speed/models/model_utils/model_config.h index 816780e32..a30935dc0 100644 --- a/neural_speed/models/model_utils/model_config.h +++ b/neural_speed/models/model_utils/model_config.h @@ -103,7 +103,7 @@ struct gpt_params { float length_penalty = 1.0f; // exponential penalty to the length in beam search generation bool do_early_stopping = false; // early stopping in beam search generation - float model_scratch_enlarge_scale = 1.0f; // model memory scratch enlarge scale + float scratch_size_ratio = 1.0f; // model memory scratch enlarge scale }; bool gpt_params_parse(int argc, char** argv, gpt_params& params); diff --git a/neural_speed/models/model_utils/model_files.h b/neural_speed/models/model_utils/model_files.h index e71ee94f0..11ac03260 100644 --- a/neural_speed/models/model_utils/model_files.h +++ b/neural_speed/models/model_utils/model_files.h @@ -1014,16 +1014,16 @@ struct model_file_loader { } else if (model_magic == NE) { std::cout << "Loading the bin file with NE format..." << std::endl; fseek(file.fp, 0, SEEK_SET); - read_ne_magic(); - read_ne_hparams(); - read_ne_vocab(); + load_ne_magic(); + load_ne_hparams(); + load_ne_vocab(); read_tensor_metadata(file_idx, tensors_map); } else { throw format("unknown file format model_maigc = %d", model_magic); } } - void read_ne_magic() { + void load_ne_magic() { uint32_t magic = file.read_u32(); if (magic == MODEL_FILE_MAGIC_NE) { @@ -1075,7 +1075,7 @@ struct model_file_loader { return model_magic; } - void read_ne_hparams() { + void load_ne_hparams() { unsigned int count = 0; hparams.n_vocab = file.read_u32(); hparams.n_embd = file.read_u32(); @@ -1101,8 +1101,8 @@ struct model_file_loader { hparams.do_layer_norm_before = bool(file.read_u32()); printf("%-16s %d.hparams.ftype = %-30d\n", __func__, count++, hparams.ftype); printf("%-16s %d.hparams.max_seq_len = %-30d\n", __func__, count++, hparams.max_seq_len); - printf("%-16s %d.hparams.alibi_bias_max = %-30f\n", __func__, count++, hparams.alibi_bias_max); - printf("%-16s %d.hparams.clip_qkv = %-30f\n", __func__, count++, hparams.clip_qkv); + printf("%-16s %d.hparams.alibi_bias_max = %-30.3f\n", __func__, count++, hparams.alibi_bias_max); + printf("%-16s %d.hparams.clip_qkv = %-30.3f\n", __func__, count++, hparams.clip_qkv); printf("%-16s %d.hparams.par_res = %-30d\n", __func__, count++, hparams.par_res); printf("%-16s %d.hparams.word_embed_proj_dim = %-30d\n", __func__, count++, hparams.word_embed_proj_dim); printf("%-16s %d.hparams.do_layer_norm_before = %-30d\n", __func__, count++, hparams.do_layer_norm_before); @@ -1124,13 +1124,13 @@ struct model_file_loader { file.read_raw(&hparams.freq_base, sizeof(float)); file.read_raw(&hparams.freq_scale, sizeof(float)); printf("%-16s %d.hparams.inner_hidden_size = %-30d\n", __func__, count++, hparams.inner_hidden_size); - printf("%-16s %d.hparams.freq_base = %-30f\n", __func__, count++, hparams.freq_base); - printf("%-16s %d.hparams.freq_scale = %-30f\n", __func__, count++, hparams.freq_scale); + printf("%-16s %d.hparams.freq_base = %-30.3f\n", __func__, count++, hparams.freq_base); + printf("%-16s %d.hparams.freq_scale = %-30.3f\n", __func__, count++, hparams.freq_scale); file.read_raw(&hparams.rope_scaling_factor, sizeof(float)); hparams.original_max_position_embeddings = file.read_u32(); hparams.use_yarn = file.read_u32(); - printf("%-16s %d.hparams.rope_scaling_factor = %-30f\n", __func__, count++, hparams.rope_scaling_factor); + printf("%-16s %d.hparams.rope_scaling_factor = %-30.3f\n", __func__, count++, hparams.rope_scaling_factor); printf("%-16s %d.hparams.original_max_position_embeddings = %-30d\n", __func__, count++, hparams.original_max_position_embeddings); printf("%-16s %d.hparams.use_yarn = %-30d\n", __func__, count++, hparams.use_yarn); @@ -1140,7 +1140,7 @@ struct model_file_loader { } } - void read_ne_vocab() { + void load_ne_vocab() { unsigned int count = 0; unsigned int ne_hparams_total = 25; file.read_raw(&vocab.bos_token_id, sizeof(model_vocab::id)); diff --git a/neural_speed/models/model_utils/model_types.h b/neural_speed/models/model_utils/model_types.h index 6833c94ea..611a54088 100644 --- a/neural_speed/models/model_utils/model_types.h +++ b/neural_speed/models/model_utils/model_types.h @@ -320,7 +320,7 @@ struct model_context { size_t mem_per_token = 0; - float model_scratch_enlarge_scale = 1.0f; // model memory scratch enlarge scale + float scratch_size_ratio = 1.0f; // model memory scratch enlarge scale // decode output (3-dimensional array: [batch_size] [n_tokens] [n_vocab]) std::vector logits; @@ -441,7 +441,7 @@ struct model_context_params { // global generation config generation_config gen_conf; // model memory scratch enlarge scale - float model_scratch_enlarge_scale; + float scratch_size_ratio; // called with a progress value between 0 and 1, pass nullptr to disable model_progress_callback progress_callback; diff --git a/neural_speed/models/model_utils/model_utils.cpp b/neural_speed/models/model_utils/model_utils.cpp index 896ed4231..eb6c0b4f0 100644 --- a/neural_speed/models/model_utils/model_utils.cpp +++ b/neural_speed/models/model_utils/model_utils.cpp @@ -188,7 +188,7 @@ struct model_context_params model_context_default_params() { /*cont_batching =*/true, /*.max_request_num =*/1, /*.gen_conf =*/generation_config(), - /*model_scratch_enlarge_scale =*/1.0f, + /*scratch_size_ratio =*/1.0f, /*.progress_callback =*/nullptr, /*.progress_callback_user_data =*/nullptr, }; @@ -911,7 +911,9 @@ struct model_context* model_init_from_file(const char* path_model, struct model_ } ctx->cont_batching = params.cont_batching; ctx->generation_conf = params.gen_conf; - ctx->model_scratch_enlarge_scale = params.model_scratch_enlarge_scale; + + ctx->scratch_size_ratio = params.scratch_size_ratio * params.max_request_num * params.beam_size; + const model_archs arch = params.arch; // the type so that kv-cache allocated according to this type must be large enough @@ -1268,6 +1270,13 @@ struct model_context* model_init_from_gpt_params(const gpt_params& params) { lparams.n_gpu_layers = params.n_gpu_layers; lparams.seed = params.seed; lparams.kv_type = params.memory_type; + + // TODO(Yi): MHA FOR LONG TOKENS + int32_t long_tokens = 6144; + if (lparams.n_ctx > long_tokens) { + lparams.kv_type = KV_MEM_TYPE_F16; + } + lparams.use_mmap = params.use_mmap; lparams.use_mlock = params.use_mlock; lparams.logits_all = params.perplexity; @@ -1284,7 +1293,7 @@ struct model_context* model_init_from_gpt_params(const gpt_params& params) { lparams.gen_conf.min_new_tokens = params.min_new_tokens; lparams.gen_conf.length_penalty = params.length_penalty; lparams.gen_conf.do_early_stopping = params.do_early_stopping; - lparams.model_scratch_enlarge_scale = params.model_scratch_enlarge_scale; + lparams.scratch_size_ratio = params.scratch_size_ratio; NE_ASSERT(("Start size cannot be greater than the maximum context size!", lparams.n_keep < lparams.n_ctx)); diff --git a/neural_speed/models/mpt/mpt.h b/neural_speed/models/mpt/mpt.h index 9ce987602..0a4d614e9 100644 --- a/neural_speed/models/mpt/mpt.h +++ b/neural_speed/models/mpt/mpt.h @@ -24,12 +24,20 @@ enum mpt_model { MPT_30B, }; -static const model_scratch mpt_mem_req(int n_layers) { +static const model_scratch mpt_mem_req(int n_layers, float scratch_size_ratio = 1.0f) { switch (n_layers) { case 32: - return {2048ull * MB, 2048ull * MB, 4096ull * MB}; + return { + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + }; case 48: - return {4096ull * MB, 4096ull * MB, 8192ull * MB}; + return { + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 8192) * MB, + }; default: MODEL_ASSERT(false); } diff --git a/neural_speed/models/mpt/mpt_utils.cpp b/neural_speed/models/mpt/mpt_utils.cpp index 489f4d1bc..ba6e3fbed 100644 --- a/neural_speed/models/mpt/mpt_utils.cpp +++ b/neural_speed/models/mpt/mpt_utils.cpp @@ -76,7 +76,7 @@ void MPT::init(const char* path_model, model_context* ctx, int n_gpu_layer_, boo n_embd = hparams.n_embd; n_vocab = hparams.n_vocab; n_layer = hparams.n_layer; - scratch = mpt_mem_req(n_layer); + scratch = mpt_mem_req(n_layer, lctx.scratch_size_ratio); model.scratchs = scratch; } @@ -89,7 +89,7 @@ void MPT::load(model_context* ctx, model_progress_callback progress_callback, vo size_t ctx_size; size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); - fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); + fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); // create the ne context lctx.model.buf.resize(ctx_size); @@ -175,6 +175,9 @@ void MPT::load(model_context* ctx, model_progress_callback progress_callback, vo // this is the total memory required to run the inference const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory scratch.scratch0 + scratch.scratch1 + scratch.eval; + fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0); fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0); (void)n_gpu_layer; diff --git a/neural_speed/models/opt/opt.h b/neural_speed/models/opt/opt.h index 10668653a..2590d9be5 100644 --- a/neural_speed/models/opt/opt.h +++ b/neural_speed/models/opt/opt.h @@ -34,20 +34,44 @@ enum opt_model { }; // TODO naive memory buffer size -static const model_scratch opt_mem_req(int n_layers) { +static const model_scratch opt_mem_req(int n_layers, float scratch_size_ratio = 1.0f) { switch (n_layers) { case 12: // OPT_125M - return {512ull * MB, 512ull * MB, 1024ull * MB}; + return { + static_cast(scratch_size_ratio * 512) * MB, + static_cast(scratch_size_ratio * 512) * MB, + static_cast(scratch_size_ratio * 1024) * MB, + }; case 24: // OPT_350M, OPT_1DOT3B - return {1024ull * MB, 1024ull * MB, 2048ull * MB}; + return { + static_cast(scratch_size_ratio * 1024) * MB, + static_cast(scratch_size_ratio * 1024) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + }; case 32: // OPT_2DOT7B OPT_6DOT7B - return {2048ull * MB, 2048ull * MB, 4096ull * MB}; + return { + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + }; case 40: - return {2560ull * MB, 2560ull * MB, 5120ull * MB}; + return { + static_cast(scratch_size_ratio * 2560) * MB, + static_cast(scratch_size_ratio * 2560) * MB, + static_cast(scratch_size_ratio * 5120) * MB, + }; case 48: - return {3072ull * MB, 3072ull * MB, 6144ull * MB}; + return { + static_cast(scratch_size_ratio * 3072) * MB, + static_cast(scratch_size_ratio * 3072) * MB, + static_cast(scratch_size_ratio * 6144) * MB, + }; case 64: - return {4096ull * MB, 4096ull * MB, 8192ull * MB}; + return { + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 8192) * MB, + }; default: MODEL_ASSERT(false); } diff --git a/neural_speed/models/opt/opt_utils.cpp b/neural_speed/models/opt/opt_utils.cpp index 4be7c2472..2641bcb51 100644 --- a/neural_speed/models/opt/opt_utils.cpp +++ b/neural_speed/models/opt/opt_utils.cpp @@ -74,7 +74,7 @@ void OPT::init(const char* path_model, model_context* ctx, int n_gpu_layer_, boo word_embed_proj_dim = hparams.word_embed_proj_dim; max_seq_len = hparams.max_seq_len; do_layer_norm_before = hparams.do_layer_norm_before; - scratch = opt_mem_req(n_layer); + scratch = opt_mem_req(n_layer, lctx.scratch_size_ratio); model.scratchs = scratch; } @@ -87,7 +87,7 @@ void OPT::load(model_context* ctx, model_progress_callback progress_callback, vo size_t ctx_size; size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); - fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); + fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); // create the ne context lctx.model.buf.resize(ctx_size); @@ -168,6 +168,9 @@ void OPT::load(model_context* ctx, model_progress_callback progress_callback, vo // this is the total memory required to run the inference const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory scratch.scratch0 + scratch.scratch1 + scratch.eval; + fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0); fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0); (void)n_gpu_layer; diff --git a/neural_speed/models/phi/phi.h b/neural_speed/models/phi/phi.h index b767c34bf..8a3618d62 100644 --- a/neural_speed/models/phi/phi.h +++ b/neural_speed/models/phi/phi.h @@ -23,12 +23,20 @@ enum new_model { PHI, }; -static const model_scratch phi_mem_req(int n_layers) { +static const model_scratch phi_mem_req(int n_layers, float scratch_size_ratio = 1.0f) { switch (n_layers) { case 24: - return {512ull * MB, 512ull * MB, 1026ull * MB}; + return { + static_cast(scratch_size_ratio * 512) * MB, + static_cast(scratch_size_ratio * 512) * MB, + static_cast(scratch_size_ratio * 1024) * MB, + }; case 32: - return {1024ull * MB, 1024ull * MB, 1026ull * MB}; + return { + static_cast(scratch_size_ratio * 1024) * MB, + static_cast(scratch_size_ratio * 1024) * MB, + static_cast(scratch_size_ratio * 1024) * MB, + }; default: MODEL_ASSERT(false); } diff --git a/neural_speed/models/phi/phi_utils.cpp b/neural_speed/models/phi/phi_utils.cpp index 16ac7e3d7..018318db2 100644 --- a/neural_speed/models/phi/phi_utils.cpp +++ b/neural_speed/models/phi/phi_utils.cpp @@ -74,7 +74,7 @@ void phi::init(const char* path_model, model_context* ctx, int n_gpu_layer_, boo n_vocab = hparams.n_vocab; n_layer = hparams.n_layer; n_embd = hparams.n_embd; - scratch = phi_mem_req(n_layer); + scratch = phi_mem_req(n_layer, lctx.scratch_size_ratio); model.scratchs = scratch; } @@ -87,7 +87,7 @@ void phi::load(model_context* ctx, model_progress_callback progress_callback, vo size_t ctx_size; size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); - fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); + fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); // create the ne context lctx.model.buf.resize(ctx_size); @@ -159,6 +159,9 @@ void phi::load(model_context* ctx, model_progress_callback progress_callback, vo // this is the total memory required to run the inference const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory scratch.scratch0 + scratch.scratch1 + scratch.eval; + fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0); fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0); (void)n_gpu_layer; diff --git a/neural_speed/models/qwen/qwen.h b/neural_speed/models/qwen/qwen.h index ec2c65357..3fb54b7c6 100644 --- a/neural_speed/models/qwen/qwen.h +++ b/neural_speed/models/qwen/qwen.h @@ -24,14 +24,26 @@ enum QWEN_model { QWEN_14B, }; -static const model_scratch qwen_mem_req(int n_layers) { +static const model_scratch qwen_mem_req(int n_layers, float scratch_size_ratio = 1.0f) { switch (n_layers) { case 40: - return {2048ull * MB, 2048ull * MB, 4096ull * MB}; + return { + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + }; case 32: - return {1024ull * MB, 1024ull * MB, 1608ull * MB}; + return { + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + }; case 24: - return {512ull * MB, 512ull * MB, 1026ull * MB}; + return { + static_cast(scratch_size_ratio * 4096) * MB, + static_cast(scratch_size_ratio * 2048) * MB, + static_cast(scratch_size_ratio * 4096) * MB, + }; default: MODEL_ASSERT(false); } diff --git a/neural_speed/models/qwen/qwen_utils.cpp b/neural_speed/models/qwen/qwen_utils.cpp index 77bba671b..32e222d76 100644 --- a/neural_speed/models/qwen/qwen_utils.cpp +++ b/neural_speed/models/qwen/qwen_utils.cpp @@ -70,7 +70,7 @@ void QWEN::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bo n_embd = hparams.n_embd; n_vocab = hparams.n_vocab; n_layer = hparams.n_layer; - scratch = qwen_mem_req(n_layer); + scratch = qwen_mem_req(n_layer, lctx.scratch_size_ratio); model.scratchs = scratch; } @@ -83,7 +83,7 @@ void QWEN::load(model_context* ctx, model_progress_callback progress_callback, v size_t ctx_size; size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); - fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); + fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); // create the ne context lctx.model.buf.resize(ctx_size); @@ -201,6 +201,9 @@ void QWEN::load(model_context* ctx, model_progress_callback progress_callback, v // this is the total memory required to run the inference const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory scratch.scratch0 + scratch.scratch1 + scratch.eval; + fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0); fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0); (void)n_gpu_layer; diff --git a/neural_speed/models/starcoder/starcoder.h b/neural_speed/models/starcoder/starcoder.h index 72ab9234c..5d4317830 100644 --- a/neural_speed/models/starcoder/starcoder.h +++ b/neural_speed/models/starcoder/starcoder.h @@ -26,14 +26,26 @@ enum starcoder_model { STARCODER_65B, }; -static const model_scratch starcoder_mem_req(int n_layers) { +static const model_scratch starcoder_mem_req(int n_layers, float scratch_size_ratio = 1.0f) { switch (n_layers) { case 24: - return {8192ull * MB, 8192ull * MB, 8192ull * MB}; + return { + static_cast(scratch_size_ratio * 3072 * 2) * MB, + static_cast(scratch_size_ratio * 2048 * 2) * MB, + static_cast(scratch_size_ratio * 3072 * 2) * MB, + }; case 36: - return {8192ull * MB, 8192ull * MB, 8192ull * MB}; + return { + static_cast(scratch_size_ratio * 3072 * 2) * MB, + static_cast(scratch_size_ratio * 2048 * 2) * MB, + static_cast(scratch_size_ratio * 3072 * 2) * MB, + }; case 40: - return {32768ull * MB, 32768ull * MB, 32768ull * MB}; + return { + static_cast(scratch_size_ratio * 3072 * 8) * MB, + static_cast(scratch_size_ratio * 2048 * 8) * MB, + static_cast(scratch_size_ratio * 3072 * 8) * MB, + }; default: MODEL_ASSERT(false); } diff --git a/neural_speed/models/starcoder/starcoder_utils.cpp b/neural_speed/models/starcoder/starcoder_utils.cpp index 13ad06afe..6d8d142fb 100644 --- a/neural_speed/models/starcoder/starcoder_utils.cpp +++ b/neural_speed/models/starcoder/starcoder_utils.cpp @@ -75,7 +75,7 @@ void STARCODER::init(const char* path_model, model_context* ctx, int n_gpu_layer n_embd = hparams.n_embd; n_vocab = hparams.n_vocab; n_layer = hparams.n_layer; - scratch = starcoder_mem_req(n_layer); + scratch = starcoder_mem_req(n_layer, lctx.scratch_size_ratio); model.scratchs = scratch; } @@ -88,7 +88,7 @@ void STARCODER::load(model_context* ctx, model_progress_callback progress_callba size_t ctx_size; size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); - fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); + fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0); // create the ne context lctx.model.buf.resize(ctx_size); @@ -203,6 +203,9 @@ void STARCODER::load(model_context* ctx, model_progress_callback progress_callba // this is the total memory required to run the inference const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory scratch.scratch0 + scratch.scratch1 + scratch.eval; + fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0); + fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0); fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0); (void)n_gpu_layer; diff --git a/scripts/python_api_example_for_model_server.py b/scripts/python_api_example_for_model_server.py index c9889ef70..e0168ca62 100644 --- a/scripts/python_api_example_for_model_server.py +++ b/scripts/python_api_example_for_model_server.py @@ -34,7 +34,7 @@ def main(args_in: Optional[List[str]] = None) -> None: help="maximum number of running requests (or queries) for model inference: Int", required=False, default=8) parser.add_argument("--print_log", action="store_true", help="print server running logs") - parser.add_argument("--model_scratch_enlarge_scale", type=float, + parser.add_argument("--scratch_size_ratio", type=float, help="scale for enlarge memory for model inference: Float", required=False, default=1.0) parser.add_argument("--memory_dtype", type=str, help="KV cache memory dtype: String", @@ -86,7 +86,7 @@ def f_response(res, working): threads=args.threads, max_request_num=args.max_request_num, print_log=args.print_log, - model_scratch_enlarge_scale = args.model_scratch_enlarge_scale, + scratch_size_ratio = args.scratch_size_ratio, memory_dtype= args.memory_dtype, ) for i in range(len(prompts)): diff --git a/tests/test_model_server.py b/tests/test_model_server.py index d7e7fd102..3ec66e8ef 100644 --- a/tests/test_model_server.py +++ b/tests/test_model_server.py @@ -85,7 +85,7 @@ def f_response(res, working): max_request_num=8, threads=56, print_log=False, - model_scratch_enlarge_scale = 1.0, + scratch_size_ratio = 1.0, memory_dtype= md, ) for i in range(len(prompts)):