diff --git a/.github/workflows/ci_lin.yml b/.github/workflows/ci_lin.yml
index 31353ade9..5b11f1c5f 100644
--- a/.github/workflows/ci_lin.yml
+++ b/.github/workflows/ci_lin.yml
@@ -31,6 +31,7 @@ jobs:
- test_suite: features
- test_suite: api_coverage
- test_suite: behavior_tests
+ - test_suite: applications
steps:
- name: Process ENV
id: process_env
diff --git a/applications/applications.xml b/applications/applications.xml
new file mode 100644
index 000000000..2b3379422
--- /dev/null
+++ b/applications/applications.xml
@@ -0,0 +1,8 @@
+
+
+
+ applications test
+
+
+
+
diff --git a/applications/config/TEMPLATE_llama.xml b/applications/config/TEMPLATE_llama.xml
new file mode 100644
index 000000000..d9be47442
--- /dev/null
+++ b/applications/config/TEMPLATE_llama.xml
@@ -0,0 +1,12 @@
+
+
+
+ test
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/applications/src/llama/LICENSE b/applications/src/llama/LICENSE
new file mode 100644
index 000000000..76f67efdc
--- /dev/null
+++ b/applications/src/llama/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023 Georgi Gerganov
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/applications/src/llama/common/base64.hpp b/applications/src/llama/common/base64.hpp
new file mode 100644
index 000000000..563247a6e
--- /dev/null
+++ b/applications/src/llama/common/base64.hpp
@@ -0,0 +1,392 @@
+/*
+This is free and unencumbered software released into the public domain.
+
+Anyone is free to copy, modify, publish, use, compile, sell, or
+distribute this software, either in source code form or as a compiled
+binary, for any purpose, commercial or non-commercial, and by any
+means.
+
+In jurisdictions that recognize copyright laws, the author or authors
+of this software dedicate any and all copyright interest in the
+software to the public domain. We make this dedication for the benefit
+of the public at large and to the detriment of our heirs and
+successors. We intend this dedication to be an overt act of
+relinquishment in perpetuity of all present and future rights to this
+software under copyright law.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
+IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR
+OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
+ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+OTHER DEALINGS IN THE SOFTWARE.
+
+For more information, please refer to
+*/
+
+#ifndef PUBLIC_DOMAIN_BASE64_HPP_
+#define PUBLIC_DOMAIN_BASE64_HPP_
+
+#include
+#include
+#include
+#include
+
+class base64_error : public std::runtime_error
+{
+public:
+ using std::runtime_error::runtime_error;
+};
+
+class base64
+{
+public:
+ enum class alphabet
+ {
+ /** the alphabet is detected automatically */
+ auto_,
+ /** the standard base64 alphabet is used */
+ standard,
+ /** like `standard` except that the characters `+` and `/` are replaced by `-` and `_` respectively*/
+ url_filename_safe
+ };
+
+ enum class decoding_behavior
+ {
+ /** if the input is not padded, the remaining bits are ignored */
+ moderate,
+ /** if a padding character is encounter decoding is finished */
+ loose
+ };
+
+ /**
+ Encodes all the elements from `in_begin` to `in_end` to `out`.
+
+ @warning The source and destination cannot overlap. The destination must be able to hold at least
+ `required_encode_size(std::distance(in_begin, in_end))`, otherwise the behavior depends on the output iterator.
+
+ @tparam Input_iterator the source; the returned elements are cast to `std::uint8_t` and should not be greater than
+ 8 bits
+ @tparam Output_iterator the destination; the elements written to it are from the type `char`
+ @param in_begin the beginning of the source
+ @param in_end the ending of the source
+ @param out the destination iterator
+ @param alphabet which alphabet should be used
+ @returns the iterator to the next element past the last element copied
+ @throws see `Input_iterator` and `Output_iterator`
+ */
+ template
+ static Output_iterator encode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out,
+ alphabet alphabet = alphabet::standard)
+ {
+ constexpr auto pad = '=';
+ const char* alpha = alphabet == alphabet::url_filename_safe
+ ? "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
+ : "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+
+ while (in_begin != in_end) {
+ std::uint8_t i0 = 0, i1 = 0, i2 = 0;
+
+ // first character
+ i0 = static_cast(*in_begin);
+ ++in_begin;
+
+ *out = alpha[i0 >> 2 & 0x3f];
+ ++out;
+
+ // part of first character and second
+ if (in_begin != in_end) {
+ i1 = static_cast(*in_begin);
+ ++in_begin;
+
+ *out = alpha[((i0 & 0x3) << 4) | (i1 >> 4 & 0x0f)];
+ ++out;
+ } else {
+ *out = alpha[(i0 & 0x3) << 4];
+ ++out;
+
+ // last padding
+ *out = pad;
+ ++out;
+
+ // last padding
+ *out = pad;
+ ++out;
+
+ break;
+ }
+
+ // part of second character and third
+ if (in_begin != in_end) {
+ i2 = static_cast(*in_begin);
+ ++in_begin;
+
+ *out = alpha[((i1 & 0xf) << 2) | (i2 >> 6 & 0x03)];
+ ++out;
+ } else {
+ *out = alpha[(i1 & 0xf) << 2];
+ ++out;
+
+ // last padding
+ *out = pad;
+ ++out;
+
+ break;
+ }
+
+ // rest of third
+ *out = alpha[i2 & 0x3f];
+ ++out;
+ }
+
+ return out;
+ }
+ /**
+ Encodes a string.
+
+ @param str the string that should be encoded
+ @param alphabet which alphabet should be used
+ @returns the encoded base64 string
+ @throws see base64::encode()
+ */
+ static std::string encode(const std::string& str, alphabet alphabet = alphabet::standard)
+ {
+ std::string result;
+
+ result.reserve(required_encode_size(str.length()) + 1);
+
+ encode(str.begin(), str.end(), std::back_inserter(result), alphabet);
+
+ return result;
+ }
+ /**
+ Encodes a char array.
+
+ @param buffer the char array
+ @param size the size of the array
+ @param alphabet which alphabet should be used
+ @returns the encoded string
+ */
+ static std::string encode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::standard)
+ {
+ std::string result;
+
+ result.reserve(required_encode_size(size) + 1);
+
+ encode(buffer, buffer + size, std::back_inserter(result), alphabet);
+
+ return result;
+ }
+ /**
+ Decodes all the elements from `in_begin` to `in_end` to `out`. `in_begin` may point to the same location as `out`,
+ in other words: inplace decoding is possible.
+
+ @warning The destination must be able to hold at least `required_decode_size(std::distance(in_begin, in_end))`,
+ otherwise the behavior depends on the output iterator.
+
+ @tparam Input_iterator the source; the returned elements are cast to `char`
+ @tparam Output_iterator the destination; the elements written to it are from the type `std::uint8_t`
+ @param in_begin the beginning of the source
+ @param in_end the ending of the source
+ @param out the destination iterator
+ @param alphabet which alphabet should be used
+ @param behavior the behavior when an error was detected
+ @returns the iterator to the next element past the last element copied
+ @throws base64_error depending on the set behavior
+ @throws see `Input_iterator` and `Output_iterator`
+ */
+ template
+ static Output_iterator decode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out,
+ alphabet alphabet = alphabet::auto_,
+ decoding_behavior behavior = decoding_behavior::moderate)
+ {
+ //constexpr auto pad = '=';
+ std::uint8_t last = 0;
+ auto bits = 0;
+
+ while (in_begin != in_end) {
+ auto c = *in_begin;
+ ++in_begin;
+
+ if (c == '=') {
+ break;
+ }
+
+ auto part = _base64_value(alphabet, c);
+
+ // enough bits for one byte
+ if (bits + 6 >= 8) {
+ *out = (last << (8 - bits)) | (part >> (bits - 2));
+ ++out;
+
+ bits -= 2;
+ } else {
+ bits += 6;
+ }
+
+ last = part;
+ }
+
+ // check padding
+ if (behavior != decoding_behavior::loose) {
+ while (in_begin != in_end) {
+ auto c = *in_begin;
+ ++in_begin;
+
+ if (c != '=') {
+ throw base64_error("invalid base64 character.");
+ }
+ }
+ }
+
+ return out;
+ }
+ /**
+ Decodes a string.
+
+ @param str the base64 encoded string
+ @param alphabet which alphabet should be used
+ @param behavior the behavior when an error was detected
+ @returns the decoded string
+ @throws see base64::decode()
+ */
+ static std::string decode(const std::string& str, alphabet alphabet = alphabet::auto_,
+ decoding_behavior behavior = decoding_behavior::moderate)
+ {
+ std::string result;
+
+ result.reserve(max_decode_size(str.length()));
+
+ decode(str.begin(), str.end(), std::back_inserter(result), alphabet, behavior);
+
+ return result;
+ }
+ /**
+ Decodes a string.
+
+ @param buffer the base64 encoded buffer
+ @param size the size of the buffer
+ @param alphabet which alphabet should be used
+ @param behavior the behavior when an error was detected
+ @returns the decoded string
+ @throws see base64::decode()
+ */
+ static std::string decode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::auto_,
+ decoding_behavior behavior = decoding_behavior::moderate)
+ {
+ std::string result;
+
+ result.reserve(max_decode_size(size));
+
+ decode(buffer, buffer + size, std::back_inserter(result), alphabet, behavior);
+
+ return result;
+ }
+ /**
+ Decodes a string inplace.
+
+ @param[in,out] str the base64 encoded string
+ @param alphabet which alphabet should be used
+ @param behavior the behavior when an error was detected
+ @throws base64::decode_inplace()
+ */
+ static void decode_inplace(std::string& str, alphabet alphabet = alphabet::auto_,
+ decoding_behavior behavior = decoding_behavior::moderate)
+ {
+ str.resize(decode(str.begin(), str.end(), str.begin(), alphabet, behavior) - str.begin());
+ }
+ /**
+ Decodes a char array inplace.
+
+ @param[in,out] str the string array
+ @param size the length of the array
+ @param alphabet which alphabet should be used
+ @param behavior the behavior when an error was detected
+ @returns the pointer to the next element past the last element decoded
+ @throws base64::decode_inplace()
+ */
+ static char* decode_inplace(char* str, std::size_t size, alphabet alphabet = alphabet::auto_,
+ decoding_behavior behavior = decoding_behavior::moderate)
+ {
+ return decode(str, str + size, str, alphabet, behavior);
+ }
+ /**
+ Returns the required decoding size for a given size. The value is calculated with the following formula:
+
+ $$
+ \lceil \frac{size}{4} \rceil \cdot 3
+ $$
+
+ @param size the size of the encoded input
+ @returns the size of the resulting decoded buffer; this the absolute maximum
+ */
+ static std::size_t max_decode_size(std::size_t size) noexcept
+ {
+ return (size / 4 + (size % 4 ? 1 : 0)) * 3;
+ }
+ /**
+ Returns the required encoding size for a given size. The value is calculated with the following formula:
+
+ $$
+ \lceil \frac{size}{3} \rceil \cdot 4
+ $$
+
+ @param size the size of the decoded input
+ @returns the size of the resulting encoded buffer
+ */
+ static std::size_t required_encode_size(std::size_t size) noexcept
+ {
+ return (size / 3 + (size % 3 ? 1 : 0)) * 4;
+ }
+
+private:
+ static std::uint8_t _base64_value(alphabet& alphabet, char c)
+ {
+ if (c >= 'A' && c <= 'Z') {
+ return c - 'A';
+ } else if (c >= 'a' && c <= 'z') {
+ return c - 'a' + 26;
+ } else if (c >= '0' && c <= '9') {
+ return c - '0' + 52;
+ }
+
+ // comes down to alphabet
+ if (alphabet == alphabet::standard) {
+ if (c == '+') {
+ return 62;
+ } else if (c == '/') {
+ return 63;
+ }
+ } else if (alphabet == alphabet::url_filename_safe) {
+ if (c == '-') {
+ return 62;
+ } else if (c == '_') {
+ return 63;
+ }
+ } // auto detect
+ else {
+ if (c == '+') {
+ alphabet = alphabet::standard;
+
+ return 62;
+ } else if (c == '/') {
+ alphabet = alphabet::standard;
+
+ return 63;
+ } else if (c == '-') {
+ alphabet = alphabet::url_filename_safe;
+
+ return 62;
+ } else if (c == '_') {
+ alphabet = alphabet::url_filename_safe;
+
+ return 63;
+ }
+ }
+
+ throw base64_error("invalid base64 character.");
+ }
+};
+
+#endif // !PUBLIC_DOMAIN_BASE64_HPP_
diff --git a/applications/src/llama/common/build-info.cpp b/applications/src/llama/common/build-info.cpp
new file mode 100644
index 000000000..b1632a1e3
--- /dev/null
+++ b/applications/src/llama/common/build-info.cpp
@@ -0,0 +1,4 @@
+int LLAMA_BUILD_NUMBER = 1657;
+char const *LLAMA_COMMIT = "3c04bf6";
+char const *LLAMA_COMPILER = "cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0";
+char const *LLAMA_BUILD_TARGET = "x86_64-linux-gnu";
diff --git a/applications/src/llama/common/build-info.cpp.in b/applications/src/llama/common/build-info.cpp.in
new file mode 100644
index 000000000..0b945aa68
--- /dev/null
+++ b/applications/src/llama/common/build-info.cpp.in
@@ -0,0 +1,4 @@
+int LLAMA_BUILD_NUMBER = @BUILD_NUMBER@;
+char const *LLAMA_COMMIT = "@BUILD_COMMIT@";
+char const *LLAMA_COMPILER = "@BUILD_COMPILER@";
+char const *LLAMA_BUILD_TARGET = "@BUILD_TARGET@";
diff --git a/applications/src/llama/common/common.cpp b/applications/src/llama/common/common.cpp
new file mode 100644
index 000000000..93d5483e4
--- /dev/null
+++ b/applications/src/llama/common/common.cpp
@@ -0,0 +1,1616 @@
+#include "common.h"
+#include "llama.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#if defined(__APPLE__) && defined(__MACH__)
+#include
+#include
+#endif
+
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+# define NOMINMAX
+#endif
+#include
+#include
+#include
+#include
+#include
+#else
+#include
+#include
+#include
+#endif
+
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
+
+int32_t get_num_physical_cores() {
+#ifdef __linux__
+ // enumerate the set of thread siblings, num entries is num cores
+ std::unordered_set siblings;
+ for (uint32_t cpu=0; cpu < UINT32_MAX; ++cpu) {
+ std::ifstream thread_siblings("/sys/devices/system/cpu"
+ + std::to_string(cpu) + "/topology/thread_siblings");
+ if (!thread_siblings.is_open()) {
+ break; // no more cpus
+ }
+ std::string line;
+ if (std::getline(thread_siblings, line)) {
+ siblings.insert(line);
+ }
+ }
+ if (!siblings.empty()) {
+ return static_cast(siblings.size());
+ }
+#elif defined(__APPLE__) && defined(__MACH__)
+ int32_t num_physical_cores;
+ size_t len = sizeof(num_physical_cores);
+ int result = sysctlbyname("hw.perflevel0.physicalcpu", &num_physical_cores, &len, NULL, 0);
+ if (result == 0) {
+ return num_physical_cores;
+ }
+ result = sysctlbyname("hw.physicalcpu", &num_physical_cores, &len, NULL, 0);
+ if (result == 0) {
+ return num_physical_cores;
+ }
+#elif defined(_WIN32)
+ //TODO: Implement
+#endif
+ unsigned int n_threads = std::thread::hardware_concurrency();
+ return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
+}
+
+void process_escapes(std::string& input) {
+ std::size_t input_len = input.length();
+ std::size_t output_idx = 0;
+
+ for (std::size_t input_idx = 0; input_idx < input_len; ++input_idx) {
+ if (input[input_idx] == '\\' && input_idx + 1 < input_len) {
+ switch (input[++input_idx]) {
+ case 'n': input[output_idx++] = '\n'; break;
+ case 'r': input[output_idx++] = '\r'; break;
+ case 't': input[output_idx++] = '\t'; break;
+ case '\'': input[output_idx++] = '\''; break;
+ case '\"': input[output_idx++] = '\"'; break;
+ case '\\': input[output_idx++] = '\\'; break;
+ case 'x':
+ // Handle \x12, etc
+ if (input_idx + 2 < input_len) {
+ const char x[3] = { input[input_idx + 1], input[input_idx + 2], 0 };
+ char *err_p = nullptr;
+ const long val = std::strtol(x, &err_p, 16);
+ if (err_p == x + 2) {
+ input_idx += 2;
+ input[output_idx++] = char(val);
+ break;
+ }
+ }
+ // fall through
+ default: input[output_idx++] = '\\';
+ input[output_idx++] = input[input_idx]; break;
+ }
+ } else {
+ input[output_idx++] = input[input_idx];
+ }
+ }
+
+ input.resize(output_idx);
+}
+
+bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
+ bool result = true;
+ try {
+ if (!gpt_params_parse_ex(argc, argv, params)) {
+ gpt_print_usage(argc, argv, gpt_params());
+ exit(0);
+ }
+ }
+ catch (const std::invalid_argument & ex) {
+ fprintf(stderr, "%s\n", ex.what());
+ gpt_print_usage(argc, argv, gpt_params());
+ exit(1);
+ }
+ return result;
+}
+
+bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
+ bool invalid_param = false;
+ std::string arg;
+ const std::string arg_prefix = "--";
+ llama_sampling_params & sparams = params.sparams;
+
+ for (int i = 1; i < argc; i++) {
+ arg = argv[i];
+ if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
+ std::replace(arg.begin(), arg.end(), '_', '-');
+ }
+
+ if (arg == "-s" || arg == "--seed") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.seed = std::stoul(argv[i]);
+ } else if (arg == "-t" || arg == "--threads") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_threads = std::stoi(argv[i]);
+ if (params.n_threads <= 0) {
+ params.n_threads = std::thread::hardware_concurrency();
+ }
+ } else if (arg == "-tb" || arg == "--threads-batch") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_threads_batch = std::stoi(argv[i]);
+ if (params.n_threads_batch <= 0) {
+ params.n_threads_batch = std::thread::hardware_concurrency();
+ }
+ } else if (arg == "-p" || arg == "--prompt") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.prompt = argv[i];
+ } else if (arg == "-e" || arg == "--escape") {
+ params.escape = true;
+ } else if (arg == "--prompt-cache") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.path_prompt_cache = argv[i];
+ } else if (arg == "--prompt-cache-all") {
+ params.prompt_cache_all = true;
+ } else if (arg == "--prompt-cache-ro") {
+ params.prompt_cache_ro = true;
+ } else if (arg == "-f" || arg == "--file") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ std::ifstream file(argv[i]);
+ if (!file) {
+ fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+ invalid_param = true;
+ break;
+ }
+ // store the external file name in params
+ params.prompt_file = argv[i];
+ std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt));
+ if (!params.prompt.empty() && params.prompt.back() == '\n') {
+ params.prompt.pop_back();
+ }
+ } else if (arg == "-n" || arg == "--n-predict") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_predict = std::stoi(argv[i]);
+ } else if (arg == "--top-k") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.top_k = std::stoi(argv[i]);
+ } else if (arg == "-c" || arg == "--ctx-size") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_ctx = std::stoi(argv[i]);
+ } else if (arg == "--rope-freq-base") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.rope_freq_base = std::stof(argv[i]);
+ } else if (arg == "--rope-freq-scale") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.rope_freq_scale = std::stof(argv[i]);
+ } else if (arg == "--rope-scaling") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ std::string value(argv[i]);
+ /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_NONE; }
+ else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_LINEAR; }
+ else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_YARN; }
+ else { invalid_param = true; break; }
+ } else if (arg == "--rope-scale") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.rope_freq_scale = 1.0f/std::stof(argv[i]);
+ } else if (arg == "--yarn-orig-ctx") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.yarn_orig_ctx = std::stoi(argv[i]);
+ } else if (arg == "--yarn-ext-factor") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.yarn_ext_factor = std::stof(argv[i]);
+ } else if (arg == "--yarn-attn-factor") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.yarn_attn_factor = std::stof(argv[i]);
+ } else if (arg == "--yarn-beta-fast") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.yarn_beta_fast = std::stof(argv[i]);
+ } else if (arg == "--yarn-beta-slow") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.yarn_beta_slow = std::stof(argv[i]);
+ } else if (arg == "--samplers") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.samplers_sequence = parse_samplers_input(argv[i]);
+ } else if (arg == "--sampling-seq") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.samplers_sequence = argv[i];
+ } else if (arg == "--top-p") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.top_p = std::stof(argv[i]);
+ } else if (arg == "--min-p") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.min_p = std::stof(argv[i]);
+ } else if (arg == "--temp") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.temp = std::stof(argv[i]);
+ sparams.temp = std::max(sparams.temp, 0.0f);
+ } else if (arg == "--tfs") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.tfs_z = std::stof(argv[i]);
+ } else if (arg == "--typical") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.typical_p = std::stof(argv[i]);
+ } else if (arg == "--repeat-last-n") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.penalty_last_n = std::stoi(argv[i]);
+ sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
+ } else if (arg == "--repeat-penalty") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.penalty_repeat = std::stof(argv[i]);
+ } else if (arg == "--frequency-penalty") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.penalty_freq = std::stof(argv[i]);
+ } else if (arg == "--presence-penalty") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.penalty_present = std::stof(argv[i]);
+ } else if (arg == "--mirostat") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.mirostat = std::stoi(argv[i]);
+ } else if (arg == "--mirostat-lr") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.mirostat_eta = std::stof(argv[i]);
+ } else if (arg == "--mirostat-ent") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.mirostat_tau = std::stof(argv[i]);
+ } else if (arg == "--cfg-negative-prompt") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.cfg_negative_prompt = argv[i];
+ } else if (arg == "--cfg-negative-prompt-file") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ std::ifstream file(argv[i]);
+ if (!file) {
+ fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+ invalid_param = true;
+ break;
+ }
+ std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(sparams.cfg_negative_prompt));
+ if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') {
+ sparams.cfg_negative_prompt.pop_back();
+ }
+ } else if (arg == "--cfg-scale") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.cfg_scale = std::stof(argv[i]);
+ } else if (arg == "-b" || arg == "--batch-size") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_batch = std::stoi(argv[i]);
+ } else if (arg == "--keep") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_keep = std::stoi(argv[i]);
+ } else if (arg == "--draft") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_draft = std::stoi(argv[i]);
+ } else if (arg == "--chunks") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_chunks = std::stoi(argv[i]);
+ } else if (arg == "-np" || arg == "--parallel") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_parallel = std::stoi(argv[i]);
+ } else if (arg == "-ns" || arg == "--sequences") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.n_sequences = std::stoi(argv[i]);
+ } else if (arg == "--p-accept" || arg == "-pa") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.p_accept = std::stof(argv[i]);
+ } else if (arg == "--p-split" || arg == "-ps") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.p_split = std::stof(argv[i]);
+ } else if (arg == "-m" || arg == "--model") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.model = argv[i];
+ } else if (arg == "-md" || arg == "--model-draft") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.model_draft = argv[i];
+ } else if (arg == "-a" || arg == "--alias") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.model_alias = argv[i];
+ } else if (arg == "--lora") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.lora_adapter.push_back(std::make_tuple(argv[i], 1.0f));
+ params.use_mmap = false;
+ } else if (arg == "--lora-scaled") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ const char * lora_adapter = argv[i];
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.lora_adapter.push_back(std::make_tuple(lora_adapter, std::stof(argv[i])));
+ params.use_mmap = false;
+ } else if (arg == "--lora-base") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.lora_base = argv[i];
+ } else if (arg == "--mmproj") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.mmproj = argv[i];
+ } else if (arg == "--image") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.image = argv[i];
+ } else if (arg == "-i" || arg == "--interactive") {
+ params.interactive = true;
+ } else if (arg == "--embedding") {
+ params.embedding = true;
+ } else if (arg == "--interactive-first") {
+ params.interactive_first = true;
+ } else if (arg == "-ins" || arg == "--instruct") {
+ params.instruct = true;
+ } else if (arg == "-cml" || arg == "--chatml") {
+ params.chatml = true;
+ } else if (arg == "--infill") {
+ params.infill = true;
+ } else if (arg == "-dkvc" || arg == "--dump-kv-cache") {
+ params.dump_kv_cache = true;
+ } else if (arg == "-nkvo" || arg == "--no-kv-offload") {
+ params.no_kv_offload = true;
+ } else if (arg == "-ctk" || arg == "--cache-type-k") {
+ params.cache_type_k = argv[++i];
+ } else if (arg == "-ctv" || arg == "--cache-type-v") {
+ params.cache_type_v = argv[++i];
+ } else if (arg == "--multiline-input") {
+ params.multiline_input = true;
+ } else if (arg == "--simple-io") {
+ params.simple_io = true;
+ } else if (arg == "-cb" || arg == "--cont-batching") {
+ params.cont_batching = true;
+ } else if (arg == "--color") {
+ params.use_color = true;
+ } else if (arg == "--mlock") {
+ params.use_mlock = true;
+ } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
+ params.n_gpu_layers = std::stoi(argv[i]);
+#else
+ fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
+ fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
+#endif
+ } else if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
+ params.n_gpu_layers_draft = std::stoi(argv[i]);
+#else
+ fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
+ fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
+#endif
+ } else if (arg == "--main-gpu" || arg == "-mg") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+#ifdef GGML_USE_CUBLAS
+ params.main_gpu = std::stoi(argv[i]);
+#else
+ fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
+#endif
+ } else if (arg == "--tensor-split" || arg == "-ts") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+#ifdef GGML_USE_CUBLAS
+ std::string arg_next = argv[i];
+
+ // split string by , and /
+ const std::regex regex{R"([,/]+)"};
+ std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
+ std::vector split_arg{it, {}};
+ GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
+
+ for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
+ if (i < split_arg.size()) {
+ params.tensor_split[i] = std::stof(split_arg[i]);
+ } else {
+ params.tensor_split[i] = 0.0f;
+ }
+ }
+#else
+ fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
+#endif // GGML_USE_CUBLAS
+ } else if (arg == "--no-mul-mat-q" || arg == "-nommq") {
+#ifdef GGML_USE_CUBLAS
+ params.mul_mat_q = false;
+#else
+ fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n");
+#endif // GGML_USE_CUBLAS
+ } else if (arg == "--no-mmap") {
+ params.use_mmap = false;
+ } else if (arg == "--numa") {
+ params.numa = true;
+ } else if (arg == "--verbose-prompt") {
+ params.verbose_prompt = true;
+ } else if (arg == "-r" || arg == "--reverse-prompt") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.antiprompt.push_back(argv[i]);
+ } else if (arg == "-ld" || arg == "--logdir") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.logdir = argv[i];
+
+ if (params.logdir.back() != DIRECTORY_SEPARATOR) {
+ params.logdir += DIRECTORY_SEPARATOR;
+ }
+ } else if (arg == "--perplexity" || arg == "--all-logits") {
+ params.logits_all = true;
+ } else if (arg == "--ppl-stride") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.ppl_stride = std::stoi(argv[i]);
+ } else if (arg == "--ppl-output-type") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.ppl_output_type = std::stoi(argv[i]);
+ } else if (arg == "--hellaswag") {
+ params.hellaswag = true;
+ } else if (arg == "--hellaswag-tasks") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.hellaswag_tasks = std::stoi(argv[i]);
+ } else if (arg == "--ignore-eos") {
+ params.ignore_eos = true;
+ } else if (arg == "--no-penalize-nl") {
+ sparams.penalize_nl = false;
+ } else if (arg == "-l" || arg == "--logit-bias") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ std::stringstream ss(argv[i]);
+ llama_token key;
+ char sign;
+ std::string value_str;
+ try {
+ if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) {
+ sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f);
+ } else {
+ throw std::exception();
+ }
+ } catch (const std::exception&) {
+ invalid_param = true;
+ break;
+ }
+ } else if (arg == "-h" || arg == "--help") {
+ return false;
+
+ } else if (arg == "--version") {
+ fprintf(stderr, "version: %d (%s)\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
+ fprintf(stderr, "built with %s for %s\n", LLAMA_COMPILER, LLAMA_BUILD_TARGET);
+ exit(0);
+ } else if (arg == "--random-prompt") {
+ params.random_prompt = true;
+ } else if (arg == "--in-prefix-bos") {
+ params.input_prefix_bos = true;
+ } else if (arg == "--in-prefix") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.input_prefix = argv[i];
+ } else if (arg == "--in-suffix") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ params.input_suffix = argv[i];
+ } else if (arg == "--grammar") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ sparams.grammar = argv[i];
+ } else if (arg == "--grammar-file") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ std::ifstream file(argv[i]);
+ if (!file) {
+ fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
+ invalid_param = true;
+ break;
+ }
+ std::copy(
+ std::istreambuf_iterator(file),
+ std::istreambuf_iterator(),
+ std::back_inserter(sparams.grammar)
+ );
+ } else if (arg == "--override-kv") {
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ char * sep = strchr(argv[i], '=');
+ if (sep == nullptr || sep - argv[i] >= 128) {
+ fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]);
+ invalid_param = true;
+ break;
+ }
+ struct llama_model_kv_override kvo;
+ std::strncpy(kvo.key, argv[i], sep - argv[i]);
+ kvo.key[sep - argv[i]] = 0;
+ sep++;
+ if (strncmp(sep, "int:", 4) == 0) {
+ sep += 4;
+ kvo.tag = LLAMA_KV_OVERRIDE_INT;
+ kvo.int_value = std::atol(sep);
+ } else if (strncmp(sep, "float:", 6) == 0) {
+ sep += 6;
+ kvo.tag = LLAMA_KV_OVERRIDE_FLOAT;
+ kvo.float_value = std::atof(sep);
+ } else if (strncmp(sep, "bool:", 5) == 0) {
+ sep += 5;
+ kvo.tag = LLAMA_KV_OVERRIDE_BOOL;
+ if (std::strcmp(sep, "true") == 0) {
+ kvo.bool_value = true;
+ } else if (std::strcmp(sep, "false") == 0) {
+ kvo.bool_value = false;
+ } else {
+ fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
+ invalid_param = true;
+ break;
+ }
+ } else {
+ fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
+ invalid_param = true;
+ break;
+ }
+ params.kv_overrides.push_back(kvo);
+#ifndef LOG_DISABLE_LOGS
+ // Parse args for logging parameters
+ } else if ( log_param_single_parse( argv[i] ) ) {
+ // Do nothing, log_param_single_parse automatically does it's thing
+ // and returns if a match was found and parsed.
+ } else if ( log_param_pair_parse( /*check_but_dont_parse*/ true, argv[i] ) ) {
+ // We have a matching known parameter requiring an argument,
+ // now we need to check if there is anything after this argv
+ // and flag invalid_param or parse it.
+ if (++i >= argc) {
+ invalid_param = true;
+ break;
+ }
+ if( !log_param_pair_parse( /*check_but_dont_parse*/ false, argv[i-1], argv[i]) ) {
+ invalid_param = true;
+ break;
+ }
+ // End of Parse args for logging parameters
+#endif // LOG_DISABLE_LOGS
+ } else {
+ throw std::invalid_argument("error: unknown argument: " + arg);
+ }
+ }
+ if (invalid_param) {
+ throw std::invalid_argument("error: invalid parameter for argument: " + arg);
+ }
+ if (params.prompt_cache_all &&
+ (params.interactive || params.interactive_first ||
+ params.instruct)) {
+
+ throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
+ }
+
+ if (params.escape) {
+ process_escapes(params.prompt);
+ process_escapes(params.input_prefix);
+ process_escapes(params.input_suffix);
+ process_escapes(sparams.cfg_negative_prompt);
+ for (auto & antiprompt : params.antiprompt) {
+ process_escapes(antiprompt);
+ }
+ }
+
+ if (!params.kv_overrides.empty()) {
+ params.kv_overrides.emplace_back(llama_model_kv_override());
+ params.kv_overrides.back().key[0] = 0;
+ }
+
+ return true;
+}
+
+void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
+ const llama_sampling_params & sparams = params.sparams;
+
+ printf("\n");
+ printf("usage: %s [options]\n", argv[0]);
+ printf("\n");
+ printf("options:\n");
+ printf(" -h, --help show this help message and exit\n");
+ printf(" --version show version and build info\n");
+ printf(" -i, --interactive run in interactive mode\n");
+ printf(" --interactive-first run in interactive mode and wait for input right away\n");
+ printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n");
+ printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n");
+ printf(" --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
+ printf(" -r PROMPT, --reverse-prompt PROMPT\n");
+ printf(" halt generation at PROMPT, return control in interactive mode\n");
+ printf(" (can be specified more than once for multiple prompts).\n");
+ printf(" --color colorise output to distinguish prompt and user input from generations\n");
+ printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
+ printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads);
+ printf(" -tb N, --threads-batch N\n");
+ printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n");
+ printf(" -p PROMPT, --prompt PROMPT\n");
+ printf(" prompt to start generation with (default: empty)\n");
+ printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
+ printf(" --prompt-cache FNAME file to cache prompt state for faster startup (default: none)\n");
+ printf(" --prompt-cache-all if specified, saves user input and generations to cache as well.\n");
+ printf(" not supported with --interactive or other interactive options\n");
+ printf(" --prompt-cache-ro if specified, uses the prompt cache but does not update it.\n");
+ printf(" --random-prompt start with a randomized prompt.\n");
+ printf(" --in-prefix-bos prefix BOS to user inputs, preceding the `--in-prefix` string\n");
+ printf(" --in-prefix STRING string to prefix user inputs with (default: empty)\n");
+ printf(" --in-suffix STRING string to suffix after user inputs with (default: empty)\n");
+ printf(" -f FNAME, --file FNAME\n");
+ printf(" prompt file to start generation.\n");
+ printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
+ printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx);
+ printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
+ printf(" --samplers samplers that will be used for generation in the order, separated by \';\', for example: \"top_k;tfs;typical;top_p;min_p;temp\"\n");
+ printf(" --sampling-seq simplified sequence for samplers that will be used (default: %s)\n", sparams.samplers_sequence.c_str());
+ printf(" --top-k N top-k sampling (default: %d, 0 = disabled)\n", sparams.top_k);
+ printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
+ printf(" --min-p N min-p sampling (default: %.1f, 0.0 = disabled)\n", (double)sparams.min_p);
+ printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
+ printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
+ printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
+ printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
+ printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
+ printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
+ printf(" --mirostat N use Mirostat sampling.\n");
+ printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
+ printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
+ printf(" --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)sparams.mirostat_eta);
+ printf(" --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)sparams.mirostat_tau);
+ printf(" -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n");
+ printf(" modifies the likelihood of token appearing in the completion,\n");
+ printf(" i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
+ printf(" or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
+ printf(" --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n");
+ printf(" --grammar-file FNAME file to read grammar from\n");
+ printf(" --cfg-negative-prompt PROMPT\n");
+ printf(" negative prompt to use for guidance. (default: empty)\n");
+ printf(" --cfg-negative-prompt-file FNAME\n");
+ printf(" negative prompt file to use for guidance. (default: empty)\n");
+ printf(" --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", sparams.cfg_scale);
+ printf(" --rope-scaling {none,linear,yarn}\n");
+ printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
+ printf(" --rope-scale N RoPE context scaling factor, expands context by a factor of N\n");
+ printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
+ printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
+ printf(" --yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size)\n");
+ printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
+ printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
+ printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
+ printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
+ printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
+ printf(" --no-penalize-nl do not penalize newline token\n");
+ printf(" --temp N temperature (default: %.1f)\n", (double)sparams.temp);
+ printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
+ printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
+ printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
+ printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
+ printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft);
+ printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
+ printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
+ printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
+ printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept);
+ printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
+ printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
+ printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
+ printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n");
+ if (llama_mlock_supported()) {
+ printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
+ }
+ if (llama_mmap_supported()) {
+ printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
+ }
+ printf(" --numa attempt optimizations that help on some NUMA systems\n");
+ printf(" if run without this previously, it is recommended to drop the system page cache before using this\n");
+ printf(" see https://github.com/ggerganov/llama.cpp/issues/1437\n");
+#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
+ printf(" -ngl N, --n-gpu-layers N\n");
+ printf(" number of layers to store in VRAM\n");
+ printf(" -ngld N, --n-gpu-layers-draft N\n");
+ printf(" number of layers to store in VRAM for the draft model\n");
+ printf(" -ts SPLIT --tensor-split SPLIT\n");
+ printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
+ printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
+#ifdef GGML_USE_CUBLAS
+ printf(" -nommq, --no-mul-mat-q\n");
+ printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n");
+ printf(" Not recommended since this is both slower and uses more VRAM.\n");
+#endif // GGML_USE_CUBLAS
+#endif
+ printf(" --verbose-prompt print prompt before generation\n");
+ printf(" -dkvc, --dump-kv-cache\n");
+ printf(" verbose print of the KV cache\n");
+ printf(" -nkvo, --no-kv-offload\n");
+ printf(" disable KV offload\n");
+ printf(" -ctk TYPE, --cache-type-k TYPE\n");
+ printf(" KV cache data type for K (default: %s)\n", params.cache_type_k.c_str());
+ printf(" -ctv TYPE, --cache-type-v TYPE\n");
+ printf(" KV cache data type for V (default: %s)\n", params.cache_type_v.c_str());
+ printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
+ printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
+ printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
+ printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
+ printf(" -m FNAME, --model FNAME\n");
+ printf(" model path (default: %s)\n", params.model.c_str());
+ printf(" -md FNAME, --model-draft FNAME\n");
+ printf(" draft model for speculative decoding (default: %s)\n", params.model.c_str());
+ printf(" -ld LOGDIR, --logdir LOGDIR\n");
+ printf(" path under which to save YAML logs (no logging if unset)\n");
+ printf(" --override-kv KEY=TYPE:VALUE\n");
+ printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
+ printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
+ printf("\n");
+#ifndef LOG_DISABLE_LOGS
+ log_print_usage();
+#endif // LOG_DISABLE_LOGS
+}
+
+std::string get_system_info(const gpt_params & params) {
+ std::ostringstream os;
+
+ os << "system_info: n_threads = " << params.n_threads;
+ if (params.n_threads_batch != -1) {
+ os << " (n_threads_batch = " << params.n_threads_batch << ")";
+ }
+ os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info();
+
+ return os.str();
+}
+
+std::string gpt_random_prompt(std::mt19937 & rng) {
+ const int r = rng() % 10;
+ switch (r) {
+ case 0: return "So";
+ case 1: return "Once upon a time";
+ case 2: return "When";
+ case 3: return "The";
+ case 4: return "After";
+ case 5: return "If";
+ case 6: return "import";
+ case 7: return "He";
+ case 8: return "She";
+ case 9: return "They";
+ }
+
+ GGML_UNREACHABLE();
+}
+
+//
+// String parsing
+//
+
+std::string parse_samplers_input(std::string input) {
+ std::string output = "";
+ // since samplers names are written multiple ways
+ // make it ready for both system names and input names
+ std::unordered_map samplers_symbols {
+ {"top_k", 'k'},
+ {"top-k", 'k'},
+ {"top_p", 'p'},
+ {"top-p", 'p'},
+ {"nucleus", 'p'},
+ {"typical_p", 'y'},
+ {"typical-p", 'y'},
+ {"typical", 'y'},
+ {"min_p", 'm'},
+ {"min-p", 'm'},
+ {"tfs_z", 'f'},
+ {"tfs-z", 'f'},
+ {"tfs", 'f'},
+ {"temp", 't'},
+ {"temperature",'t'}
+ };
+ // expected format example: "temp;top_k;tfs_z;typical_p;top_p;min_p"
+ size_t separator = input.find(';');
+ while (separator != input.npos) {
+ std::string name = input.substr(0,separator);
+ input = input.substr(separator+1);
+ separator = input.find(';');
+
+ if (samplers_symbols.find(name) != samplers_symbols.end()) {
+ output += samplers_symbols[name];
+ }
+ }
+ if (samplers_symbols.find(input) != samplers_symbols.end()) {
+ output += samplers_symbols[input];
+ }
+ return output;
+}
+
+//
+// Model utils
+//
+
+struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
+ auto mparams = llama_model_default_params();
+
+ if (params.n_gpu_layers != -1) {
+ mparams.n_gpu_layers = params.n_gpu_layers;
+ }
+ mparams.main_gpu = params.main_gpu;
+ mparams.tensor_split = params.tensor_split;
+ mparams.use_mmap = params.use_mmap;
+ mparams.use_mlock = params.use_mlock;
+ if (params.kv_overrides.empty()) {
+ mparams.kv_overrides = NULL;
+ } else {
+ GGML_ASSERT(params.kv_overrides.back().key[0] == 0 && "KV overrides not terminated with empty key");
+ mparams.kv_overrides = params.kv_overrides.data();
+ }
+
+ return mparams;
+}
+
+static ggml_type kv_cache_type_from_str(const std::string & s) {
+ if (s == "f16") {
+ return GGML_TYPE_F16;
+ }
+ if (s == "q8_0") {
+ return GGML_TYPE_Q8_0;
+ }
+ if (s == "q4_0") {
+ return GGML_TYPE_Q4_0;
+ }
+ if (s == "q4_1") {
+ return GGML_TYPE_Q4_1;
+ }
+ if (s == "q5_0") {
+ return GGML_TYPE_Q5_0;
+ }
+ if (s == "q5_1") {
+ return GGML_TYPE_Q5_1;
+ }
+
+ throw std::runtime_error("Invalid cache type: " + s);
+}
+
+struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) {
+ auto cparams = llama_context_default_params();
+
+ cparams.n_ctx = params.n_ctx;
+ cparams.n_batch = params.n_batch;
+ cparams.n_threads = params.n_threads;
+ cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
+ cparams.mul_mat_q = params.mul_mat_q;
+ cparams.seed = params.seed;
+ cparams.logits_all = params.logits_all;
+ cparams.embedding = params.embedding;
+ cparams.rope_scaling_type = params.rope_scaling_type;
+ cparams.rope_freq_base = params.rope_freq_base;
+ cparams.rope_freq_scale = params.rope_freq_scale;
+ cparams.yarn_ext_factor = params.yarn_ext_factor;
+ cparams.yarn_attn_factor = params.yarn_attn_factor;
+ cparams.yarn_beta_fast = params.yarn_beta_fast;
+ cparams.yarn_beta_slow = params.yarn_beta_slow;
+ cparams.yarn_orig_ctx = params.yarn_orig_ctx;
+ cparams.offload_kqv = !params.no_kv_offload;
+
+ cparams.type_k = kv_cache_type_from_str(params.cache_type_k);
+ cparams.type_v = kv_cache_type_from_str(params.cache_type_v);
+
+ return cparams;
+}
+
+void llama_batch_clear(struct llama_batch & batch) {
+ batch.n_tokens = 0;
+}
+
+void llama_batch_add(
+ struct llama_batch & batch,
+ llama_token id,
+ llama_pos pos,
+ const std::vector & seq_ids,
+ bool logits) {
+ batch.token [batch.n_tokens] = id;
+ batch.pos [batch.n_tokens] = pos;
+ batch.n_seq_id[batch.n_tokens] = seq_ids.size();
+ for (size_t i = 0; i < seq_ids.size(); ++i) {
+ batch.seq_id[batch.n_tokens][i] = seq_ids[i];
+ }
+ batch.logits [batch.n_tokens] = logits;
+
+ batch.n_tokens++;
+}
+
+std::tuple llama_init_from_gpt_params(gpt_params & params) {
+ auto mparams = llama_model_params_from_gpt_params(params);
+
+ llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
+ if (model == NULL) {
+ fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str());
+ return std::make_tuple(nullptr, nullptr);
+ }
+
+ auto cparams = llama_context_params_from_gpt_params(params);
+
+ llama_context * lctx = llama_new_context_with_model(model, cparams);
+ if (lctx == NULL) {
+ fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
+ llama_free_model(model);
+ return std::make_tuple(nullptr, nullptr);
+ }
+
+ for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
+ const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]);
+ float lora_scale = std::get<1>(params.lora_adapter[i]);
+ int err = llama_model_apply_lora_from_file(model,
+ lora_adapter.c_str(),
+ lora_scale,
+ ((i > 0) || params.lora_base.empty())
+ ? NULL
+ : params.lora_base.c_str(),
+ params.n_threads);
+ if (err != 0) {
+ fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
+ llama_free(lctx);
+ llama_free_model(model);
+ return std::make_tuple(nullptr, nullptr);
+ }
+ }
+
+ if (params.ignore_eos) {
+ params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
+ }
+
+ {
+ LOG("warming up the model with an empty run\n");
+
+ std::vector tmp = { llama_token_bos(model), llama_token_eos(model), };
+ llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
+ llama_kv_cache_clear(lctx);
+ llama_reset_timings(lctx);
+ }
+
+ return std::make_tuple(model, lctx);
+}
+
+//
+// Vocab utils
+//
+
+std::vector llama_tokenize(
+ const struct llama_context * ctx,
+ const std::string & text,
+ bool add_bos,
+ bool special) {
+ return llama_tokenize(llama_get_model(ctx), text, add_bos, special);
+}
+
+std::vector llama_tokenize(
+ const struct llama_model * model,
+ const std::string & text,
+ bool add_bos,
+ bool special) {
+ // upper limit for the number of tokens
+ int n_tokens = text.length() + add_bos;
+ std::vector result(n_tokens);
+ n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
+ if (n_tokens < 0) {
+ result.resize(-n_tokens);
+ int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
+ GGML_ASSERT(check == -n_tokens);
+ } else {
+ result.resize(n_tokens);
+ }
+ return result;
+}
+
+std::string llama_token_to_piece(const struct llama_context * ctx, llama_token token) {
+ std::vector result(8, 0);
+ const int n_tokens = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
+ if (n_tokens < 0) {
+ result.resize(-n_tokens);
+ int check = llama_token_to_piece(llama_get_model(ctx), token, result.data(), result.size());
+ GGML_ASSERT(check == -n_tokens);
+ } else {
+ result.resize(n_tokens);
+ }
+
+ return std::string(result.data(), result.size());
+}
+
+std::string llama_detokenize_spm(llama_context * ctx, const std::vector & tokens) {
+ const llama_token bos_id = llama_token_bos(llama_get_model(ctx));
+
+ std::string piece;
+ std::string result;
+
+ for (size_t i = 0; i < tokens.size(); ++i) {
+ piece = llama_token_to_piece(ctx, tokens[i]);
+
+ // remove the leading space of the first non-BOS token
+ if (((tokens[0] == bos_id && i == 1) || (tokens[0] != bos_id && i == 0)) && piece[0] == ' ') {
+ piece = piece.substr(1);
+ }
+
+ result += piece;
+ }
+
+ return result;
+}
+
+std::string llama_detokenize_bpe(llama_context * ctx, const std::vector & tokens) {
+ std::string piece;
+ std::string result;
+
+ for (size_t i = 0; i < tokens.size(); ++i) {
+ piece = llama_token_to_piece(ctx, tokens[i]);
+
+ result += piece;
+ }
+
+ // NOTE: the original tokenizer decodes bytes after collecting the pieces.
+ return result;
+}
+
+bool llama_should_add_bos_token(const llama_model * model) {
+ const int add_bos = llama_add_bos_token(model);
+
+ return add_bos != -1 ? bool(add_bos) : (llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
+}
+
+//
+// YAML utils
+//
+
+// returns true if successful, false otherwise
+bool create_directory_with_parents(const std::string & path) {
+#ifdef _WIN32
+ std::wstring_convert> converter;
+ std::wstring wpath = converter.from_bytes(path);
+
+ // if the path already exists, check whether it's a directory
+ const DWORD attributes = GetFileAttributesW(wpath.c_str());
+ if ((attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY)) {
+ return true;
+ }
+
+ size_t pos_slash = 0;
+
+ // process path from front to back, procedurally creating directories
+ while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) {
+ const std::wstring subpath = wpath.substr(0, pos_slash);
+ const wchar_t * test = subpath.c_str();
+
+ const bool success = CreateDirectoryW(test, NULL);
+ if (!success) {
+ const DWORD error = GetLastError();
+
+ // if the path already exists, ensure that it's a directory
+ if (error == ERROR_ALREADY_EXISTS) {
+ const DWORD attributes = GetFileAttributesW(subpath.c_str());
+ if (attributes == INVALID_FILE_ATTRIBUTES || !(attributes & FILE_ATTRIBUTE_DIRECTORY)) {
+ return false;
+ }
+ } else {
+ return false;
+ }
+ }
+
+ pos_slash += 1;
+ }
+
+ return true;
+#else
+ // if the path already exists, check whether it's a directory
+ struct stat info;
+ if (stat(path.c_str(), &info) == 0) {
+ return S_ISDIR(info.st_mode);
+ }
+
+ size_t pos_slash = 1; // skip leading slashes for directory creation
+
+ // process path from front to back, procedurally creating directories
+ while ((pos_slash = path.find('/', pos_slash)) != std::string::npos) {
+ const std::string subpath = path.substr(0, pos_slash);
+ struct stat info;
+
+ // if the path already exists, ensure that it's a directory
+ if (stat(subpath.c_str(), &info) == 0) {
+ if (!S_ISDIR(info.st_mode)) {
+ return false;
+ }
+ } else {
+ // create parent directories
+ const int ret = mkdir(subpath.c_str(), 0755);
+ if (ret != 0) {
+ return false;
+ }
+ }
+
+ pos_slash += 1;
+ }
+
+ return true;
+#endif // _WIN32
+}
+
+void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector & data) {
+ if (data.empty()) {
+ fprintf(stream, "%s:\n", prop_name);
+ return;
+ }
+
+ fprintf(stream, "%s: [", prop_name);
+ for (size_t i = 0; i < data.size() - 1; ++i) {
+ fprintf(stream, "%e, ", data[i]);
+ }
+ fprintf(stream, "%e]\n", data.back());
+}
+
+void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector & data) {
+ if (data.empty()) {
+ fprintf(stream, "%s:\n", prop_name);
+ return;
+ }
+
+ fprintf(stream, "%s: [", prop_name);
+ for (size_t i = 0; i < data.size() - 1; ++i) {
+ fprintf(stream, "%d, ", data[i]);
+ }
+ fprintf(stream, "%d]\n", data.back());
+}
+
+void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data) {
+ std::string data_str(data == NULL ? "" : data);
+
+ if (data_str.empty()) {
+ fprintf(stream, "%s:\n", prop_name);
+ return;
+ }
+
+ size_t pos_start = 0;
+ size_t pos_found = 0;
+
+ if (!data_str.empty() && (std::isspace(data_str[0]) || std::isspace(data_str.back()))) {
+ data_str = std::regex_replace(data_str, std::regex("\n"), "\\n");
+ data_str = std::regex_replace(data_str, std::regex("\""), "\\\"");
+ data_str = std::regex_replace(data_str, std::regex(R"(\\[^n"])"), R"(\$&)");
+ data_str = "\"" + data_str + "\"";
+ fprintf(stream, "%s: %s\n", prop_name, data_str.c_str());
+ return;
+ }
+
+ if (data_str.find('\n') == std::string::npos) {
+ fprintf(stream, "%s: %s\n", prop_name, data_str.c_str());
+ return;
+ }
+
+ fprintf(stream, "%s: |\n", prop_name);
+ while ((pos_found = data_str.find('\n', pos_start)) != std::string::npos) {
+ fprintf(stream, " %s\n", data_str.substr(pos_start, pos_found-pos_start).c_str());
+ pos_start = pos_found + 1;
+ }
+}
+
+std::string get_sortable_timestamp() {
+ using clock = std::chrono::system_clock;
+
+ const clock::time_point current_time = clock::now();
+ const time_t as_time_t = clock::to_time_t(current_time);
+ char timestamp_no_ns[100];
+ std::strftime(timestamp_no_ns, 100, "%Y_%m_%d-%H_%M_%S", std::localtime(&as_time_t));
+
+ const int64_t ns = std::chrono::duration_cast(
+ current_time.time_since_epoch() % 1000000000).count();
+ char timestamp_ns[11];
+ snprintf(timestamp_ns, 11, "%09" PRId64, ns);
+
+ return std::string(timestamp_no_ns) + "." + std::string(timestamp_ns);
+}
+
+void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const llama_context * lctx,
+ const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) {
+ const llama_sampling_params & sparams = params.sparams;
+
+ fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT);
+ fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER);
+ fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false");
+ fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false");
+ fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false");
+ fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false");
+ fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false");
+ fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false");
+ fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
+ fprintf(stream, "cpu_has_cublas: %s\n", ggml_cpu_has_cublas() ? "true" : "false");
+ fprintf(stream, "cpu_has_clblast: %s\n", ggml_cpu_has_clblast() ? "true" : "false");
+ fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false");
+ fprintf(stream, "cpu_has_gpublas: %s\n", ggml_cpu_has_gpublas() ? "true" : "false");
+ fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false");
+ fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false");
+ fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false");
+ fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false");
+ fprintf(stream, "cpu_has_blas: %s\n", ggml_cpu_has_blas() ? "true" : "false");
+ fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false");
+ fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false");
+
+#ifdef NDEBUG
+ fprintf(stream, "debug: false\n");
+#else
+ fprintf(stream, "debug: true\n");
+#endif // NDEBUG
+
+ fprintf(stream, "model_desc: %s\n", model_desc);
+ fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx)));
+
+#ifdef __OPTIMIZE__
+ fprintf(stream, "optimize: true\n");
+#else
+ fprintf(stream, "optimize: false\n");
+#endif // __OPTIMIZE__
+
+ fprintf(stream, "time: %s\n", timestamp.c_str());
+
+ fprintf(stream, "\n");
+ fprintf(stream, "###############\n");
+ fprintf(stream, "# User Inputs #\n");
+ fprintf(stream, "###############\n");
+ fprintf(stream, "\n");
+
+ fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str());
+ fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch);
+ dump_string_yaml_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str());
+ fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale);
+ fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks);
+ fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
+ fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
+ fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
+ fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
+ fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
+ dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str());
+ fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
+ fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
+ fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
+
+ const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
+ const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
+ fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
+
+ dump_string_yaml_multiline(stream, "in_prefix", params.input_prefix.c_str());
+ fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false");
+ dump_string_yaml_multiline(stream, "in_suffix", params.input_prefix.c_str());
+ fprintf(stream, "instruct: %s # default: false\n", params.instruct ? "true" : "false");
+ fprintf(stream, "interactive: %s # default: false\n", params.interactive ? "true" : "false");
+ fprintf(stream, "interactive_first: %s # default: false\n", params.interactive_first ? "true" : "false");
+ fprintf(stream, "keep: %d # default: 0\n", params.n_keep);
+ fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str());
+
+ fprintf(stream, "logit_bias:\n");
+ for (std::pair lb : sparams.logit_bias) {
+ if (ignore_eos && lb.first == logit_bias_eos->first) {
+ continue;
+ }
+ fprintf(stream, " %d: %f", lb.first, lb.second);
+ }
+
+ fprintf(stream, "lora:\n");
+ for (std::tuple la : params.lora_adapter) {
+ if (std::get<1>(la) != 1.0f) {
+ continue;
+ }
+ fprintf(stream, " - %s\n", std::get<0>(la).c_str());
+ }
+ fprintf(stream, "lora_scaled:\n");
+ for (std::tuple la : params.lora_adapter) {
+ if (std::get<1>(la) == 1.0f) {
+ continue;
+ }
+ fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
+ }
+ fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
+ fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
+ fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
+ fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau);
+ fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta);
+ fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
+ fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str());
+ fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
+ fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
+ fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers);
+ fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict);
+ fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs);
+ fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false");
+ fprintf(stream, "no_mul_mat_q: %s # default: false\n", !params.mul_mat_q ? "true" : "false");
+ fprintf(stream, "no_penalize_nl: %s # default: false\n", !sparams.penalize_nl ? "true" : "false");
+ fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false");
+ fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
+ fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
+ fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
+ dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str());
+ fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
+ fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
+ fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
+ dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens);
+ fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false");
+ fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat);
+
+ fprintf(stream, "reverse_prompt:\n");
+ for (std::string ap : params.antiprompt) {
+ size_t pos = 0;
+ while ((pos = ap.find('\n', pos)) != std::string::npos) {
+ ap.replace(pos, 1, "\\n");
+ pos += 1;
+ }
+
+ fprintf(stream, " - %s\n", ap.c_str());
+ }
+
+ fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base);
+ fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale);
+ fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed);
+ fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
+ fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
+ fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp);
+
+ const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES);
+ dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector);
+
+ fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z);
+ fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency());
+ fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k);
+ fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p);
+ fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p);
+ fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
+ fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
+}
+
+//
+// KV cache utils
+//
+
+void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) {
+ static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
+
+ printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
+ view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
+
+ llama_kv_cache_view_cell * c_curr = view.cells;
+ llama_seq_id * cs_curr = view.cells_sequences;
+
+ for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
+ if (i % row_size == 0) {
+ printf("\n%5d: ", i);
+ }
+ int seq_count = 0;
+ for (int j = 0; j < view.n_max_seq; j++) {
+ if (cs_curr[j] >= 0) { seq_count++; }
+ }
+ putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
+ }
+
+ printf("\n=== Done dumping\n");
+}
+
+void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
+ static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
+
+ printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
+ view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
+
+ std::unordered_map seqs;
+ llama_kv_cache_view_cell * c_curr = view.cells;
+ llama_seq_id * cs_curr = view.cells_sequences;
+
+ for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
+ for (int j = 0; j < view.n_max_seq; j++) {
+ if (cs_curr[j] < 0) { continue; }
+ if (seqs.find(cs_curr[j]) == seqs.end()) {
+ if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
+ seqs[cs_curr[j]] = seqs.size();
+ }
+ }
+ if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
+ }
+
+ printf("=== Sequence legend: ");
+ for (const auto & it : seqs) {
+ printf("%zu=%d, ", it.second, it.first);
+ }
+ printf("'+'=other sequence ids");
+
+ c_curr = view.cells;
+ cs_curr = view.cells_sequences;
+ for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
+ if (i % row_size == 0) {
+ printf("\n%5d: ", i);
+ }
+ for (int j = 0; j < view.n_max_seq; j++) {
+ if (cs_curr[j] >= 0) {
+ const auto & it = seqs.find(cs_curr[j]);
+ putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
+ } else {
+ putchar('.');
+ }
+ }
+ putchar(' ');
+ }
+
+ printf("\n=== Done dumping\n");
+}
diff --git a/applications/src/llama/common/common.h b/applications/src/llama/common/common.h
new file mode 100644
index 000000000..e87ce1133
--- /dev/null
+++ b/applications/src/llama/common/common.h
@@ -0,0 +1,242 @@
+// Various helper functions and utilities
+
+#pragma once
+
+#include "llama.h"
+
+#include "sampling.h"
+
+#define LOG_NO_FILE_LINE_FUNCTION
+#include "log.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#ifdef _WIN32
+#define DIRECTORY_SEPARATOR '\\'
+#else
+#define DIRECTORY_SEPARATOR '/'
+#endif // _WIN32
+
+#define die(msg) do { fputs("error: " msg "\n", stderr); exit(1); } while (0)
+#define die_fmt(fmt, ...) do { fprintf(stderr, "error: " fmt "\n", __VA_ARGS__); exit(1); } while (0)
+
+#define print_build_info() do { \
+ fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \
+ fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
+} while(0)
+
+// build info
+extern int LLAMA_BUILD_NUMBER;
+extern char const *LLAMA_COMMIT;
+extern char const *LLAMA_COMPILER;
+extern char const *LLAMA_BUILD_TARGET;
+
+//
+// CLI argument parsing
+//
+int32_t get_num_physical_cores();
+
+struct gpt_params {
+ uint32_t seed = -1; // RNG seed
+
+ int32_t n_threads = get_num_physical_cores();
+ int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
+ int32_t n_predict = -1; // new tokens to predict
+ int32_t n_ctx = 512; // context size
+ int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
+ int32_t n_draft = 16; // number of tokens to draft during speculative decoding
+ int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
+ int32_t n_parallel = 1; // number of parallel sequences to decode
+ int32_t n_sequences = 1; // number of sequences to decode
+ float p_accept = 0.5f; // speculative decoding accept probability
+ float p_split = 0.1f; // speculative decoding split probability
+ int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
+ int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
+ int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
+ float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
+ int32_t n_beams = 0; // if non-zero then use beam search of given width.
+ float rope_freq_base = 0.0f; // RoPE base frequency
+ float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
+ float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
+ float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
+ float yarn_beta_fast = 32.0f; // YaRN low correction dim
+ float yarn_beta_slow = 1.0f; // YaRN high correction dim
+ int32_t yarn_orig_ctx = 0; // YaRN original context length
+ int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment
+ // pinging @cebtenzzre
+
+ // // sampling parameters
+ struct llama_sampling_params sparams;
+
+ std::string model = "models/7B/ggml-model-f16.gguf"; // model path
+ std::string model_draft = ""; // draft model for speculative decoding
+ std::string model_alias = "unknown"; // model alias
+ std::string prompt = "";
+ std::string prompt_file = ""; // store the external prompt file name
+ std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
+ std::string input_prefix = ""; // string to prefix user inputs with
+ std::string input_suffix = ""; // string to suffix user inputs with
+ std::vector antiprompt; // string upon seeing which more user input is prompted
+ std::string logdir = ""; // directory in which to save YAML log files
+
+ std::vector kv_overrides;
+
+ // TODO: avoid tuple, use struct
+ std::vector> lora_adapter; // lora adapter path with user defined scale
+ std::string lora_base = ""; // base model path for the lora adapter
+
+ int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
+ int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
+ // (which is more convenient to use for plotting)
+ //
+ bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
+ size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
+
+ bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
+ bool random_prompt = false; // do not randomize prompt if none provided
+ bool use_color = false; // use color to distinguish generations and inputs
+ bool interactive = false; // interactive mode
+ bool chatml = false; // chatml mode (used for models trained on chatml syntax)
+ bool prompt_cache_all = false; // save user input and generations to prompt cache
+ bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
+
+ bool embedding = false; // get only sentence embedding
+ bool escape = false; // escape "\n", "\r", "\t", "\'", "\"", and "\\"
+ bool interactive_first = false; // wait for user input immediately
+ bool multiline_input = false; // reverse the usage of `\`
+ bool simple_io = false; // improves compatibility with subprocesses and limited consoles
+ bool cont_batching = false; // insert new sequences for decoding on-the-fly
+
+ bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
+ bool ignore_eos = false; // ignore generated EOS tokens
+ bool instruct = false; // instruction mode (used for Alpaca models)
+ bool logits_all = false; // return logits for all tokens in the batch
+ bool use_mmap = true; // use mmap for faster loads
+ bool use_mlock = false; // use mlock to keep model in memory
+ bool numa = false; // attempt optimizations that help on some NUMA systems
+ bool verbose_prompt = false; // print prompt tokens before generation
+ bool infill = false; // use infill mode
+ bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
+ bool no_kv_offload = false; // disable KV offloading
+
+ std::string cache_type_k = "f16"; // KV cache data type for the K
+ std::string cache_type_v = "f16"; // KV cache data type for the V
+
+ // multimodal models (see examples/llava)
+ std::string mmproj = ""; // path to multimodal projector
+ std::string image = ""; // path to an image file
+};
+
+bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params);
+
+bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
+
+void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
+
+std::string get_system_info(const gpt_params & params);
+
+std::string gpt_random_prompt(std::mt19937 & rng);
+
+void process_escapes(std::string& input);
+
+//
+// String parsing
+//
+
+std::string parse_samplers_input(std::string input);
+
+//
+// Model utils
+//
+
+// TODO: avoid tuplue, use struct
+std::tuple llama_init_from_gpt_params(gpt_params & params);
+
+struct llama_model_params llama_model_params_from_gpt_params (const gpt_params & params);
+struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params);
+
+// Batch utils
+
+void llama_batch_clear(struct llama_batch & batch);
+
+void llama_batch_add(
+ struct llama_batch & batch,
+ llama_token id,
+ llama_pos pos,
+ const std::vector & seq_ids,
+ bool logits);
+
+//
+// Vocab utils
+//
+
+// tokenizes a string into a vector of tokens
+// should work similar to Python's `tokenizer.encode`
+std::vector llama_tokenize(
+ const struct llama_context * ctx,
+ const std::string & text,
+ bool add_bos,
+ bool special = false);
+
+std::vector llama_tokenize(
+ const struct llama_model * model,
+ const std::string & text,
+ bool add_bos,
+ bool special = false);
+
+// tokenizes a token into a piece
+// should work similar to Python's `tokenizer.id_to_piece`
+std::string llama_token_to_piece(
+ const struct llama_context * ctx,
+ llama_token token);
+
+// TODO: these should be moved in llama.h C-style API under single `llama_detokenize` function
+// that takes into account the tokenizer type and decides how to handle the leading space
+//
+// detokenizes a vector of tokens into a string
+// should work similar to Python's `tokenizer.decode`
+// removes the leading space from the first non-BOS token
+std::string llama_detokenize_spm(
+ llama_context * ctx,
+ const std::vector & tokens);
+
+// detokenizes a vector of tokens into a string
+// should work similar to Python's `tokenizer.decode`
+std::string llama_detokenize_bpe(
+ llama_context * ctx,
+ const std::vector & tokens);
+
+// Uses the value from the model metadata if possible, otherwise
+// defaults to true when model type is SPM, otherwise false.
+bool llama_should_add_bos_token(const llama_model * model);
+
+//
+// YAML utils
+//
+
+bool create_directory_with_parents(const std::string & path);
+void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector & data);
+void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector & data);
+void dump_string_yaml_multiline(FILE * stream, const char * prop_name, const char * data);
+std::string get_sortable_timestamp();
+
+void dump_non_result_info_yaml(
+ FILE * stream, const gpt_params & params, const llama_context * lctx,
+ const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc);
+
+//
+// KV cache utils
+//
+
+// Dump the KV cache view with the number of sequences per cell.
+void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
+
+// Dump the KV cache view showing individual sequences in each cell (long output).
+void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);
diff --git a/applications/src/llama/common/console.cpp b/applications/src/llama/common/console.cpp
new file mode 100644
index 000000000..f65cbc6ed
--- /dev/null
+++ b/applications/src/llama/common/console.cpp
@@ -0,0 +1,501 @@
+#include "console.h"
+#include
+#include
+
+#if defined(_WIN32)
+#define WIN32_LEAN_AND_MEAN
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif
+#include
+#include
+#include
+#ifndef ENABLE_VIRTUAL_TERMINAL_PROCESSING
+#define ENABLE_VIRTUAL_TERMINAL_PROCESSING 0x0004
+#endif
+#else
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#endif
+
+#define ANSI_COLOR_RED "\x1b[31m"
+#define ANSI_COLOR_GREEN "\x1b[32m"
+#define ANSI_COLOR_YELLOW "\x1b[33m"
+#define ANSI_COLOR_BLUE "\x1b[34m"
+#define ANSI_COLOR_MAGENTA "\x1b[35m"
+#define ANSI_COLOR_CYAN "\x1b[36m"
+#define ANSI_COLOR_RESET "\x1b[0m"
+#define ANSI_BOLD "\x1b[1m"
+
+namespace console {
+
+ //
+ // Console state
+ //
+
+ static bool advanced_display = false;
+ static bool simple_io = true;
+ static display_t current_display = reset;
+
+ static FILE* out = stdout;
+
+#if defined (_WIN32)
+ static void* hConsole;
+#else
+ static FILE* tty = nullptr;
+ static termios initial_state;
+#endif
+
+ //
+ // Init and cleanup
+ //
+
+ void init(bool use_simple_io, bool use_advanced_display) {
+ advanced_display = use_advanced_display;
+ simple_io = use_simple_io;
+#if defined(_WIN32)
+ // Windows-specific console initialization
+ DWORD dwMode = 0;
+ hConsole = GetStdHandle(STD_OUTPUT_HANDLE);
+ if (hConsole == INVALID_HANDLE_VALUE || !GetConsoleMode(hConsole, &dwMode)) {
+ hConsole = GetStdHandle(STD_ERROR_HANDLE);
+ if (hConsole != INVALID_HANDLE_VALUE && (!GetConsoleMode(hConsole, &dwMode))) {
+ hConsole = nullptr;
+ simple_io = true;
+ }
+ }
+ if (hConsole) {
+ // Check conditions combined to reduce nesting
+ if (advanced_display && !(dwMode & ENABLE_VIRTUAL_TERMINAL_PROCESSING) &&
+ !SetConsoleMode(hConsole, dwMode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)) {
+ advanced_display = false;
+ }
+ // Set console output codepage to UTF8
+ SetConsoleOutputCP(CP_UTF8);
+ }
+ HANDLE hConIn = GetStdHandle(STD_INPUT_HANDLE);
+ if (hConIn != INVALID_HANDLE_VALUE && GetConsoleMode(hConIn, &dwMode)) {
+ // Set console input codepage to UTF16
+ _setmode(_fileno(stdin), _O_WTEXT);
+
+ // Set ICANON (ENABLE_LINE_INPUT) and ECHO (ENABLE_ECHO_INPUT)
+ if (simple_io) {
+ dwMode |= ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT;
+ } else {
+ dwMode &= ~(ENABLE_LINE_INPUT | ENABLE_ECHO_INPUT);
+ }
+ if (!SetConsoleMode(hConIn, dwMode)) {
+ simple_io = true;
+ }
+ }
+#else
+ // POSIX-specific console initialization
+ if (!simple_io) {
+ struct termios new_termios;
+ tcgetattr(STDIN_FILENO, &initial_state);
+ new_termios = initial_state;
+ new_termios.c_lflag &= ~(ICANON | ECHO);
+ new_termios.c_cc[VMIN] = 1;
+ new_termios.c_cc[VTIME] = 0;
+ tcsetattr(STDIN_FILENO, TCSANOW, &new_termios);
+
+ tty = fopen("/dev/tty", "w+");
+ if (tty != nullptr) {
+ out = tty;
+ }
+ }
+
+ setlocale(LC_ALL, "");
+#endif
+ }
+
+ void cleanup() {
+ // Reset console display
+ set_display(reset);
+
+#if !defined(_WIN32)
+ // Restore settings on POSIX systems
+ if (!simple_io) {
+ if (tty != nullptr) {
+ out = stdout;
+ fclose(tty);
+ tty = nullptr;
+ }
+ tcsetattr(STDIN_FILENO, TCSANOW, &initial_state);
+ }
+#endif
+ }
+
+ //
+ // Display and IO
+ //
+
+ // Keep track of current display and only emit ANSI code if it changes
+ void set_display(display_t display) {
+ if (advanced_display && current_display != display) {
+ fflush(stdout);
+ switch(display) {
+ case reset:
+ fprintf(out, ANSI_COLOR_RESET);
+ break;
+ case prompt:
+ fprintf(out, ANSI_COLOR_YELLOW);
+ break;
+ case user_input:
+ fprintf(out, ANSI_BOLD ANSI_COLOR_GREEN);
+ break;
+ case error:
+ fprintf(out, ANSI_BOLD ANSI_COLOR_RED);
+ }
+ current_display = display;
+ fflush(out);
+ }
+ }
+
+ static char32_t getchar32() {
+#if defined(_WIN32)
+ HANDLE hConsole = GetStdHandle(STD_INPUT_HANDLE);
+ wchar_t high_surrogate = 0;
+
+ while (true) {
+ INPUT_RECORD record;
+ DWORD count;
+ if (!ReadConsoleInputW(hConsole, &record, 1, &count) || count == 0) {
+ return WEOF;
+ }
+
+ if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) {
+ wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar;
+ if (wc == 0) {
+ continue;
+ }
+
+ if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate
+ high_surrogate = wc;
+ continue;
+ }
+ if ((wc >= 0xDC00) && (wc <= 0xDFFF)) { // Check if wc is a low surrogate
+ if (high_surrogate != 0) { // Check if we have a high surrogate
+ return ((high_surrogate - 0xD800) << 10) + (wc - 0xDC00) + 0x10000;
+ }
+ }
+
+ high_surrogate = 0; // Reset the high surrogate
+ return static_cast(wc);
+ }
+ }
+#else
+ wchar_t wc = getwchar();
+ if (static_cast(wc) == WEOF) {
+ return WEOF;
+ }
+
+#if WCHAR_MAX == 0xFFFF
+ if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate
+ wchar_t low_surrogate = getwchar();
+ if ((low_surrogate >= 0xDC00) && (low_surrogate <= 0xDFFF)) { // Check if the next wchar is a low surrogate
+ return (static_cast(wc & 0x03FF) << 10) + (low_surrogate & 0x03FF) + 0x10000;
+ }
+ }
+ if ((wc >= 0xD800) && (wc <= 0xDFFF)) { // Invalid surrogate pair
+ return 0xFFFD; // Return the replacement character U+FFFD
+ }
+#endif
+
+ return static_cast(wc);
+#endif
+ }
+
+ static void pop_cursor() {
+#if defined(_WIN32)
+ if (hConsole != NULL) {
+ CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
+ GetConsoleScreenBufferInfo(hConsole, &bufferInfo);
+
+ COORD newCursorPosition = bufferInfo.dwCursorPosition;
+ if (newCursorPosition.X == 0) {
+ newCursorPosition.X = bufferInfo.dwSize.X - 1;
+ newCursorPosition.Y -= 1;
+ } else {
+ newCursorPosition.X -= 1;
+ }
+
+ SetConsoleCursorPosition(hConsole, newCursorPosition);
+ return;
+ }
+#endif
+ putc('\b', out);
+ }
+
+ static int estimateWidth(char32_t codepoint) {
+#if defined(_WIN32)
+ (void)codepoint;
+ return 1;
+#else
+ return wcwidth(codepoint);
+#endif
+ }
+
+ static int put_codepoint(const char* utf8_codepoint, size_t length, int expectedWidth) {
+#if defined(_WIN32)
+ CONSOLE_SCREEN_BUFFER_INFO bufferInfo;
+ if (!GetConsoleScreenBufferInfo(hConsole, &bufferInfo)) {
+ // go with the default
+ return expectedWidth;
+ }
+ COORD initialPosition = bufferInfo.dwCursorPosition;
+ DWORD nNumberOfChars = length;
+ WriteConsole(hConsole, utf8_codepoint, nNumberOfChars, &nNumberOfChars, NULL);
+
+ CONSOLE_SCREEN_BUFFER_INFO newBufferInfo;
+ GetConsoleScreenBufferInfo(hConsole, &newBufferInfo);
+
+ // Figure out our real position if we're in the last column
+ if (utf8_codepoint[0] != 0x09 && initialPosition.X == newBufferInfo.dwSize.X - 1) {
+ DWORD nNumberOfChars;
+ WriteConsole(hConsole, &" \b", 2, &nNumberOfChars, NULL);
+ GetConsoleScreenBufferInfo(hConsole, &newBufferInfo);
+ }
+
+ int width = newBufferInfo.dwCursorPosition.X - initialPosition.X;
+ if (width < 0) {
+ width += newBufferInfo.dwSize.X;
+ }
+ return width;
+#else
+ // We can trust expectedWidth if we've got one
+ if (expectedWidth >= 0 || tty == nullptr) {
+ fwrite(utf8_codepoint, length, 1, out);
+ return expectedWidth;
+ }
+
+ fputs("\033[6n", tty); // Query cursor position
+ int x1;
+ int y1;
+ int x2;
+ int y2;
+ int results = 0;
+ results = fscanf(tty, "\033[%d;%dR", &y1, &x1);
+
+ fwrite(utf8_codepoint, length, 1, tty);
+
+ fputs("\033[6n", tty); // Query cursor position
+ results += fscanf(tty, "\033[%d;%dR", &y2, &x2);
+
+ if (results != 4) {
+ return expectedWidth;
+ }
+
+ int width = x2 - x1;
+ if (width < 0) {
+ // Calculate the width considering text wrapping
+ struct winsize w;
+ ioctl(STDOUT_FILENO, TIOCGWINSZ, &w);
+ width += w.ws_col;
+ }
+ return width;
+#endif
+ }
+
+ static void replace_last(char ch) {
+#if defined(_WIN32)
+ pop_cursor();
+ put_codepoint(&ch, 1, 1);
+#else
+ fprintf(out, "\b%c", ch);
+#endif
+ }
+
+ static void append_utf8(char32_t ch, std::string & out) {
+ if (ch <= 0x7F) {
+ out.push_back(static_cast(ch));
+ } else if (ch <= 0x7FF) {
+ out.push_back(static_cast(0xC0 | ((ch >> 6) & 0x1F)));
+ out.push_back(static_cast(0x80 | (ch & 0x3F)));
+ } else if (ch <= 0xFFFF) {
+ out.push_back(static_cast(0xE0 | ((ch >> 12) & 0x0F)));
+ out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F)));
+ out.push_back(static_cast(0x80 | (ch & 0x3F)));
+ } else if (ch <= 0x10FFFF) {
+ out.push_back(static_cast(0xF0 | ((ch >> 18) & 0x07)));
+ out.push_back(static_cast(0x80 | ((ch >> 12) & 0x3F)));
+ out.push_back(static_cast(0x80 | ((ch >> 6) & 0x3F)));
+ out.push_back(static_cast(0x80 | (ch & 0x3F)));
+ } else {
+ // Invalid Unicode code point
+ }
+ }
+
+ // Helper function to remove the last UTF-8 character from a string
+ static void pop_back_utf8_char(std::string & line) {
+ if (line.empty()) {
+ return;
+ }
+
+ size_t pos = line.length() - 1;
+
+ // Find the start of the last UTF-8 character (checking up to 4 bytes back)
+ for (size_t i = 0; i < 3 && pos > 0; ++i, --pos) {
+ if ((line[pos] & 0xC0) != 0x80) {
+ break; // Found the start of the character
+ }
+ }
+ line.erase(pos);
+ }
+
+ static bool readline_advanced(std::string & line, bool multiline_input) {
+ if (out != stdout) {
+ fflush(stdout);
+ }
+
+ line.clear();
+ std::vector widths;
+ bool is_special_char = false;
+ bool end_of_stream = false;
+
+ char32_t input_char;
+ while (true) {
+ fflush(out); // Ensure all output is displayed before waiting for input
+ input_char = getchar32();
+
+ if (input_char == '\r' || input_char == '\n') {
+ break;
+ }
+
+ if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D*/) {
+ end_of_stream = true;
+ break;
+ }
+
+ if (is_special_char) {
+ set_display(user_input);
+ replace_last(line.back());
+ is_special_char = false;
+ }
+
+ if (input_char == '\033') { // Escape sequence
+ char32_t code = getchar32();
+ if (code == '[' || code == 0x1B) {
+ // Discard the rest of the escape sequence
+ while ((code = getchar32()) != (char32_t) WEOF) {
+ if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') {
+ break;
+ }
+ }
+ }
+ } else if (input_char == 0x08 || input_char == 0x7F) { // Backspace
+ if (!widths.empty()) {
+ int count;
+ do {
+ count = widths.back();
+ widths.pop_back();
+ // Move cursor back, print space, and move cursor back again
+ for (int i = 0; i < count; i++) {
+ replace_last(' ');
+ pop_cursor();
+ }
+ pop_back_utf8_char(line);
+ } while (count == 0 && !widths.empty());
+ }
+ } else {
+ int offset = line.length();
+ append_utf8(input_char, line);
+ int width = put_codepoint(line.c_str() + offset, line.length() - offset, estimateWidth(input_char));
+ if (width < 0) {
+ width = 0;
+ }
+ widths.push_back(width);
+ }
+
+ if (!line.empty() && (line.back() == '\\' || line.back() == '/')) {
+ set_display(prompt);
+ replace_last(line.back());
+ is_special_char = true;
+ }
+ }
+
+ bool has_more = multiline_input;
+ if (is_special_char) {
+ replace_last(' ');
+ pop_cursor();
+
+ char last = line.back();
+ line.pop_back();
+ if (last == '\\') {
+ line += '\n';
+ fputc('\n', out);
+ has_more = !has_more;
+ } else {
+ // llama will just eat the single space, it won't act as a space
+ if (line.length() == 1 && line.back() == ' ') {
+ line.clear();
+ pop_cursor();
+ }
+ has_more = false;
+ }
+ } else {
+ if (end_of_stream) {
+ has_more = false;
+ } else {
+ line += '\n';
+ fputc('\n', out);
+ }
+ }
+
+ fflush(out);
+ return has_more;
+ }
+
+ static bool readline_simple(std::string & line, bool multiline_input) {
+#if defined(_WIN32)
+ std::wstring wline;
+ if (!std::getline(std::wcin, wline)) {
+ // Input stream is bad or EOF received
+ line.clear();
+ GenerateConsoleCtrlEvent(CTRL_C_EVENT, 0);
+ return false;
+ }
+
+ int size_needed = WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), NULL, 0, NULL, NULL);
+ line.resize(size_needed);
+ WideCharToMultiByte(CP_UTF8, 0, &wline[0], (int)wline.size(), &line[0], size_needed, NULL, NULL);
+#else
+ if (!std::getline(std::cin, line)) {
+ // Input stream is bad or EOF received
+ line.clear();
+ return false;
+ }
+#endif
+ if (!line.empty()) {
+ char last = line.back();
+ if (last == '/') { // Always return control on '/' symbol
+ line.pop_back();
+ return false;
+ }
+ if (last == '\\') { // '\\' changes the default action
+ line.pop_back();
+ multiline_input = !multiline_input;
+ }
+ }
+ line += '\n';
+
+ // By default, continue input if multiline_input is set
+ return multiline_input;
+ }
+
+ bool readline(std::string & line, bool multiline_input) {
+ set_display(user_input);
+
+ if (simple_io) {
+ return readline_simple(line, multiline_input);
+ }
+ return readline_advanced(line, multiline_input);
+ }
+
+}
diff --git a/applications/src/llama/common/console.h b/applications/src/llama/common/console.h
new file mode 100644
index 000000000..ec175269b
--- /dev/null
+++ b/applications/src/llama/common/console.h
@@ -0,0 +1,19 @@
+// Console functions
+
+#pragma once
+
+#include
+
+namespace console {
+ enum display_t {
+ reset = 0,
+ prompt,
+ user_input,
+ error
+ };
+
+ void init(bool use_simple_io, bool use_advanced_display);
+ void cleanup();
+ void set_display(display_t display);
+ bool readline(std::string & line, bool multiline_input);
+}
diff --git a/applications/src/llama/common/grammar-parser.cpp b/applications/src/llama/common/grammar-parser.cpp
new file mode 100644
index 000000000..bf89a96f3
--- /dev/null
+++ b/applications/src/llama/common/grammar-parser.cpp
@@ -0,0 +1,424 @@
+#include "grammar-parser.h"
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace grammar_parser {
+ // NOTE: assumes valid utf8 (but checks for overrun)
+ // copied from llama.cpp
+ static std::pair decode_utf8(const char * src) {
+ static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };
+ uint8_t first_byte = static_cast(*src);
+ uint8_t highbits = first_byte >> 4;
+ int len = lookup[highbits];
+ uint8_t mask = (1 << (8 - len)) - 1;
+ uint32_t value = first_byte & mask;
+ const char * end = src + len; // may overrun!
+ const char * pos = src + 1;
+ for ( ; pos < end && *pos; pos++) {
+ value = (value << 6) + (static_cast(*pos) & 0x3F);
+ }
+ return std::make_pair(value, pos);
+ }
+
+ static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) {
+ uint32_t next_id = static_cast(state.symbol_ids.size());
+ auto result = state.symbol_ids.insert(std::make_pair(std::string(src, len), next_id));
+ return result.first->second;
+ }
+
+ static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) {
+ uint32_t next_id = static_cast(state.symbol_ids.size());
+ state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id;
+ return next_id;
+ }
+
+ static void add_rule(
+ parse_state & state,
+ uint32_t rule_id,
+ const std::vector & rule) {
+ if (state.rules.size() <= rule_id) {
+ state.rules.resize(rule_id + 1);
+ }
+ state.rules[rule_id] = rule;
+ }
+
+ static bool is_word_char(char c) {
+ return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
+ }
+
+ static std::pair parse_hex(const char * src, int size) {
+ const char * pos = src;
+ const char * end = src + size;
+ uint32_t value = 0;
+ for ( ; pos < end && *pos; pos++) {
+ value <<= 4;
+ char c = *pos;
+ if ('a' <= c && c <= 'f') {
+ value += c - 'a' + 10;
+ } else if ('A' <= c && c <= 'F') {
+ value += c - 'A' + 10;
+ } else if ('0' <= c && c <= '9') {
+ value += c - '0';
+ } else {
+ break;
+ }
+ }
+ if (pos != end) {
+ throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src);
+ }
+ return std::make_pair(value, pos);
+ }
+
+ static const char * parse_space(const char * src, bool newline_ok) {
+ const char * pos = src;
+ while (*pos == ' ' || *pos == '\t' || *pos == '#' ||
+ (newline_ok && (*pos == '\r' || *pos == '\n'))) {
+ if (*pos == '#') {
+ while (*pos && *pos != '\r' && *pos != '\n') {
+ pos++;
+ }
+ } else {
+ pos++;
+ }
+ }
+ return pos;
+ }
+
+ static const char * parse_name(const char * src) {
+ const char * pos = src;
+ while (is_word_char(*pos)) {
+ pos++;
+ }
+ if (pos == src) {
+ throw std::runtime_error(std::string("expecting name at ") + src);
+ }
+ return pos;
+ }
+
+ static std::pair parse_char(const char * src) {
+ if (*src == '\\') {
+ switch (src[1]) {
+ case 'x': return parse_hex(src + 2, 2);
+ case 'u': return parse_hex(src + 2, 4);
+ case 'U': return parse_hex(src + 2, 8);
+ case 't': return std::make_pair('\t', src + 2);
+ case 'r': return std::make_pair('\r', src + 2);
+ case 'n': return std::make_pair('\n', src + 2);
+ case '\\':
+ case '"':
+ case '[':
+ case ']':
+ return std::make_pair(src[1], src + 2);
+ default:
+ throw std::runtime_error(std::string("unknown escape at ") + src);
+ }
+ } else if (*src) {
+ return decode_utf8(src);
+ }
+ throw std::runtime_error("unexpected end of input");
+ }
+
+ const char * parse_alternates(
+ parse_state & state,
+ const char * src,
+ const std::string & rule_name,
+ uint32_t rule_id,
+ bool is_nested);
+
+ static const char * parse_sequence(
+ parse_state & state,
+ const char * src,
+ const std::string & rule_name,
+ std::vector & out_elements,
+ bool is_nested) {
+ size_t last_sym_start = out_elements.size();
+ const char * pos = src;
+ while (*pos) {
+ if (*pos == '"') { // literal string
+ pos++;
+ last_sym_start = out_elements.size();
+ while (*pos != '"') {
+ auto char_pair = parse_char(pos);
+ pos = char_pair.second;
+ out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first});
+ }
+ pos = parse_space(pos + 1, is_nested);
+ } else if (*pos == '[') { // char range(s)
+ pos++;
+ enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
+ if (*pos == '^') {
+ pos++;
+ start_type = LLAMA_GRETYPE_CHAR_NOT;
+ }
+ last_sym_start = out_elements.size();
+ while (*pos != ']') {
+ auto char_pair = parse_char(pos);
+ pos = char_pair.second;
+ enum llama_gretype type = last_sym_start < out_elements.size()
+ ? LLAMA_GRETYPE_CHAR_ALT
+ : start_type;
+
+ out_elements.push_back({type, char_pair.first});
+ if (pos[0] == '-' && pos[1] != ']') {
+ auto endchar_pair = parse_char(pos + 1);
+ pos = endchar_pair.second;
+ out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first});
+ }
+ }
+ pos = parse_space(pos + 1, is_nested);
+ } else if (is_word_char(*pos)) { // rule reference
+ const char * name_end = parse_name(pos);
+ uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos);
+ pos = parse_space(name_end, is_nested);
+ last_sym_start = out_elements.size();
+ out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
+ } else if (*pos == '(') { // grouping
+ // parse nested alternates into synthesized rule
+ pos = parse_space(pos + 1, true);
+ uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
+ pos = parse_alternates(state, pos, rule_name, sub_rule_id, true);
+ last_sym_start = out_elements.size();
+ // output reference to synthesized rule
+ out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
+ if (*pos != ')') {
+ throw std::runtime_error(std::string("expecting ')' at ") + pos);
+ }
+ pos = parse_space(pos + 1, is_nested);
+ } else if (*pos == '*' || *pos == '+' || *pos == '?') { // repetition operator
+ if (last_sym_start == out_elements.size()) {
+ throw std::runtime_error(std::string("expecting preceding item to */+/? at ") + pos);
+ }
+
+ // apply transformation to previous symbol (last_sym_start to end) according to
+ // rewrite rules:
+ // S* --> S' ::= S S' |
+ // S+ --> S' ::= S S' | S
+ // S? --> S' ::= S |
+ uint32_t sub_rule_id = generate_symbol_id(state, rule_name);
+ std::vector sub_rule;
+ // add preceding symbol to generated rule
+ sub_rule.insert(
+ sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
+ if (*pos == '*' || *pos == '+') {
+ // cause generated rule to recurse
+ sub_rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
+ }
+ // mark start of alternate def
+ sub_rule.push_back({LLAMA_GRETYPE_ALT, 0});
+ if (*pos == '+') {
+ // add preceding symbol as alternate only for '+' (otherwise empty)
+ sub_rule.insert(
+ sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
+ }
+ sub_rule.push_back({LLAMA_GRETYPE_END, 0});
+ add_rule(state, sub_rule_id, sub_rule);
+
+ // in original rule, replace previous symbol with reference to generated rule
+ out_elements.resize(last_sym_start);
+ out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id});
+
+ pos = parse_space(pos + 1, is_nested);
+ } else {
+ break;
+ }
+ }
+ return pos;
+ }
+
+ const char * parse_alternates(
+ parse_state & state,
+ const char * src,
+ const std::string & rule_name,
+ uint32_t rule_id,
+ bool is_nested) {
+ std::vector rule;
+ const char * pos = parse_sequence(state, src, rule_name, rule, is_nested);
+ while (*pos == '|') {
+ rule.push_back({LLAMA_GRETYPE_ALT, 0});
+ pos = parse_space(pos + 1, true);
+ pos = parse_sequence(state, pos, rule_name, rule, is_nested);
+ }
+ rule.push_back({LLAMA_GRETYPE_END, 0});
+ add_rule(state, rule_id, rule);
+ return pos;
+ }
+
+ static const char * parse_rule(parse_state & state, const char * src) {
+ const char * name_end = parse_name(src);
+ const char * pos = parse_space(name_end, false);
+ size_t name_len = name_end - src;
+ uint32_t rule_id = get_symbol_id(state, src, name_len);
+ const std::string name(src, name_len);
+
+ if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) {
+ throw std::runtime_error(std::string("expecting ::= at ") + pos);
+ }
+ pos = parse_space(pos + 3, true);
+
+ pos = parse_alternates(state, pos, name, rule_id, false);
+
+ if (*pos == '\r') {
+ pos += pos[1] == '\n' ? 2 : 1;
+ } else if (*pos == '\n') {
+ pos++;
+ } else if (*pos) {
+ throw std::runtime_error(std::string("expecting newline or end at ") + pos);
+ }
+ return parse_space(pos, true);
+ }
+
+ parse_state parse(const char * src) {
+ try {
+ parse_state state;
+ const char * pos = parse_space(src, true);
+ while (*pos) {
+ pos = parse_rule(state, pos);
+ }
+ return state;
+ } catch (const std::exception & err) {
+ fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what());
+ return parse_state();
+ }
+ }
+
+ static void print_grammar_char(FILE * file, uint32_t c) {
+ if (0x20 <= c && c <= 0x7f) {
+ fprintf(file, "%c", static_cast(c));
+ } else {
+ // cop out of encoding UTF-8
+ fprintf(file, "", c);
+ }
+ }
+
+ static bool is_char_element(llama_grammar_element elem) {
+ switch (elem.type) {
+ case LLAMA_GRETYPE_CHAR: return true;
+ case LLAMA_GRETYPE_CHAR_NOT: return true;
+ case LLAMA_GRETYPE_CHAR_ALT: return true;
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true;
+ default: return false;
+ }
+ }
+
+ static void print_rule_binary(FILE * file, const std::vector & rule) {
+ for (auto elem : rule) {
+ switch (elem.type) {
+ case LLAMA_GRETYPE_END: fprintf(file, "END"); break;
+ case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break;
+ case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break;
+ case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break;
+ case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break;
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break;
+ case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break;
+ }
+ switch (elem.type) {
+ case LLAMA_GRETYPE_END:
+ case LLAMA_GRETYPE_ALT:
+ case LLAMA_GRETYPE_RULE_REF:
+ fprintf(file, "(%u) ", elem.value);
+ break;
+ case LLAMA_GRETYPE_CHAR:
+ case LLAMA_GRETYPE_CHAR_NOT:
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+ case LLAMA_GRETYPE_CHAR_ALT:
+ fprintf(file, "(\"");
+ print_grammar_char(file, elem.value);
+ fprintf(file, "\") ");
+ break;
+ }
+ }
+ fprintf(file, "\n");
+ }
+
+ static void print_rule(
+ FILE * file,
+ uint32_t rule_id,
+ const std::vector & rule,
+ const std::map & symbol_id_names) {
+ if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
+ throw std::runtime_error(
+ "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
+ }
+ fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
+ for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
+ llama_grammar_element elem = rule[i];
+ switch (elem.type) {
+ case LLAMA_GRETYPE_END:
+ throw std::runtime_error(
+ "unexpected end of rule: " + std::to_string(rule_id) + "," +
+ std::to_string(i));
+ case LLAMA_GRETYPE_ALT:
+ fprintf(file, "| ");
+ break;
+ case LLAMA_GRETYPE_RULE_REF:
+ fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str());
+ break;
+ case LLAMA_GRETYPE_CHAR:
+ fprintf(file, "[");
+ print_grammar_char(file, elem.value);
+ break;
+ case LLAMA_GRETYPE_CHAR_NOT:
+ fprintf(file, "[^");
+ print_grammar_char(file, elem.value);
+ break;
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+ if (i == 0 || !is_char_element(rule[i - 1])) {
+ throw std::runtime_error(
+ "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " +
+ std::to_string(rule_id) + "," + std::to_string(i));
+ }
+ fprintf(file, "-");
+ print_grammar_char(file, elem.value);
+ break;
+ case LLAMA_GRETYPE_CHAR_ALT:
+ if (i == 0 || !is_char_element(rule[i - 1])) {
+ throw std::runtime_error(
+ "LLAMA_GRETYPE_CHAR_ALT without preceding char: " +
+ std::to_string(rule_id) + "," + std::to_string(i));
+ }
+ print_grammar_char(file, elem.value);
+ break;
+ }
+ if (is_char_element(elem)) {
+ switch (rule[i + 1].type) {
+ case LLAMA_GRETYPE_CHAR_ALT:
+ case LLAMA_GRETYPE_CHAR_RNG_UPPER:
+ break;
+ default:
+ fprintf(file, "] ");
+ }
+ }
+ }
+ fprintf(file, "\n");
+ }
+
+ void print_grammar(FILE * file, const parse_state & state) {
+ try {
+ std::map symbol_id_names;
+ for (const auto & kv : state.symbol_ids) {
+ symbol_id_names[kv.second] = kv.first;
+ }
+ for (size_t i = 0, end = state.rules.size(); i < end; i++) {
+ // fprintf(file, "%zu: ", i);
+ // print_rule_binary(file, state.rules[i]);
+ print_rule(file, uint32_t(i), state.rules[i], symbol_id_names);
+ // fprintf(file, "\n");
+ }
+ } catch (const std::exception & err) {
+ fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what());
+ }
+ }
+
+ std::vector parse_state::c_rules() {
+ std::vector ret;
+ ret.reserve(rules.size());
+ for (const auto & rule : rules) {
+ ret.push_back(rule.data());
+ }
+ return ret;
+ }
+}
diff --git a/applications/src/llama/common/grammar-parser.h b/applications/src/llama/common/grammar-parser.h
new file mode 100644
index 000000000..9037d7272
--- /dev/null
+++ b/applications/src/llama/common/grammar-parser.h
@@ -0,0 +1,29 @@
+// Implements a parser for an extended Backus-Naur form (BNF), producing the
+// binary context-free grammar format specified by llama.h. Supports character
+// ranges, grouping, and repetition operators. As an example, a grammar for
+// arithmetic might look like:
+//
+// root ::= expr
+// expr ::= term ([-+*/] term)*
+// term ::= num | "(" space expr ")" space
+// num ::= [0-9]+ space
+// space ::= [ \t\n]*
+
+#pragma once
+#include "llama.h"
+#include
+#include