diff --git a/developer_document.md b/developer_document.md
index 3b6e41cde..19ad23cf4 100644
--- a/developer_document.md
+++ b/developer_document.md
@@ -8,7 +8,7 @@ For simplicity, we take [polyglot](https://huggingface.co/EleutherAI/polyglot-ko
Firstly, we need to add its temp buffer in its [related model-arch header file](neural_speed/models/gptneox/gptneox.h) and [re-compile](README.md#Install).
```diff
-static const model_scratch gptneox_mem_req(int n_layers) {
+static const model_scratch gptneox_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
switch (n_layers) {
case 44:
return {2048ull * MB, 2048ull * MB, 4096ull * MB};
@@ -167,7 +167,7 @@ and update [model_name_to_arch()](neural_speed/models/model_utils/model_types.h#
+ NEW_MODEL_13B,
+};
-+static const model_scratch new_model_mem_req(int n_layers) {
++static const model_scratch new_model_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
+ switch (n_layers) {
+ case N:
+ return {8192ull * MB, 8192ull * MB, 8192ull * MB};
@@ -390,7 +390,7 @@ We recommend to use continuous batching way since it has no padding effect and c
+ ne_view_2d(ctx0, KQV_merged_contiguous, head_size * n_head, attn_sl * attn_bs, head_size * n_head * ne_element_size(KQV_merged_contiguous), ne_element_size(KQV_merged_contiguous) * off_sl)));
+ off_sl += head_size * n_head * attn_sl * attn_bs;
```
->Note: You can set larger [`NE_MAX_NODES`](neural_speed/core/ne.h#43) and [`model_scratch_enlarge_scale`](neural_speed/models/llama/llama.h#29) values if out of memory when the inputs' batch size becomes larger.
+>Note: You can set larger [`NE_MAX_NODES`](neural_speed/core/ne.h#43) and [`scratch_size_ratio`](neural_speed/models/llama/llama.h#29) values if out of memory when the inputs' batch size becomes larger.
## 2.3. Application
- Q4_0 quant : We can quantize the model generated by convert by adding a quant layer class to quantize it into an int4 low-bit file, so as to obtain better inference performance. Register quant layer class in your new_model_utils.cpp, just like [gptneox_utils.cpp](neural_speed/models/gptneox/gptneox_utils.cpp#L163), replace `gptneox_quant_layer` to your `new_model_quant_layer`.
diff --git a/docs/gptq_and_awq.md b/docs/gptq_and_awq.md
index ec66f5e43..d8dfbc43f 100644
--- a/docs/gptq_and_awq.md
+++ b/docs/gptq_and_awq.md
@@ -13,7 +13,7 @@ Validated GPTQ & AWQ models directly from the HuggingFace:
* [Qwen-7B-Chat-GPTQ](https://huggingface.co/TheBloke/Qwen-7B-Chat-GPTQ) & [Qwen-7B-Chat-AWQ](https://huggingface.co/TheBloke/Qwen-7B-Chat-AWQ) & * [Qwen1.5-7B-Chat-GPTQ-Int4](https://huggingface.co/Qwen/Qwen1.5-7B-Chat-GPTQ-Int4)
* [SOLAR-10.7B-v1.0-GPTQ](https://huggingface.co/TheBloke/SOLAR-10.7B-v1.0-GPTQ)
-Please check more validated GPTQ & AWQ models in the list of [supported_models](./docs/supported_models.md).
+Please check more validated GPTQ & AWQ models in the list of [supported_models](./supported_models.md).
## Examples
diff --git a/docs/supported_models.md b/docs/supported_models.md
index ef3f7d362..4aad26d29 100644
--- a/docs/supported_models.md
+++ b/docs/supported_models.md
@@ -10,6 +10,7 @@ Neural Speed supports the following models:
INT8 |
INT4 |
Transformer Version |
+ Max tokens length |
RTN |
@@ -36,6 +37,7 @@ Neural Speed supports the following models:
✅ |
✅ |
Latest |
+ 4096 |
LLaMA-7B,
@@ -49,6 +51,7 @@ Neural Speed supports the following models:
| ✅ |
✅ |
Latest |
+ 2048 |
CodeLlama-7b |
✅ |
@@ -60,6 +63,7 @@ Neural Speed supports the following models:
✅ |
✅ |
Latest |
+ 16384 |
Solar-10.7B |
@@ -72,6 +76,7 @@ Neural Speed supports the following models:
✅ |
✅ |
Latest |
+ 4096 |
Neural-Chat-7B-v3-1,
@@ -85,6 +90,7 @@ Neural Speed supports the following models:
| ✅ |
✅ |
Latest |
+ 32768 |
Mistral-7B,
@@ -98,6 +104,7 @@ Neural Speed supports the following models:
| ✅ |
✅ |
4.36.0 or newer |
+ 32768 |
Qwen-7B,
@@ -113,6 +120,7 @@ Neural Speed supports the following models:
| ✅ |
✅ |
Latest |
+ 8192 / 32768 |
GPT-J-6B |
@@ -125,6 +133,7 @@ Neural Speed supports the following models:
✅ |
✅ |
Latest |
+ 2048 |
GPT-NeoX-20B |
@@ -137,6 +146,7 @@ Neural Speed supports the following models:
|
|
Latest |
+ 2048 |
Dolly-v2-3B |
@@ -149,6 +159,7 @@ Neural Speed supports the following models:
|
|
4.28.1 or newer |
+ 2048 |
MPT-7B,
@@ -162,6 +173,7 @@ Neural Speed supports the following models:
| |
|
Latest |
+ 2048 |
Falcon-7B,
@@ -175,6 +187,7 @@ Neural Speed supports the following models:
| |
|
Latest |
+ 2048 |
BLOOM-7B |
@@ -187,6 +200,7 @@ Neural Speed supports the following models:
|
|
Latest |
+ 2048 |
OPT-125m,
@@ -201,6 +215,7 @@ Neural Speed supports the following models:
| |
|
Latest |
+ 2048 |
ChatGLM-6B,
@@ -214,6 +229,7 @@ Neural Speed supports the following models:
| |
|
4.33.1 |
+ 2048 / 32768 |
Baichuan-13B-Chat,
@@ -227,6 +243,7 @@ Neural Speed supports the following models:
| |
|
4.33.1 |
+ 4096 |
phi-2,
@@ -241,6 +258,7 @@ Neural Speed supports the following models:
| |
|
Latest |
+ 2048 |
Whisper-tiny,
@@ -257,6 +275,7 @@ Neural Speed supports the following models:
| |
|
Latest |
+ 448 |
diff --git a/neural_speed/__init__.py b/neural_speed/__init__.py
index e9f4203fb..dda41c270 100644
--- a/neural_speed/__init__.py
+++ b/neural_speed/__init__.py
@@ -24,6 +24,7 @@
class Model:
+
def __init__(self):
self.module = None
self.model = None
@@ -84,9 +85,19 @@ def get_model_type(model_config):
model_type = "chatglm2"
return model_type
- def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, use_autoround=False,
- weight_dtype="int4", alg="sym", group_size=32,
- scale_dtype="fp32", compute_dtype="int8", use_ggml=False, model_hub="huggingface"):
+ def init(self,
+ model_name,
+ use_quant=True,
+ use_gptq=False,
+ use_awq=False,
+ use_autoround=False,
+ weight_dtype="int4",
+ alg="sym",
+ group_size=32,
+ scale_dtype="fp32",
+ compute_dtype="int8",
+ use_ggml=False,
+ model_hub="huggingface"):
if model_hub == "modelscope":
from modelscope import AutoConfig
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
@@ -124,8 +135,7 @@ def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, use_au
self.bin_file = quant_bin
if os.path.exists(self.bin_file):
- print("{} existed, will use cache file. Otherwise please remove the file".
- format(self.bin_file))
+ print("{} existed, will use cache file. Otherwise please remove the file".format(self.bin_file))
return
if use_gptq or use_awq or use_autoround:
@@ -133,15 +143,20 @@ def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, use_au
return
if not os.path.exists(fp32_bin):
- convert_model(model_name, fp32_bin, "f32", model_hub = model_hub)
+ convert_model(model_name, fp32_bin, "f32", model_hub=model_hub)
assert os.path.exists(fp32_bin), "Fail to convert pytorch model"
if not use_quant:
print("FP32 model will be used.")
return
- self.module.Model.quant_model(model_path=fp32_bin, out_path=quant_bin,
- weight_dtype=weight_dtype, alg=alg, group_size=group_size,
- scale_dtype=scale_dtype, compute_dtype=compute_dtype, use_ggml=use_ggml)
+ self.module.Model.quant_model(model_path=fp32_bin,
+ out_path=quant_bin,
+ weight_dtype=weight_dtype,
+ alg=alg,
+ group_size=group_size,
+ scale_dtype=scale_dtype,
+ compute_dtype=compute_dtype,
+ use_ggml=use_ggml)
assert os.path.exists(quant_bin), "Fail to quantize model"
# clean
@@ -150,9 +165,11 @@ def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, use_au
def init_from_bin(self, model_type, model_path, **generate_kwargs):
self.__import_package(model_type)
self.model = self.module.Model()
+
if self.max_request_num == -1:
- self.max_request_num = max(generate_kwargs.get("max_request_num",
- max_request_num_default), generate_kwargs.get("batch_size", 1))
+ self.max_request_num = max(generate_kwargs.get("max_request_num", max_request_num_default),
+ generate_kwargs.get("batch_size", 1))
+
if "threads" not in generate_kwargs:
threads = os.getenv("OMP_NUM_THREADS")
import platform
@@ -165,29 +182,107 @@ def init_from_bin(self, model_type, model_path, **generate_kwargs):
generate_kwargs["threads"] = len(os.sched_getaffinity(0))
else:
generate_kwargs["threads"] = int(threads)
- self.model.init_model(model_path, **generate_kwargs)
+ # Setting scratch_size_ratio according to the ctx_size & tokens_length
+ # If scratch_size_ratio has been set, will not enter this branch.
+ if generate_kwargs.get("ctx_size") is not None and generate_kwargs.get(
+ "ctx_size") > 2048 and generate_kwargs.get("scratch_size_ratio") is None:
+
+ def get_max_seq_length():
+ config = self.config.to_dict()
+ # chatglm2, bloom
+ if 'seq_length' in config:
+ return config['seq_length']
+ # qwen2, llama-2, llama, dolly, gptneox, qwen, qwen1.5, opt, phi
+ elif 'max_position_embeddings' in config:
+ return config['max_position_embeddings']
+ # baichuan, baichuan2
+ elif 'model_max_length' in config:
+ return config['model_max_length']
+ # gptj
+ elif 'n_positions' in config:
+ return config['n_positions']
+ # mpt
+ elif 'max_seq_len' in config:
+ return config['max_seq_len']
+ # chatglm
+ elif 'max_sequence_length' in config:
+ return config['max_sequence_length']
+ # whisper
+ elif 'max_length' in config:
+ return config['max_length']
+ # Falcon does not have these parameters.
+ elif model_type == "falcon":
+ return 2048
+ else:
+ print("Not found max seq length, setting to default 512")
+ return 512
+
+ # when tokens less than 10240
+ def get_scratch_size_ratio(size):
+ if size > 2048 and size <= 4096:
+ return 2
+ elif size > 4096 and size <= 8192:
+ return 4
+ elif size > 8192 and size <= 10240:
+ return 8
+ else:
+ # more than 10240
+ return -1
+
+ max_seq_length = get_max_seq_length()
+ ctx_size = generate_kwargs.get("ctx_size")
+
+ if ctx_size > max_seq_length:
+ print(f'max_seq_length is {max_seq_length}, but ctx_size is {ctx_size}. Please reduce ctx_size.')
+ exit(0)
+
+ if max_seq_length > 2048 and max_seq_length <= 4096:
+ generate_kwargs["scratch_size_ratio"] = 2
+ elif max_seq_length > 4096 and max_seq_length <= 8192:
+ generate_kwargs["scratch_size_ratio"] = 4
+ elif max_seq_length > 8192:
+ if get_scratch_size_ratio(ctx_size) != -1:
+ generate_kwargs["scratch_size_ratio"] = get_scratch_size_ratio(ctx_size)
+ else:
+ if max_seq_length == 16384:
+ generate_kwargs["scratch_size_ratio"] = 12
+ elif max_seq_length == 32768:
+ if ctx_size < 20480:
+ generate_kwargs["scratch_size_ratio"] = 20
+ else:
+ generate_kwargs["scratch_size_ratio"] = 35
+
+ self.model.init_model(model_path, **generate_kwargs)
def quant_model(self, model_type, model_path, out_path, **quant_kwargs):
self.__import_package(model_type)
self.module.Model.quant_model(model_path=model_path, out_path=out_path, **quant_kwargs)
+ def generate(self,
+ input_ids,
+ streamer=None,
+ interactive=False,
+ ignore_prompt=False,
+ stopping_criteria=None,
+ **generate_kwargs):
+ batch_size = input_ids.shape[0]
- def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=False,
- stopping_criteria=None, **generate_kwargs):
max_new_tokens = generate_kwargs.get("max_new_tokens", -1)
- input_bs = input_ids.shape[0]
max_request_num = generate_kwargs.pop("max_request_num", max_request_num_default)
reinit_from_bin = False
- if max_request_num > self.max_request_num or input_bs > self.max_request_num:
+ if max_request_num > self.max_request_num or batch_size > self.max_request_num:
reinit_from_bin = True
if self.max_request_num > 0:
print("Will start to reinit model from bin due to different max request num.")
- self.max_request_num = max(input_bs, max_request_num)
+ self.max_request_num = max(batch_size, max_request_num)
if self.model is None or reinit_from_bin:
- self.init_from_bin(self.model_type, self.bin_file, batch_size=input_bs,
- max_request_num = self.max_request_num, **generate_kwargs)
+ self.init_from_bin(self.model_type,
+ self.bin_file,
+ batch_size=batch_size,
+ max_request_num=self.max_request_num,
+ **generate_kwargs)
self.generate_round = 0
elif not interactive:
self.model.reinit()
@@ -208,6 +303,7 @@ def generate(self, input_ids, streamer=None, interactive=False, ignore_prompt=Fa
assert input_ids.shape[0] == 1, "Streamer only supports batch size 1."
assert beam_search == False, "ERROR, can not use streamer when use beam search for generation! \
Make sure that `num_beams` is set to 1."
+
if self.generate_round == 0 and not ignore_prompt:
streamer.put(input_ids)
@@ -284,6 +380,6 @@ def _cont_batching_input(self, input_ids, pad_token_id=None):
for il in range(len(input_list)):
count = input_list[il].count(pti)
# padding left
- del input_list[il][0: count]
+ del input_list[il][0:count]
assert input_list[il] != [], "there are all pad tokens in batch {}.".format(il)
return input_list
diff --git a/neural_speed/application/main_pybind.cpp b/neural_speed/application/main_pybind.cpp
index f82fd803b..9149cd439 100644
--- a/neural_speed/application/main_pybind.cpp
+++ b/neural_speed/application/main_pybind.cpp
@@ -84,7 +84,7 @@ void init_gpt_params(gpt_params* params, const std::string& model_path, int max_
bool early_stopping = false, int n_keep = 0, int n_discard = -1, bool shift_roped_k = false,
int batch_size = 1, model_vocab::id pad_token = -1, const std::string& memory_dtype = "auto",
bool continuous_batching = true, const int& max_request_num = MODEL_MAX_REQUEST_NUM,
- const float& model_scratch_enlarge_scale = 1.0f) {
+ const float& scratch_size_ratio = 1.0f) {
MODEL_ASSERT(params != nullptr);
#ifdef MODEL_NAME
params->model_name = MODEL_NAME;
@@ -115,6 +115,7 @@ void init_gpt_params(gpt_params* params, const std::string& model_path, int max_
params->memory_type = KV_MEM_TYPE_AUTO;
else
fprintf(stderr, "Unexpected memory dtype %s!", memory_dtype.c_str());
+
// TODO(Yi & YZT): MHA IN MULTI-BATCH For More Model Archs
params->cont_batching = continuous_batching;
if (params->shift_roped_k) params->cont_batching = false;
@@ -126,13 +127,19 @@ void init_gpt_params(gpt_params* params, const std::string& model_path, int max_
params->min_new_tokens = min_new_tokens;
params->length_penalty = length_penalty;
params->do_early_stopping = early_stopping;
- params->model_scratch_enlarge_scale = model_scratch_enlarge_scale;
+ params->scratch_size_ratio = scratch_size_ratio;
+
+ // TODO(Yi): MHA FOR LONG TOKENS
+ int32_t tokens_length = 6144;
+ if (params->n_ctx > tokens_length) {
+ params->memory_type = KV_MEM_TYPE_F16;
+ }
printf(
- "beam_size: %d, do_sample: %d, top_k: %d, top_p: %f, continuous_batching: %d, max_request_num: %d, "
- "early_stopping: %d\n",
+ "beam_size: %d, do_sample: %d, top_k: %d, top_p: %.3f, continuous_batching: %d, max_request_num: %d, "
+ "early_stopping: %d, scratch_size_ratio: %.3f\n",
params->beam_size, params->do_sample, params->top_k, params->top_p, params->cont_batching,
- params->max_request_num, params->do_early_stopping);
+ params->max_request_num, params->do_early_stopping, params->scratch_size_ratio);
}
class ModelServer {
@@ -142,7 +149,7 @@ class ModelServer {
int top_k, float top_p, float temperature, int min_new_tokens, float length_penalty, bool early_stopping,
int n_keep, int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token,
const std::string& memory_dtype, bool continuous_batching, const int& max_request_num,
- const float& model_scratch_enlarge_scale, const std::string& policy, bool print_log,
+ const float& scratch_size_ratio, const std::string& policy, bool print_log,
const std::function& init_cb)
: response(response),
waiting(),
@@ -161,7 +168,7 @@ class ModelServer {
this->InitServerParams(model_path, max_new_tokens, n_batch, ctx_size, seed, threads, repetition_penalty,
num_beams, do_sample, top_k, top_p, temperature, min_new_tokens, length_penalty,
early_stopping, n_keep, n_discard, shift_roped_k, batch_size, pad_token, memory_dtype,
- true, max_request_num, model_scratch_enlarge_scale);
+ true, max_request_num, scratch_size_ratio);
Cont_batch_gen_scheduler scheduler(this->params, policy, print_log ? 0 : 1);
std::vector added_seqs;
while (running) {
@@ -263,11 +270,11 @@ class ModelServer {
float temperature, int min_new_tokens, float length_penalty, bool early_stopping, int n_keep,
int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token,
const std::string& memory_dtype, bool continuous_batching, const int& max_request_num,
- const float& model_scratch_enlarge_scale) {
+ const float& scratch_size_ratio) {
init_gpt_params(¶ms, model_path, max_new_tokens, n_batch, ctx_size, seed, threads, repetition_penalty,
num_beams, do_sample, top_k, top_p, temperature, min_new_tokens, length_penalty, early_stopping,
n_keep, n_discard, shift_roped_k, batch_size, pad_token, memory_dtype, continuous_batching,
- max_request_num, model_scratch_enlarge_scale);
+ max_request_num, scratch_size_ratio);
if (cont_batching_model_archs.count(params.model_arch) == 0) {
fprintf(stderr, "\nERROR: ModelServer only supports gpt-j, llama!\n");
running = false;
@@ -325,7 +332,7 @@ class Model {
float repetition_penalty, int num_beams, bool do_sample, int top_k, float top_p, float temperature,
int min_new_tokens, float length_penalty, bool early_stopping, int n_keep, int n_discard,
bool shift_roped_k, int batch_size, model_vocab::id pad_token, const std::string& memory_dtype,
- bool continuous_batching, const int& max_request_num, const float& model_scratch_enlarge_scale);
+ bool continuous_batching, const int& max_request_num, const float& scratch_size_ratio);
void reinit();
std::vector> generate(const std::vector>& input_ids);
// deprecated API
@@ -419,11 +426,11 @@ void Model::init_model(const std::string& model_path, int max_new_tokens, int n_
float temperature, int min_new_tokens, float length_penalty, bool early_stopping, int n_keep,
int n_discard, bool shift_roped_k, int batch_size, model_vocab::id pad_token,
const std::string& memory_dtype, bool continuous_batching, const int& max_request_num,
- const float& model_scratch_enlarge_scale) {
+ const float& scratch_size_ratio) {
init_gpt_params(¶ms, model_path, max_new_tokens, n_batch, ctx_size, seed, threads, repetition_penalty, num_beams,
do_sample, top_k, top_p, temperature, min_new_tokens, length_penalty, early_stopping, n_keep,
n_discard, shift_roped_k, batch_size, pad_token, memory_dtype, continuous_batching, max_request_num,
- model_scratch_enlarge_scale);
+ scratch_size_ratio);
n_past = 0;
n_total = 0;
@@ -533,12 +540,13 @@ const std::vector& Model::evaluate_(const std::vector n_ctx - 4) { // long input_id_cb and empty curr_input_ids[bs]
+ } else if (input_id_cb.size() > n_ctx - params.n_keep) { // long input_id_cb and empty curr_input_ids[bs]
fprintf(stderr, "\n%s: Warning: prompt is too long (%zu tokens, max %d), will be truncated\n", __func__,
- input_id_cb.size(), n_ctx - 4);
- curr_input_ids[bs].resize(n_ctx - 4);
- std::copy(input_id_cb.end() - n_ctx - 8, input_id_cb.end(), curr_input_ids[bs].begin() + 4);
- std::copy(input_id_cb.begin(), input_id_cb.begin() + 4, curr_input_ids[bs].begin());
+ input_id_cb.size(), n_ctx - params.n_keep);
+ curr_input_ids[bs].resize(n_ctx - params.n_keep);
+ std::copy(input_id_cb.end() - n_ctx - params.n_keep * 2, input_id_cb.end(),
+ curr_input_ids[bs].begin() + params.n_keep);
+ std::copy(input_id_cb.begin(), input_id_cb.begin() + params.n_keep, curr_input_ids[bs].begin());
} else { // good input_id_cb and empty curr_input_ids[bs]
curr_input_ids[bs] = input_id_cb;
}
@@ -648,13 +656,13 @@ std::vector> Model::generate_tokens(const std::vector n_ctx - 4) {
+ if (input_ids[STATIC_INPUT_HEAD_IDX].size() > n_ctx - params.n_keep) {
fprintf(stderr, "\n%s: Warning: prompt is too long (%zu tokens, max %d), will be truncated\n", __func__,
- input_ids[STATIC_INPUT_HEAD_IDX].size(), n_ctx - 4);
- curr_input_ids[STATIC_INPUT_HEAD_IDX].resize(n_ctx - 4);
- std::copy(input_ids[STATIC_INPUT_HEAD_IDX].end() - n_ctx - 8, input_ids[STATIC_INPUT_HEAD_IDX].end(),
- curr_input_ids[STATIC_INPUT_HEAD_IDX].begin() + 4);
- std::copy(input_ids[STATIC_INPUT_HEAD_IDX].begin(), input_ids[STATIC_INPUT_HEAD_IDX].begin() + 4,
+ input_ids[STATIC_INPUT_HEAD_IDX].size(), n_ctx - params.n_keep);
+ curr_input_ids[STATIC_INPUT_HEAD_IDX].resize(n_ctx - params.n_keep);
+ std::copy(input_ids[STATIC_INPUT_HEAD_IDX].end() - n_ctx - params.n_keep * 2,
+ input_ids[STATIC_INPUT_HEAD_IDX].end(), curr_input_ids[STATIC_INPUT_HEAD_IDX].begin() + params.n_keep);
+ std::copy(input_ids[STATIC_INPUT_HEAD_IDX].begin(), input_ids[STATIC_INPUT_HEAD_IDX].begin() + params.n_keep,
curr_input_ids[STATIC_INPUT_HEAD_IDX].begin());
} else {
curr_input_ids[STATIC_INPUT_HEAD_IDX] = input_ids[STATIC_INPUT_HEAD_IDX];
@@ -922,7 +930,7 @@ PYBIND11_MODULE(mixtral_cpp, m)
py::arg("n_keep") = 0, py::arg("n_discard") = -1, py::arg("shift_roped_k") = false,
py::arg("batch_size") = 1, py::arg("pad_token") = -1, py::arg("memory_dtype") = "auto",
py::arg("continuous_batching") = true, py::arg("max_request_num") = MODEL_MAX_REQUEST_NUM,
- py::arg("model_scratch_enlarge_scale") = 1.0f)
+ py::arg("scratch_size_ratio") = 1.0f)
.def("generate", &Model::generate, "Generate token with input ids", py::arg("input_ids"))
.def("evaluate", &Model::evaluate, "Evaluate token with input ids and output logits",
py::arg("input_ids") = std::vector>{}, py::arg("logits_all") = false)
@@ -962,7 +970,7 @@ PYBIND11_MODULE(mixtral_cpp, m)
py::arg("length_penalty") = 1.0, py::arg("early_stopping") = false, py::arg("n_keep") = 0,
py::arg("n_discard") = -1, py::arg("shift_roped_k") = false, py::arg("batch_size") = 1,
py::arg("pad_token") = -1, py::arg("memory_dtype") = "auto", py::arg("continuous_batching") = true,
- py::arg("max_request_num") = MODEL_MAX_REQUEST_NUM, py::arg("model_scratch_enlarge_scale") = 1.0f,
+ py::arg("max_request_num") = MODEL_MAX_REQUEST_NUM, py::arg("scratch_size_ratio") = 1.0f,
py::arg("policy") = "fcfs", py::arg("print_log") = false,
py::arg("init_cb") = std::function{[]() {}})
.def("issueQuery", &ModelServer::issueQuery, "desc placeholder", py::arg("qs"))
diff --git a/neural_speed/application/main_run.cpp b/neural_speed/application/main_run.cpp
index 7a74197e4..55cd4fec8 100644
--- a/neural_speed/application/main_run.cpp
+++ b/neural_speed/application/main_run.cpp
@@ -250,9 +250,9 @@ int main(int argc, char** argv) { // NOLINT
const int n_ctx = model_n_ctx(ctx);
- if (static_cast(embd_inp.size()) > n_ctx - 4) {
+ if (static_cast(embd_inp.size()) > n_ctx - params.n_keep) {
fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, static_cast(embd_inp.size()),
- n_ctx - 4);
+ n_ctx - params.n_keep);
return 1;
}
@@ -352,8 +352,8 @@ int main(int argc, char** argv) { // NOLINT
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k,
params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta,
params.mirostat_tau);
- fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch,
- params.n_predict, params.n_keep);
+ fprintf(stderr, "generate: n_ctx = %d, tokens_length = %ld, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx,
+ embd_inp.size(), params.n_batch, params.n_predict, params.n_keep);
fprintf(stderr, "\n\n");
// TODO(Bo): replace with ring-buffer
diff --git a/neural_speed/application/pybind_gptj.cpp b/neural_speed/application/pybind_gptj.cpp
index 9779b6913..f98dc7a0a 100644
--- a/neural_speed/application/pybind_gptj.cpp
+++ b/neural_speed/application/pybind_gptj.cpp
@@ -35,9 +35,8 @@ static model_context** g_ctx;
bool gptj_model_eval_ids(model_context* ctx, model_token* tokens, size_t n_eval, size_t n_past, size_t n_threads) {
const int n_ctx = model_n_ctx(ctx);
- if (static_cast(n_eval) > n_ctx - 4) {
- fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, static_cast(n_eval),
- n_ctx - 4);
+ if (static_cast(n_eval) > n_ctx) {
+ fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, static_cast(n_eval), n_ctx);
return true;
}
diff --git a/neural_speed/core/ne_layers.c b/neural_speed/core/ne_layers.c
index 9030e67cd..b8110b74c 100644
--- a/neural_speed/core/ne_layers.c
+++ b/neural_speed/core/ne_layers.c
@@ -911,8 +911,10 @@ struct ne_tensor* ne_new_tensor_impl(struct ne_context* ctx, enum ne_type type,
size_needed += sizeof(struct ne_tensor);
if (cur_end + size_needed + NE_OBJECT_SIZE > ctx->mem_size) {
- NE_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", __func__,
- cur_end + size_needed + NE_OBJECT_SIZE, ctx->mem_size);
+ NE_PRINT(
+ "%s: %d Context's memory pool is not enough(current %zu MB, ctx->mem_size available %zu MB), please increase "
+ "the scratch_size_ratio.\n",
+ __func__, __LINE__, (cur_end + size_needed + NE_OBJECT_SIZE) / 1024 / 1024, ctx->mem_size / 1024 / 1024);
assert(false);
return NULL;
}
@@ -924,14 +926,17 @@ struct ne_tensor* ne_new_tensor_impl(struct ne_context* ctx, enum ne_type type,
};
} else {
if (ctx->scratch.offs + size_needed > ctx->scratch.size) {
- NE_PRINT("%s: not enough space in the scratch memory\n", __func__);
+ NE_PRINT(
+ "%s: %d scratch.size pool is not enough(current %zu MB, ctx->scratch.size available %zu MB), please increase "
+ "the scratch_size_ratio.\n",
+ __func__, __LINE__, (ctx->scratch.offs + size_needed) / 1024 / 1024, ctx->scratch.size / 1024 / 1024);
assert(false);
return NULL;
}
if (cur_end + sizeof(struct ne_tensor) + NE_OBJECT_SIZE > ctx->mem_size) {
- NE_PRINT("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", __func__,
- cur_end + sizeof(struct ne_tensor) + NE_OBJECT_SIZE, ctx->mem_size);
+ NE_PRINT("%s: %d not enough space in the context's memory pool (needed %zu, ctx->mem_size available %zu)\n",
+ __func__, __LINE__, cur_end + sizeof(struct ne_tensor) + NE_OBJECT_SIZE, ctx->mem_size);
assert(false);
return NULL;
}
diff --git a/neural_speed/models/baichuan/baichuan.h b/neural_speed/models/baichuan/baichuan.h
index d2cb1ad5d..2803d9617 100644
--- a/neural_speed/models/baichuan/baichuan.h
+++ b/neural_speed/models/baichuan/baichuan.h
@@ -23,10 +23,14 @@ enum baichuan_model {
BAICHUAN_13B,
};
-static const model_scratch baichuan_mem_req(int n_layers) {
+static const model_scratch baichuan_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
switch (n_layers) {
case 40:
- return {8192ull * MB, 8192ull * MB, 8192ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ };
default:
MODEL_ASSERT(false);
}
diff --git a/neural_speed/models/baichuan/baichuan_utils.cpp b/neural_speed/models/baichuan/baichuan_utils.cpp
index e755833cb..155bc1e61 100644
--- a/neural_speed/models/baichuan/baichuan_utils.cpp
+++ b/neural_speed/models/baichuan/baichuan_utils.cpp
@@ -75,7 +75,7 @@ void BAICHUAN::init(const char* path_model, model_context* ctx, int n_gpu_layer_
n_embd = hparams.n_embd;
n_vocab = hparams.n_vocab;
n_layer = hparams.n_layer;
- scratch = baichuan_mem_req(n_layer);
+ scratch = baichuan_mem_req(n_layer, lctx.scratch_size_ratio);
model.scratchs = scratch;
}
@@ -89,7 +89,7 @@ void BAICHUAN::load(model_context* ctx, model_progress_callback progress_callbac
size_t mmapped_size;
ml->calc_sizes(&ctx_size, &mmapped_size);
ctx_size = ctx_size * 2;
- fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
const auto& hparams = model.hparams;
const int head_dim = n_embd / hparams.n_head;
@@ -153,6 +153,9 @@ void BAICHUAN::load(model_context* ctx, model_progress_callback progress_callbac
// this is the total memory required to run the inference
const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory
scratch.scratch0 + scratch.scratch1 + scratch.eval;
+ fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0);
fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0);
(void)n_gpu_layer;
diff --git a/neural_speed/models/bloom/bloom.h b/neural_speed/models/bloom/bloom.h
index e66ababe5..86b480bbd 100644
--- a/neural_speed/models/bloom/bloom.h
+++ b/neural_speed/models/bloom/bloom.h
@@ -23,10 +23,14 @@ enum bloom_model {
BLOOM_7B,
};
-static const model_scratch bloom_mem_req(int n_layers) {
+static const model_scratch bloom_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
switch (n_layers) {
case 30:
- return {4 * 2048ull * MB, 4 * 2048ull * MB, 4 * 4096ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ };
default:
MODEL_ASSERT(false);
}
diff --git a/neural_speed/models/bloom/bloom_utils.cpp b/neural_speed/models/bloom/bloom_utils.cpp
index 10d72bb9a..4f57767b0 100644
--- a/neural_speed/models/bloom/bloom_utils.cpp
+++ b/neural_speed/models/bloom/bloom_utils.cpp
@@ -73,7 +73,7 @@ void BLOOM::init(const char* path_model, model_context* ctx, int n_gpu_layer_, b
n_embd = hparams.n_embd;
n_vocab = hparams.n_vocab;
n_layer = hparams.n_layer;
- scratch = bloom_mem_req(n_layer);
+ scratch = bloom_mem_req(n_layer, lctx.scratch_size_ratio);
model.scratchs = scratch;
}
@@ -86,7 +86,7 @@ void BLOOM::load(model_context* ctx, model_progress_callback progress_callback,
size_t ctx_size;
size_t mmapped_size;
ml->calc_sizes(&ctx_size, &mmapped_size);
- fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
// create the ne context
lctx.model.buf.resize(ctx_size);
@@ -195,6 +195,9 @@ void BLOOM::load(model_context* ctx, model_progress_callback progress_callback,
// this is the total memory required to run the inference
const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory
scratch.scratch0 + scratch.scratch1 + scratch.eval;
+ fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0);
fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0);
(void)n_gpu_layer;
diff --git a/neural_speed/models/chatglm/chatglm.h b/neural_speed/models/chatglm/chatglm.h
index 853086bb3..a26194cde 100644
--- a/neural_speed/models/chatglm/chatglm.h
+++ b/neural_speed/models/chatglm/chatglm.h
@@ -23,11 +23,14 @@ enum chatglm_model {
CHATGLM_6B,
};
-static const model_scratch chatglm_mem_req(int n_layers) {
+static const model_scratch chatglm_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
switch (n_layers) {
case 28:
- return {4096ull * MB, 4096ull * MB, 8192ull * MB};
- // TODO(hengyu): add more variants besides 6B
+ return {
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ };
default:
MODEL_ASSERT(false);
}
diff --git a/neural_speed/models/chatglm/chatglm2.h b/neural_speed/models/chatglm/chatglm2.h
index 35db8175b..328ae7d7f 100644
--- a/neural_speed/models/chatglm/chatglm2.h
+++ b/neural_speed/models/chatglm/chatglm2.h
@@ -23,10 +23,14 @@ enum chatglm2_model {
CHATGLM2_6B,
};
-static const model_scratch chatglm_mem_req(int n_layers) {
+static const model_scratch chatglm_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
switch (n_layers) {
case 28:
- return {4096ull * MB, 4096ull * MB, 8192ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ };
default:
MODEL_ASSERT(false);
}
diff --git a/neural_speed/models/chatglm/chatglm2_utils.cpp b/neural_speed/models/chatglm/chatglm2_utils.cpp
index 3fd38d2f3..b2202b06a 100644
--- a/neural_speed/models/chatglm/chatglm2_utils.cpp
+++ b/neural_speed/models/chatglm/chatglm2_utils.cpp
@@ -78,7 +78,7 @@ void CHATGLM2::init(const char* path_model, model_context* ctx, int n_gpu_layer_
n_embd = hparams.n_embd;
n_vocab = hparams.n_vocab;
n_layer = hparams.n_layer;
- scratch = chatglm_mem_req(n_layer);
+ scratch = chatglm_mem_req(n_layer, lctx.scratch_size_ratio);
model.scratchs = scratch;
}
@@ -92,7 +92,7 @@ void CHATGLM2::load(model_context* ctx, model_progress_callback progress_callbac
size_t mmapped_size;
ml->calc_sizes(&ctx_size, &mmapped_size);
ctx_size = ctx_size * 2;
- fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
const auto& hparams = model.hparams;
MODEL_ASSERT(("chatglm uses multi_query_group_num rather than n_head_kv",
@@ -174,6 +174,9 @@ void CHATGLM2::load(model_context* ctx, model_progress_callback progress_callbac
// this is the total memory required to run the inference
const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory
scratch.scratch0 + scratch.scratch1 + scratch.eval;
+ fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0);
fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0);
(void)n_gpu_layer;
diff --git a/neural_speed/models/chatglm/chatglm_utils.cpp b/neural_speed/models/chatglm/chatglm_utils.cpp
index e11579e90..665498cee 100644
--- a/neural_speed/models/chatglm/chatglm_utils.cpp
+++ b/neural_speed/models/chatglm/chatglm_utils.cpp
@@ -72,7 +72,7 @@ void CHATGLM::init(const char* path_model, model_context* ctx, int n_gpu_layer_,
n_embd = hparams.n_embd;
n_vocab = hparams.n_vocab;
n_layer = hparams.n_layer;
- scratch = chatglm_mem_req(n_layer);
+ scratch = chatglm_mem_req(n_layer, lctx.scratch_size_ratio);
model.scratchs = scratch;
}
@@ -86,7 +86,7 @@ void CHATGLM::load(model_context* ctx, model_progress_callback progress_callback
size_t mmapped_size;
ml->calc_sizes(&ctx_size, &mmapped_size);
ctx_size = ctx_size * 2;
- fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
const auto& hparams = model.hparams;
const int head_dim = n_embd / hparams.n_head;
@@ -160,6 +160,9 @@ void CHATGLM::load(model_context* ctx, model_progress_callback progress_callback
// this is the total memory required to run the inference
const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory
scratch.scratch0 + scratch.scratch1 + scratch.eval;
+ fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0);
fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0);
(void)n_gpu_layer;
diff --git a/neural_speed/models/falcon/falcon.h b/neural_speed/models/falcon/falcon.h
index b837c141a..3d6ea8b36 100644
--- a/neural_speed/models/falcon/falcon.h
+++ b/neural_speed/models/falcon/falcon.h
@@ -23,14 +23,26 @@ enum falcon_model {
FALCON_7B,
};
-static const model_scratch falcon_mem_req(int n_layers) {
+static const model_scratch falcon_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
switch (n_layers) {
case 32:
- return {2048ull * MB, 2048ull * MB, 4096ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ };
case 60:
- return {2 * 2048ull * MB, 2 * 2048ull * MB, 2 * 4096ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 2 * 3072) * MB,
+ static_cast(scratch_size_ratio * 2 * 2048) * MB,
+ static_cast(scratch_size_ratio * 2 * 3072) * MB,
+ };
case 80:
- return {3 * 2048ull * MB, 3 * 2048ull * MB, 3 * 4096ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 3 * 3072) * MB,
+ static_cast(scratch_size_ratio * 3 * 2048) * MB,
+ static_cast(scratch_size_ratio * 3 * 3072) * MB,
+ };
default:
MODEL_ASSERT(false);
}
diff --git a/neural_speed/models/falcon/falcon_utils.cpp b/neural_speed/models/falcon/falcon_utils.cpp
index b52919d7f..ea27d7946 100644
--- a/neural_speed/models/falcon/falcon_utils.cpp
+++ b/neural_speed/models/falcon/falcon_utils.cpp
@@ -77,7 +77,7 @@ void FALCON::init(const char* path_model, model_context* ctx, int n_gpu_layer_,
n_vocab = hparams.n_vocab;
n_layer = hparams.n_layer;
n_head_kv = hparams.n_head_kv;
- scratch = falcon_mem_req(n_layer);
+ scratch = falcon_mem_req(n_layer, lctx.scratch_size_ratio);
model.scratchs = scratch;
}
@@ -90,7 +90,7 @@ void FALCON::load(model_context* ctx, model_progress_callback progress_callback,
size_t ctx_size;
size_t mmapped_size;
ml->calc_sizes(&ctx_size, &mmapped_size);
- fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
// create the ne context
lctx.model.buf.resize(ctx_size);
@@ -204,6 +204,9 @@ void FALCON::load(model_context* ctx, model_progress_callback progress_callback,
// this is the total memory required to run the inference
const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory
scratch.scratch0 + scratch.scratch1 + scratch.eval;
+ fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0);
fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0);
(void)n_gpu_layer;
diff --git a/neural_speed/models/gptj/gptj.h b/neural_speed/models/gptj/gptj.h
index f1edcf97c..dacafc2b4 100644
--- a/neural_speed/models/gptj/gptj.h
+++ b/neural_speed/models/gptj/gptj.h
@@ -26,14 +26,14 @@ enum gptj_model {
GPTJ_65B,
};
-static const model_scratch gptj_mem_req(int n_layers, float enlarge_scale = 1.0f) {
+static const model_scratch gptj_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
switch (n_layers) {
case 28:
// should be enough for batch=8 * beam=4
return {
- static_cast(enlarge_scale * 3072) * MB,
- static_cast(enlarge_scale * 2048) * MB,
- static_cast(enlarge_scale * 3072) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
};
default:
MODEL_ASSERT(false);
diff --git a/neural_speed/models/gptj/gptj_utils.cpp b/neural_speed/models/gptj/gptj_utils.cpp
index 2e6702e48..004a5b423 100644
--- a/neural_speed/models/gptj/gptj_utils.cpp
+++ b/neural_speed/models/gptj/gptj_utils.cpp
@@ -75,7 +75,7 @@ void GPTJ::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bo
n_embd = hparams.n_embd;
n_vocab = hparams.n_vocab;
n_layer = hparams.n_layer;
- scratch = gptj_mem_req(n_layer, lctx.model_scratch_enlarge_scale);
+ scratch = gptj_mem_req(n_layer, lctx.scratch_size_ratio);
model.scratchs = scratch;
}
@@ -88,7 +88,7 @@ void GPTJ::load(model_context* ctx, model_progress_callback progress_callback, v
size_t ctx_size;
size_t mmapped_size;
ml->calc_sizes(&ctx_size, &mmapped_size);
- fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
// create the ne context
lctx.model.buf.resize(ctx_size);
@@ -154,6 +154,9 @@ void GPTJ::load(model_context* ctx, model_progress_callback progress_callback, v
// this is the total memory required to run the inference
const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory
scratch.scratch0 + scratch.scratch1 + scratch.eval;
+ fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0);
fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0);
(void)n_gpu_layer;
diff --git a/neural_speed/models/gptneox/gptneox.h b/neural_speed/models/gptneox/gptneox.h
index 1304b386f..617827618 100644
--- a/neural_speed/models/gptneox/gptneox.h
+++ b/neural_speed/models/gptneox/gptneox.h
@@ -23,14 +23,26 @@ enum gptneox_model {
GPTNEOX_7B,
};
-static const model_scratch gptneox_mem_req(int n_layers) {
+static const model_scratch gptneox_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
switch (n_layers) {
case 44:
- return {2048ull * MB, 2048ull * MB, 4096ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ };
case 32:
- return {512ull * MB, 512ull * MB, 1026ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ };
case 28: // 5.8B
- return {512ull * MB, 512ull * MB, 1024ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ };
default:
MODEL_ASSERT(false);
}
diff --git a/neural_speed/models/gptneox/gptneox_utils.cpp b/neural_speed/models/gptneox/gptneox_utils.cpp
index 401c42803..7a7f0c967 100644
--- a/neural_speed/models/gptneox/gptneox_utils.cpp
+++ b/neural_speed/models/gptneox/gptneox_utils.cpp
@@ -74,7 +74,7 @@ void GPTNEOX::init(const char* path_model, model_context* ctx, int n_gpu_layer_,
n_embd = hparams.n_embd;
n_vocab = hparams.n_vocab;
n_layer = hparams.n_layer;
- scratch = gptneox_mem_req(n_layer);
+ scratch = gptneox_mem_req(n_layer, lctx.scratch_size_ratio);
model.scratchs = scratch;
}
@@ -87,7 +87,7 @@ void GPTNEOX::load(model_context* ctx, model_progress_callback progress_callback
size_t ctx_size;
size_t mmapped_size;
ml->calc_sizes(&ctx_size, &mmapped_size);
- fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
// create the ne context
lctx.model.buf.resize(ctx_size);
@@ -152,6 +152,9 @@ void GPTNEOX::load(model_context* ctx, model_progress_callback progress_callback
// this is the total memory required to run the inference
const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory
scratch.scratch0 + scratch.scratch1 + scratch.eval;
+ fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0);
fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0);
(void)n_gpu_layer;
diff --git a/neural_speed/models/llama/llama.h b/neural_speed/models/llama/llama.h
index 5c9f07e58..99fb65a72 100644
--- a/neural_speed/models/llama/llama.h
+++ b/neural_speed/models/llama/llama.h
@@ -26,37 +26,37 @@ enum llama_model {
LLAMA_65B,
};
-static const model_scratch llama_mem_req(int n_layers, float enlarge_scale = 1.0f) {
+static const model_scratch llama_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
switch (n_layers) {
case 32:
return {
- static_cast(enlarge_scale * 1024) * MB,
- static_cast(enlarge_scale * 1024) * MB,
- static_cast(enlarge_scale * 1608) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
};
case 40:
return {
- static_cast(enlarge_scale * 512) * MB,
- static_cast(enlarge_scale * 512) * MB,
- static_cast(enlarge_scale * 1608) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
};
case 48:
return {
- static_cast(enlarge_scale * 512) * MB,
- static_cast(enlarge_scale * 512) * MB,
- static_cast(enlarge_scale * 2366) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
};
case 60:
return {
- static_cast(enlarge_scale * 512) * MB,
- static_cast(enlarge_scale * 512) * MB,
- static_cast(enlarge_scale * 3124) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
};
case 80:
return {
- static_cast(enlarge_scale * 2048) * MB,
- static_cast(enlarge_scale * 2048) * MB,
- static_cast(enlarge_scale * 10240) * MB,
+ static_cast(scratch_size_ratio * 3072) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 3072 * 3) * MB,
};
default:
MODEL_ASSERT(false);
diff --git a/neural_speed/models/llama/llama_utils.cpp b/neural_speed/models/llama/llama_utils.cpp
index 9757156c0..6685dacdd 100644
--- a/neural_speed/models/llama/llama_utils.cpp
+++ b/neural_speed/models/llama/llama_utils.cpp
@@ -81,7 +81,7 @@ void Llama::init(const char* path_model, model_context* ctx, int n_gpu_layer_, b
n_head = hparams.n_head;
n_expert = hparams.n_experts;
n_expert_used = hparams.n_experts_used;
- scratch = llama_mem_req(n_layer, lctx.model_scratch_enlarge_scale);
+ scratch = llama_mem_req(n_layer, lctx.scratch_size_ratio);
model.scratchs = scratch;
}
@@ -93,7 +93,7 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback,
size_t ctx_size;
size_t mmapped_size;
ml->calc_sizes(&ctx_size, &mmapped_size);
- fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
// create the ne context
lctx.model.buf.resize(ctx_size);
@@ -226,6 +226,9 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback,
// this is the total memory required to run the inference
const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory
scratch.scratch0 + scratch.scratch1 + scratch.eval;
+ fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0);
fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0);
(void)n_gpu_layer;
diff --git a/neural_speed/models/model_utils/model_config.h b/neural_speed/models/model_utils/model_config.h
index 816780e32..a30935dc0 100644
--- a/neural_speed/models/model_utils/model_config.h
+++ b/neural_speed/models/model_utils/model_config.h
@@ -103,7 +103,7 @@ struct gpt_params {
float length_penalty = 1.0f; // exponential penalty to the length in beam search generation
bool do_early_stopping = false; // early stopping in beam search generation
- float model_scratch_enlarge_scale = 1.0f; // model memory scratch enlarge scale
+ float scratch_size_ratio = 1.0f; // model memory scratch enlarge scale
};
bool gpt_params_parse(int argc, char** argv, gpt_params& params);
diff --git a/neural_speed/models/model_utils/model_files.h b/neural_speed/models/model_utils/model_files.h
index e71ee94f0..11ac03260 100644
--- a/neural_speed/models/model_utils/model_files.h
+++ b/neural_speed/models/model_utils/model_files.h
@@ -1014,16 +1014,16 @@ struct model_file_loader {
} else if (model_magic == NE) {
std::cout << "Loading the bin file with NE format..." << std::endl;
fseek(file.fp, 0, SEEK_SET);
- read_ne_magic();
- read_ne_hparams();
- read_ne_vocab();
+ load_ne_magic();
+ load_ne_hparams();
+ load_ne_vocab();
read_tensor_metadata(file_idx, tensors_map);
} else {
throw format("unknown file format model_maigc = %d", model_magic);
}
}
- void read_ne_magic() {
+ void load_ne_magic() {
uint32_t magic = file.read_u32();
if (magic == MODEL_FILE_MAGIC_NE) {
@@ -1075,7 +1075,7 @@ struct model_file_loader {
return model_magic;
}
- void read_ne_hparams() {
+ void load_ne_hparams() {
unsigned int count = 0;
hparams.n_vocab = file.read_u32();
hparams.n_embd = file.read_u32();
@@ -1101,8 +1101,8 @@ struct model_file_loader {
hparams.do_layer_norm_before = bool(file.read_u32());
printf("%-16s %d.hparams.ftype = %-30d\n", __func__, count++, hparams.ftype);
printf("%-16s %d.hparams.max_seq_len = %-30d\n", __func__, count++, hparams.max_seq_len);
- printf("%-16s %d.hparams.alibi_bias_max = %-30f\n", __func__, count++, hparams.alibi_bias_max);
- printf("%-16s %d.hparams.clip_qkv = %-30f\n", __func__, count++, hparams.clip_qkv);
+ printf("%-16s %d.hparams.alibi_bias_max = %-30.3f\n", __func__, count++, hparams.alibi_bias_max);
+ printf("%-16s %d.hparams.clip_qkv = %-30.3f\n", __func__, count++, hparams.clip_qkv);
printf("%-16s %d.hparams.par_res = %-30d\n", __func__, count++, hparams.par_res);
printf("%-16s %d.hparams.word_embed_proj_dim = %-30d\n", __func__, count++, hparams.word_embed_proj_dim);
printf("%-16s %d.hparams.do_layer_norm_before = %-30d\n", __func__, count++, hparams.do_layer_norm_before);
@@ -1124,13 +1124,13 @@ struct model_file_loader {
file.read_raw(&hparams.freq_base, sizeof(float));
file.read_raw(&hparams.freq_scale, sizeof(float));
printf("%-16s %d.hparams.inner_hidden_size = %-30d\n", __func__, count++, hparams.inner_hidden_size);
- printf("%-16s %d.hparams.freq_base = %-30f\n", __func__, count++, hparams.freq_base);
- printf("%-16s %d.hparams.freq_scale = %-30f\n", __func__, count++, hparams.freq_scale);
+ printf("%-16s %d.hparams.freq_base = %-30.3f\n", __func__, count++, hparams.freq_base);
+ printf("%-16s %d.hparams.freq_scale = %-30.3f\n", __func__, count++, hparams.freq_scale);
file.read_raw(&hparams.rope_scaling_factor, sizeof(float));
hparams.original_max_position_embeddings = file.read_u32();
hparams.use_yarn = file.read_u32();
- printf("%-16s %d.hparams.rope_scaling_factor = %-30f\n", __func__, count++, hparams.rope_scaling_factor);
+ printf("%-16s %d.hparams.rope_scaling_factor = %-30.3f\n", __func__, count++, hparams.rope_scaling_factor);
printf("%-16s %d.hparams.original_max_position_embeddings = %-30d\n", __func__, count++,
hparams.original_max_position_embeddings);
printf("%-16s %d.hparams.use_yarn = %-30d\n", __func__, count++, hparams.use_yarn);
@@ -1140,7 +1140,7 @@ struct model_file_loader {
}
}
- void read_ne_vocab() {
+ void load_ne_vocab() {
unsigned int count = 0;
unsigned int ne_hparams_total = 25;
file.read_raw(&vocab.bos_token_id, sizeof(model_vocab::id));
diff --git a/neural_speed/models/model_utils/model_types.h b/neural_speed/models/model_utils/model_types.h
index 6833c94ea..611a54088 100644
--- a/neural_speed/models/model_utils/model_types.h
+++ b/neural_speed/models/model_utils/model_types.h
@@ -320,7 +320,7 @@ struct model_context {
size_t mem_per_token = 0;
- float model_scratch_enlarge_scale = 1.0f; // model memory scratch enlarge scale
+ float scratch_size_ratio = 1.0f; // model memory scratch enlarge scale
// decode output (3-dimensional array: [batch_size] [n_tokens] [n_vocab])
std::vector logits;
@@ -441,7 +441,7 @@ struct model_context_params {
// global generation config
generation_config gen_conf;
// model memory scratch enlarge scale
- float model_scratch_enlarge_scale;
+ float scratch_size_ratio;
// called with a progress value between 0 and 1, pass nullptr to disable
model_progress_callback progress_callback;
diff --git a/neural_speed/models/model_utils/model_utils.cpp b/neural_speed/models/model_utils/model_utils.cpp
index 896ed4231..eb6c0b4f0 100644
--- a/neural_speed/models/model_utils/model_utils.cpp
+++ b/neural_speed/models/model_utils/model_utils.cpp
@@ -188,7 +188,7 @@ struct model_context_params model_context_default_params() {
/*cont_batching =*/true,
/*.max_request_num =*/1,
/*.gen_conf =*/generation_config(),
- /*model_scratch_enlarge_scale =*/1.0f,
+ /*scratch_size_ratio =*/1.0f,
/*.progress_callback =*/nullptr,
/*.progress_callback_user_data =*/nullptr,
};
@@ -911,7 +911,9 @@ struct model_context* model_init_from_file(const char* path_model, struct model_
}
ctx->cont_batching = params.cont_batching;
ctx->generation_conf = params.gen_conf;
- ctx->model_scratch_enlarge_scale = params.model_scratch_enlarge_scale;
+
+ ctx->scratch_size_ratio = params.scratch_size_ratio * params.max_request_num * params.beam_size;
+
const model_archs arch = params.arch;
// the type so that kv-cache allocated according to this type must be large enough
@@ -1268,6 +1270,13 @@ struct model_context* model_init_from_gpt_params(const gpt_params& params) {
lparams.n_gpu_layers = params.n_gpu_layers;
lparams.seed = params.seed;
lparams.kv_type = params.memory_type;
+
+ // TODO(Yi): MHA FOR LONG TOKENS
+ int32_t long_tokens = 6144;
+ if (lparams.n_ctx > long_tokens) {
+ lparams.kv_type = KV_MEM_TYPE_F16;
+ }
+
lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock;
lparams.logits_all = params.perplexity;
@@ -1284,7 +1293,7 @@ struct model_context* model_init_from_gpt_params(const gpt_params& params) {
lparams.gen_conf.min_new_tokens = params.min_new_tokens;
lparams.gen_conf.length_penalty = params.length_penalty;
lparams.gen_conf.do_early_stopping = params.do_early_stopping;
- lparams.model_scratch_enlarge_scale = params.model_scratch_enlarge_scale;
+ lparams.scratch_size_ratio = params.scratch_size_ratio;
NE_ASSERT(("Start size cannot be greater than the maximum context size!", lparams.n_keep < lparams.n_ctx));
diff --git a/neural_speed/models/mpt/mpt.h b/neural_speed/models/mpt/mpt.h
index 9ce987602..0a4d614e9 100644
--- a/neural_speed/models/mpt/mpt.h
+++ b/neural_speed/models/mpt/mpt.h
@@ -24,12 +24,20 @@ enum mpt_model {
MPT_30B,
};
-static const model_scratch mpt_mem_req(int n_layers) {
+static const model_scratch mpt_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
switch (n_layers) {
case 32:
- return {2048ull * MB, 2048ull * MB, 4096ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ };
case 48:
- return {4096ull * MB, 4096ull * MB, 8192ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 8192) * MB,
+ };
default:
MODEL_ASSERT(false);
}
diff --git a/neural_speed/models/mpt/mpt_utils.cpp b/neural_speed/models/mpt/mpt_utils.cpp
index 489f4d1bc..ba6e3fbed 100644
--- a/neural_speed/models/mpt/mpt_utils.cpp
+++ b/neural_speed/models/mpt/mpt_utils.cpp
@@ -76,7 +76,7 @@ void MPT::init(const char* path_model, model_context* ctx, int n_gpu_layer_, boo
n_embd = hparams.n_embd;
n_vocab = hparams.n_vocab;
n_layer = hparams.n_layer;
- scratch = mpt_mem_req(n_layer);
+ scratch = mpt_mem_req(n_layer, lctx.scratch_size_ratio);
model.scratchs = scratch;
}
@@ -89,7 +89,7 @@ void MPT::load(model_context* ctx, model_progress_callback progress_callback, vo
size_t ctx_size;
size_t mmapped_size;
ml->calc_sizes(&ctx_size, &mmapped_size);
- fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
// create the ne context
lctx.model.buf.resize(ctx_size);
@@ -175,6 +175,9 @@ void MPT::load(model_context* ctx, model_progress_callback progress_callback, vo
// this is the total memory required to run the inference
const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory
scratch.scratch0 + scratch.scratch1 + scratch.eval;
+ fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0);
fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0);
(void)n_gpu_layer;
diff --git a/neural_speed/models/opt/opt.h b/neural_speed/models/opt/opt.h
index 10668653a..2590d9be5 100644
--- a/neural_speed/models/opt/opt.h
+++ b/neural_speed/models/opt/opt.h
@@ -34,20 +34,44 @@ enum opt_model {
};
// TODO naive memory buffer size
-static const model_scratch opt_mem_req(int n_layers) {
+static const model_scratch opt_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
switch (n_layers) {
case 12: // OPT_125M
- return {512ull * MB, 512ull * MB, 1024ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 512) * MB,
+ static_cast(scratch_size_ratio * 512) * MB,
+ static_cast(scratch_size_ratio * 1024) * MB,
+ };
case 24: // OPT_350M, OPT_1DOT3B
- return {1024ull * MB, 1024ull * MB, 2048ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 1024) * MB,
+ static_cast(scratch_size_ratio * 1024) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ };
case 32: // OPT_2DOT7B OPT_6DOT7B
- return {2048ull * MB, 2048ull * MB, 4096ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ };
case 40:
- return {2560ull * MB, 2560ull * MB, 5120ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 2560) * MB,
+ static_cast(scratch_size_ratio * 2560) * MB,
+ static_cast(scratch_size_ratio * 5120) * MB,
+ };
case 48:
- return {3072ull * MB, 3072ull * MB, 6144ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 3072) * MB,
+ static_cast(scratch_size_ratio * 3072) * MB,
+ static_cast(scratch_size_ratio * 6144) * MB,
+ };
case 64:
- return {4096ull * MB, 4096ull * MB, 8192ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 8192) * MB,
+ };
default:
MODEL_ASSERT(false);
}
diff --git a/neural_speed/models/opt/opt_utils.cpp b/neural_speed/models/opt/opt_utils.cpp
index 4be7c2472..2641bcb51 100644
--- a/neural_speed/models/opt/opt_utils.cpp
+++ b/neural_speed/models/opt/opt_utils.cpp
@@ -74,7 +74,7 @@ void OPT::init(const char* path_model, model_context* ctx, int n_gpu_layer_, boo
word_embed_proj_dim = hparams.word_embed_proj_dim;
max_seq_len = hparams.max_seq_len;
do_layer_norm_before = hparams.do_layer_norm_before;
- scratch = opt_mem_req(n_layer);
+ scratch = opt_mem_req(n_layer, lctx.scratch_size_ratio);
model.scratchs = scratch;
}
@@ -87,7 +87,7 @@ void OPT::load(model_context* ctx, model_progress_callback progress_callback, vo
size_t ctx_size;
size_t mmapped_size;
ml->calc_sizes(&ctx_size, &mmapped_size);
- fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
// create the ne context
lctx.model.buf.resize(ctx_size);
@@ -168,6 +168,9 @@ void OPT::load(model_context* ctx, model_progress_callback progress_callback, vo
// this is the total memory required to run the inference
const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory
scratch.scratch0 + scratch.scratch1 + scratch.eval;
+ fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0);
fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0);
(void)n_gpu_layer;
diff --git a/neural_speed/models/phi/phi.h b/neural_speed/models/phi/phi.h
index b767c34bf..8a3618d62 100644
--- a/neural_speed/models/phi/phi.h
+++ b/neural_speed/models/phi/phi.h
@@ -23,12 +23,20 @@ enum new_model {
PHI,
};
-static const model_scratch phi_mem_req(int n_layers) {
+static const model_scratch phi_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
switch (n_layers) {
case 24:
- return {512ull * MB, 512ull * MB, 1026ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 512) * MB,
+ static_cast(scratch_size_ratio * 512) * MB,
+ static_cast(scratch_size_ratio * 1024) * MB,
+ };
case 32:
- return {1024ull * MB, 1024ull * MB, 1026ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 1024) * MB,
+ static_cast(scratch_size_ratio * 1024) * MB,
+ static_cast(scratch_size_ratio * 1024) * MB,
+ };
default:
MODEL_ASSERT(false);
}
diff --git a/neural_speed/models/phi/phi_utils.cpp b/neural_speed/models/phi/phi_utils.cpp
index 16ac7e3d7..018318db2 100644
--- a/neural_speed/models/phi/phi_utils.cpp
+++ b/neural_speed/models/phi/phi_utils.cpp
@@ -74,7 +74,7 @@ void phi::init(const char* path_model, model_context* ctx, int n_gpu_layer_, boo
n_vocab = hparams.n_vocab;
n_layer = hparams.n_layer;
n_embd = hparams.n_embd;
- scratch = phi_mem_req(n_layer);
+ scratch = phi_mem_req(n_layer, lctx.scratch_size_ratio);
model.scratchs = scratch;
}
@@ -87,7 +87,7 @@ void phi::load(model_context* ctx, model_progress_callback progress_callback, vo
size_t ctx_size;
size_t mmapped_size;
ml->calc_sizes(&ctx_size, &mmapped_size);
- fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
// create the ne context
lctx.model.buf.resize(ctx_size);
@@ -159,6 +159,9 @@ void phi::load(model_context* ctx, model_progress_callback progress_callback, vo
// this is the total memory required to run the inference
const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory
scratch.scratch0 + scratch.scratch1 + scratch.eval;
+ fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0);
fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0);
(void)n_gpu_layer;
diff --git a/neural_speed/models/qwen/qwen.h b/neural_speed/models/qwen/qwen.h
index ec2c65357..3fb54b7c6 100644
--- a/neural_speed/models/qwen/qwen.h
+++ b/neural_speed/models/qwen/qwen.h
@@ -24,14 +24,26 @@ enum QWEN_model {
QWEN_14B,
};
-static const model_scratch qwen_mem_req(int n_layers) {
+static const model_scratch qwen_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
switch (n_layers) {
case 40:
- return {2048ull * MB, 2048ull * MB, 4096ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ };
case 32:
- return {1024ull * MB, 1024ull * MB, 1608ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ };
case 24:
- return {512ull * MB, 512ull * MB, 1026ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 4096) * MB,
+ static_cast(scratch_size_ratio * 2048) * MB,
+ static_cast(scratch_size_ratio * 4096) * MB,
+ };
default:
MODEL_ASSERT(false);
}
diff --git a/neural_speed/models/qwen/qwen_utils.cpp b/neural_speed/models/qwen/qwen_utils.cpp
index 77bba671b..32e222d76 100644
--- a/neural_speed/models/qwen/qwen_utils.cpp
+++ b/neural_speed/models/qwen/qwen_utils.cpp
@@ -70,7 +70,7 @@ void QWEN::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bo
n_embd = hparams.n_embd;
n_vocab = hparams.n_vocab;
n_layer = hparams.n_layer;
- scratch = qwen_mem_req(n_layer);
+ scratch = qwen_mem_req(n_layer, lctx.scratch_size_ratio);
model.scratchs = scratch;
}
@@ -83,7 +83,7 @@ void QWEN::load(model_context* ctx, model_progress_callback progress_callback, v
size_t ctx_size;
size_t mmapped_size;
ml->calc_sizes(&ctx_size, &mmapped_size);
- fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
// create the ne context
lctx.model.buf.resize(ctx_size);
@@ -201,6 +201,9 @@ void QWEN::load(model_context* ctx, model_progress_callback progress_callback, v
// this is the total memory required to run the inference
const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory
scratch.scratch0 + scratch.scratch1 + scratch.eval;
+ fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0);
fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0);
(void)n_gpu_layer;
diff --git a/neural_speed/models/starcoder/starcoder.h b/neural_speed/models/starcoder/starcoder.h
index 72ab9234c..5d4317830 100644
--- a/neural_speed/models/starcoder/starcoder.h
+++ b/neural_speed/models/starcoder/starcoder.h
@@ -26,14 +26,26 @@ enum starcoder_model {
STARCODER_65B,
};
-static const model_scratch starcoder_mem_req(int n_layers) {
+static const model_scratch starcoder_mem_req(int n_layers, float scratch_size_ratio = 1.0f) {
switch (n_layers) {
case 24:
- return {8192ull * MB, 8192ull * MB, 8192ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 3072 * 2) * MB,
+ static_cast(scratch_size_ratio * 2048 * 2) * MB,
+ static_cast(scratch_size_ratio * 3072 * 2) * MB,
+ };
case 36:
- return {8192ull * MB, 8192ull * MB, 8192ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 3072 * 2) * MB,
+ static_cast(scratch_size_ratio * 2048 * 2) * MB,
+ static_cast(scratch_size_ratio * 3072 * 2) * MB,
+ };
case 40:
- return {32768ull * MB, 32768ull * MB, 32768ull * MB};
+ return {
+ static_cast(scratch_size_ratio * 3072 * 8) * MB,
+ static_cast(scratch_size_ratio * 2048 * 8) * MB,
+ static_cast(scratch_size_ratio * 3072 * 8) * MB,
+ };
default:
MODEL_ASSERT(false);
}
diff --git a/neural_speed/models/starcoder/starcoder_utils.cpp b/neural_speed/models/starcoder/starcoder_utils.cpp
index 13ad06afe..6d8d142fb 100644
--- a/neural_speed/models/starcoder/starcoder_utils.cpp
+++ b/neural_speed/models/starcoder/starcoder_utils.cpp
@@ -75,7 +75,7 @@ void STARCODER::init(const char* path_model, model_context* ctx, int n_gpu_layer
n_embd = hparams.n_embd;
n_vocab = hparams.n_vocab;
n_layer = hparams.n_layer;
- scratch = starcoder_mem_req(n_layer);
+ scratch = starcoder_mem_req(n_layer, lctx.scratch_size_ratio);
model.scratchs = scratch;
}
@@ -88,7 +88,7 @@ void STARCODER::load(model_context* ctx, model_progress_callback progress_callba
size_t ctx_size;
size_t mmapped_size;
ml->calc_sizes(&ctx_size, &mmapped_size);
- fprintf(stderr, "%s: ne ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: ctx size = %7.2f MB\n", __func__, ctx_size / 1024.0 / 1024.0);
// create the ne context
lctx.model.buf.resize(ctx_size);
@@ -203,6 +203,9 @@ void STARCODER::load(model_context* ctx, model_progress_callback progress_callba
// this is the total memory required to run the inference
const size_t mem_required = ctx_size + mmapped_size - vram_total + // weights in VRAM not in memory
scratch.scratch0 + scratch.scratch1 + scratch.eval;
+ fprintf(stderr, "%s: scratch0 = %7.2f MB\n", __func__, scratch.scratch0 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch1 = %7.2f MB\n", __func__, scratch.scratch1 / 1024.0 / 1024.0);
+ fprintf(stderr, "%s: scratch2 = %7.2f MB\n", __func__, scratch.eval / 1024.0 / 1024.0);
fprintf(stderr, "%s: mem required = %7.2f MB (+ memory per state)\n", __func__, mem_required / 1024.0 / 1024.0);
(void)n_gpu_layer;
diff --git a/scripts/python_api_example_for_model_server.py b/scripts/python_api_example_for_model_server.py
index c9889ef70..e0168ca62 100644
--- a/scripts/python_api_example_for_model_server.py
+++ b/scripts/python_api_example_for_model_server.py
@@ -34,7 +34,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
help="maximum number of running requests (or queries) for model inference: Int",
required=False, default=8)
parser.add_argument("--print_log", action="store_true", help="print server running logs")
- parser.add_argument("--model_scratch_enlarge_scale", type=float,
+ parser.add_argument("--scratch_size_ratio", type=float,
help="scale for enlarge memory for model inference: Float",
required=False, default=1.0)
parser.add_argument("--memory_dtype", type=str, help="KV cache memory dtype: String",
@@ -86,7 +86,7 @@ def f_response(res, working):
threads=args.threads,
max_request_num=args.max_request_num,
print_log=args.print_log,
- model_scratch_enlarge_scale = args.model_scratch_enlarge_scale,
+ scratch_size_ratio = args.scratch_size_ratio,
memory_dtype= args.memory_dtype,
)
for i in range(len(prompts)):
diff --git a/tests/test_model_server.py b/tests/test_model_server.py
index d7e7fd102..3ec66e8ef 100644
--- a/tests/test_model_server.py
+++ b/tests/test_model_server.py
@@ -85,7 +85,7 @@ def f_response(res, working):
max_request_num=8,
threads=56,
print_log=False,
- model_scratch_enlarge_scale = 1.0,
+ scratch_size_ratio = 1.0,
memory_dtype= md,
)
for i in range(len(prompts)):