From 7d0b2e3318184b6bd86f07a07d5e3edbe2c0fa18 Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Thu, 23 May 2024 08:37:46 +0700 Subject: [PATCH] fix: flash attention param typo (#50) Co-authored-by: vansangpfiev --- README.md | 2 +- src/llama_engine.cc | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 0cb9da9b..e346674b 100644 --- a/README.md +++ b/README.md @@ -145,4 +145,4 @@ Table of parameters |`model_type` | String | Model type we want to use: llm or embedding, default value is llm| |`model_alias`| String | Used as model_id if specified in request, mandatory in loadmodel| |`model` | String | Used as model_id if specified in request, mandatory in chat/embedding request| -|`flash-attn` | Boolean| To enable Flash Attention, default is false| \ No newline at end of file +|`flash_attn` | Boolean| To enable Flash Attention, default is false| \ No newline at end of file diff --git a/src/llama_engine.cc b/src/llama_engine.cc index 8ac8d87a..f04859b4 100644 --- a/src/llama_engine.cc +++ b/src/llama_engine.cc @@ -334,8 +334,11 @@ bool LlamaEngine::LoadModelImpl(std::shared_ptr jsonBody) { jsonBody->get("cpu_threads", std::thread::hardware_concurrency()) .asInt(); params.cont_batching = jsonBody->get("cont_batching", false).asBool(); - params.flash_attn = jsonBody->get("flash-attn", false).asBool(); - if(params.flash_attn) { + // Check for backward compatible + auto fa0 = jsonBody->get("flash-attn", false).asBool(); + auto fa1 = jsonBody->get("flash_attn", false).asBool(); + params.flash_attn = fa0 || fa1; + if (params.flash_attn) { LOG_DEBUG << "Enabled Flash Attention"; } server_map_[model_id].caching_enabled =