diff --git a/neural_speed/application/main_pybind.cpp b/neural_speed/application/main_pybind.cpp index 07e8e532b..6390eda54 100644 --- a/neural_speed/application/main_pybind.cpp +++ b/neural_speed/application/main_pybind.cpp @@ -533,12 +533,12 @@ const std::vector& Model::evaluate_(const std::vector n_ctx - n_keep) { // long input_id_cb and empty curr_input_ids[bs] + } else if (input_id_cb.size() > n_ctx - params.n_keep) { // long input_id_cb and empty curr_input_ids[bs] fprintf(stderr, "\n%s: Warning: prompt is too long (%zu tokens, max %d), will be truncated\n", __func__, - input_id_cb.size(), n_ctx - n_keep); - curr_input_ids[bs].resize(n_ctx - n_keep); - std::copy(input_id_cb.end() - n_ctx - n_keep * 2, input_id_cb.end(), curr_input_ids[bs].begin() + n_keep); - std::copy(input_id_cb.begin(), input_id_cb.begin() + n_keep, curr_input_ids[bs].begin()); + input_id_cb.size(), n_ctx - params.n_keep); + curr_input_ids[bs].resize(n_ctx - params.n_keep); + std::copy(input_id_cb.end() - n_ctx - params.n_keep * 2, input_id_cb.end(), curr_input_ids[bs].begin() + params.n_keep); + std::copy(input_id_cb.begin(), input_id_cb.begin() + params.n_keep, curr_input_ids[bs].begin()); } else { // good input_id_cb and empty curr_input_ids[bs] curr_input_ids[bs] = input_id_cb; } @@ -648,13 +648,13 @@ std::vector> Model::generate_tokens(const std::vector n_ctx - n_keep) { + if (input_ids[STATIC_INPUT_HEAD_IDX].size() > n_ctx - params.n_keep) { fprintf(stderr, "\n%s: Warning: prompt is too long (%zu tokens, max %d), will be truncated\n", __func__, - input_ids[STATIC_INPUT_HEAD_IDX].size(), n_ctx - n_keep); - curr_input_ids[STATIC_INPUT_HEAD_IDX].resize(n_ctx - n_keep); - std::copy(input_ids[STATIC_INPUT_HEAD_IDX].end() - n_ctx - n_keep * 2, input_ids[STATIC_INPUT_HEAD_IDX].end(), - curr_input_ids[STATIC_INPUT_HEAD_IDX].begin() + n_keep); - std::copy(input_ids[STATIC_INPUT_HEAD_IDX].begin(), input_ids[STATIC_INPUT_HEAD_IDX].begin() + n_keep, + input_ids[STATIC_INPUT_HEAD_IDX].size(), n_ctx - params.n_keep); + curr_input_ids[STATIC_INPUT_HEAD_IDX].resize(n_ctx - params.n_keep); + std::copy(input_ids[STATIC_INPUT_HEAD_IDX].end() - n_ctx - params.n_keep * 2, input_ids[STATIC_INPUT_HEAD_IDX].end(), + curr_input_ids[STATIC_INPUT_HEAD_IDX].begin() + params.n_keep); + std::copy(input_ids[STATIC_INPUT_HEAD_IDX].begin(), input_ids[STATIC_INPUT_HEAD_IDX].begin() + params.n_keep, curr_input_ids[STATIC_INPUT_HEAD_IDX].begin()); } else { curr_input_ids[STATIC_INPUT_HEAD_IDX] = input_ids[STATIC_INPUT_HEAD_IDX]; diff --git a/neural_speed/application/main_run.cpp b/neural_speed/application/main_run.cpp index 7e500919e..691e468b3 100644 --- a/neural_speed/application/main_run.cpp +++ b/neural_speed/application/main_run.cpp @@ -241,9 +241,9 @@ int main(int argc, char** argv) { // NOLINT const int n_ctx = model_n_ctx(ctx); - if (static_cast(embd_inp.size()) > n_ctx - n_keep) { + if (static_cast(embd_inp.size()) > n_ctx - params.n_keep) { fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, static_cast(embd_inp.size()), - n_ctx - n_keep); + n_ctx - params.n_keep); return 1; } diff --git a/neural_speed/application/pybind_gptj.cpp b/neural_speed/application/pybind_gptj.cpp index 82cdb2a83..101a14e4a 100644 --- a/neural_speed/application/pybind_gptj.cpp +++ b/neural_speed/application/pybind_gptj.cpp @@ -35,9 +35,9 @@ static model_context** g_ctx; bool gptj_model_eval_ids(model_context* ctx, model_token* tokens, size_t n_eval, size_t n_past, size_t n_threads) { const int n_ctx = model_n_ctx(ctx); - if (static_cast(n_eval) > n_ctx - n_keep) { + if (static_cast(n_eval) > n_ctx) { fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, static_cast(n_eval), - n_ctx - n_keep); + n_ctx); return true; }