diff --git a/src/onnx_engine.cc b/src/onnx_engine.cc index 915cccb..54a73d1 100644 --- a/src/onnx_engine.cc +++ b/src/onnx_engine.cc @@ -221,6 +221,8 @@ void OnnxEngine::HandleChatCompletion( params->SetInputSequences(*sequences); auto generator = OgaGenerator::Create(*oga_model_, *params); + auto start = std::chrono::system_clock::now(); + double generated_tokens = 0; while (!generator->IsDone() && model_loaded_) { generator->ComputeLogits(); generator->GenerateNextToken(); @@ -241,6 +243,7 @@ void OnnxEngine::HandleChatCompletion( status["is_stream"] = true; status["status_code"] = k200OK; cb(std::move(status), std::move(resp_data)); + generated_tokens++; } if (!model_loaded_) { @@ -255,6 +258,12 @@ void OnnxEngine::HandleChatCompletion( cb(std::move(status), std::move(respData)); return; } + auto end = std::chrono::system_clock::now(); + auto duration_ms = + std::chrono::duration_cast(end - start) + .count(); + LOG_DEBUG << "Generated tokens per second: " + << generated_tokens / duration_ms * 1000; LOG_INFO << "End of result"; Json::Value resp_data; @@ -281,6 +290,7 @@ void OnnxEngine::HandleChatCompletion( // params->SetSearchOption("repetition_penalty", req.frequency_penalty); params->SetInputSequences(*sequences); + auto start = std::chrono::system_clock::now(); auto output_sequences = oga_model_->Generate(*params); const auto output_sequence_length = output_sequences->SequenceCount(0) - sequences->SequenceCount(0); @@ -290,6 +300,12 @@ void OnnxEngine::HandleChatCompletion( tokenizer_->Decode(output_sequence_data, output_sequence_length); // std::cout << "Output: " << std::endl << out_string << std::endl; + auto end = std::chrono::system_clock::now(); + auto duration_ms = + std::chrono::duration_cast(end - start) + .count(); + LOG_DEBUG << "Generated tokens per second: " + << static_cast(output_sequence_length) / duration_ms * 1000; std::string to_send = out_string.p_; auto resp_data = CreateFullReturnJson(GenerateRandomString(20), "_",