Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
-4 -> params.n_keep
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenzhong1 committed Mar 7, 2024
1 parent c6b5fcd commit 7d4e483
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 15 deletions.
22 changes: 11 additions & 11 deletions neural_speed/application/main_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,12 +533,12 @@ const std::vector<float>& Model::evaluate_(const std::vector<std::vector<model_t
} else if (!curr_input_ids[bs].empty()) {
fprintf(stderr, "%s: error: prompt confliction\n", __func__);
return empty_ret;
} else if (input_id_cb.size() > n_ctx - n_keep) { // long input_id_cb and empty curr_input_ids[bs]
} else if (input_id_cb.size() > n_ctx - params.n_keep) { // long input_id_cb and empty curr_input_ids[bs]
fprintf(stderr, "\n%s: Warning: prompt is too long (%zu tokens, max %d), will be truncated\n", __func__,
input_id_cb.size(), n_ctx - n_keep);
curr_input_ids[bs].resize(n_ctx - n_keep);
std::copy(input_id_cb.end() - n_ctx - n_keep * 2, input_id_cb.end(), curr_input_ids[bs].begin() + n_keep);
std::copy(input_id_cb.begin(), input_id_cb.begin() + n_keep, curr_input_ids[bs].begin());
input_id_cb.size(), n_ctx - params.n_keep);
curr_input_ids[bs].resize(n_ctx - params.n_keep);
std::copy(input_id_cb.end() - n_ctx - params.n_keep * 2, input_id_cb.end(), curr_input_ids[bs].begin() + params.n_keep);
std::copy(input_id_cb.begin(), input_id_cb.begin() + params.n_keep, curr_input_ids[bs].begin());
} else { // good input_id_cb and empty curr_input_ids[bs]
curr_input_ids[bs] = input_id_cb;
}
Expand Down Expand Up @@ -648,13 +648,13 @@ std::vector<std::vector<model_token>> Model::generate_tokens(const std::vector<s
}

if (curr_input_ids[STATIC_INPUT_HEAD_IDX].empty()) {
if (input_ids[STATIC_INPUT_HEAD_IDX].size() > n_ctx - n_keep) {
if (input_ids[STATIC_INPUT_HEAD_IDX].size() > n_ctx - params.n_keep) {
fprintf(stderr, "\n%s: Warning: prompt is too long (%zu tokens, max %d), will be truncated\n", __func__,
input_ids[STATIC_INPUT_HEAD_IDX].size(), n_ctx - n_keep);
curr_input_ids[STATIC_INPUT_HEAD_IDX].resize(n_ctx - n_keep);
std::copy(input_ids[STATIC_INPUT_HEAD_IDX].end() - n_ctx - n_keep * 2, input_ids[STATIC_INPUT_HEAD_IDX].end(),
curr_input_ids[STATIC_INPUT_HEAD_IDX].begin() + n_keep);
std::copy(input_ids[STATIC_INPUT_HEAD_IDX].begin(), input_ids[STATIC_INPUT_HEAD_IDX].begin() + n_keep,
input_ids[STATIC_INPUT_HEAD_IDX].size(), n_ctx - params.n_keep);
curr_input_ids[STATIC_INPUT_HEAD_IDX].resize(n_ctx - params.n_keep);
std::copy(input_ids[STATIC_INPUT_HEAD_IDX].end() - n_ctx - params.n_keep * 2, input_ids[STATIC_INPUT_HEAD_IDX].end(),
curr_input_ids[STATIC_INPUT_HEAD_IDX].begin() + params.n_keep);
std::copy(input_ids[STATIC_INPUT_HEAD_IDX].begin(), input_ids[STATIC_INPUT_HEAD_IDX].begin() + params.n_keep,
curr_input_ids[STATIC_INPUT_HEAD_IDX].begin());
} else {
curr_input_ids[STATIC_INPUT_HEAD_IDX] = input_ids[STATIC_INPUT_HEAD_IDX];
Expand Down
4 changes: 2 additions & 2 deletions neural_speed/application/main_run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,9 @@ int main(int argc, char** argv) { // NOLINT

const int n_ctx = model_n_ctx(ctx);

if (static_cast<int>(embd_inp.size()) > n_ctx - n_keep) {
if (static_cast<int>(embd_inp.size()) > n_ctx - params.n_keep) {
fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, static_cast<int>(embd_inp.size()),
n_ctx - n_keep);
n_ctx - params.n_keep);
return 1;
}

Expand Down
4 changes: 2 additions & 2 deletions neural_speed/application/pybind_gptj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ static model_context** g_ctx;

bool gptj_model_eval_ids(model_context* ctx, model_token* tokens, size_t n_eval, size_t n_past, size_t n_threads) {
const int n_ctx = model_n_ctx(ctx);
if (static_cast<int>(n_eval) > n_ctx - n_keep) {
if (static_cast<int>(n_eval) > n_ctx) {
fprintf(stderr, "%s: error: prompt is too long (%d tokens, max %d)\n", __func__, static_cast<int>(n_eval),
n_ctx - n_keep);
n_ctx);
return true;
}

Expand Down

0 comments on commit 7d4e483

Please sign in to comment.